inital integration with ATOM

Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Arvindh
2026-05-14 19:10:18 +05:30
parent d4f0d8fdef
commit 32319881a9
531 changed files with 4954 additions and 145562 deletions
+44 -33
View File
@@ -10,7 +10,9 @@ on:
paths:
- ".github/workflows/api-tests.yaml"
- "api/**"
- "auth/api/http/**"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "channels/api/http/**"
- "clients/api/http/**"
- "domains/api/http/**"
@@ -20,9 +22,9 @@ on:
- "bootstrap/api/**"
- "certs/api/http/**"
- "readers/api/http/**"
- "re/api/**"
- "alarms/api/**"
- "reports/api/**"
- "re/**"
- "alarms/**"
- "reports/**"
- "apidocs/openapi/**"
pull_request:
branches:
@@ -30,7 +32,9 @@ on:
paths:
- ".github/workflows/api-tests.yaml"
- "api/**"
- "auth/api/http/**"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "channels/api/http/**"
- "clients/api/http/**"
- "domains/api/http/**"
@@ -40,9 +44,9 @@ on:
- "bootstrap/api/**"
- "certs/api/http/**"
- "readers/api/http/**"
- "re/api/**"
- "alarms/api/**"
- "reports/api/**"
- "re/**"
- "alarms/**"
- "reports/**"
- "apidocs/openapi/**"
concurrency:
@@ -50,17 +54,16 @@ concurrency:
cancel-in-progress: true
env:
TOKENS_URL: http://localhost:9002/users/tokens/issue
CREATE_DOMAINS_URL: http://localhost:9003/domains
TOKENS_URL: http://localhost:9000/users/tokens/issue
CREATE_DOMAINS_URL: http://localhost:9000/domains
USER_IDENTITY: admin@example.com
USER_SECRET: 12345678
DOMAIN_NAME: demo-test
USERS_URL: http://localhost:9002
DOMAIN_URL: http://localhost:9003
CLIENTS_URL: http://localhost:9006
CHANNELS_URL: http://localhost:9005
GROUPS_URL: http://localhost:9004
AUTH_URL: http://localhost:9001
USERS_URL: http://localhost:9000
DOMAIN_URL: http://localhost:9000
CLIENTS_URL: http://localhost:9000
CHANNELS_URL: http://localhost:9000
GROUPS_URL: http://localhost:9000
JOURNAL_URL: http://localhost:9021
BOOTSTRAP_URL: http://localhost:9013
CERTS_URL: http://localhost:9019
@@ -96,28 +99,39 @@ jobs:
- "apidocs/openapi/journal.yaml"
- "journal/api/**"
auth:
- "apidocs/openapi/auth.yaml"
- "auth/api/http/**"
domains:
- "apidocs/openapi/domains.yaml"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "domains/api/http/**"
clients:
- "apidocs/openapi/clients.yaml"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "clients/api/http/**"
channels:
- "apidocs/openapi/channels.yaml"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "channels/api/http/**"
groups:
- "apidocs/openapi/groups.yaml"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "groups/api/http/**"
users:
- "apidocs/openapi/users.yaml"
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- "users/api/**"
bootstrap:
@@ -134,15 +148,21 @@ jobs:
re:
- "apidocs/openapi/rules.yaml"
- "re/api/**"
- "re/**"
- "cmd/re/**"
- "internal/atom/**"
alarms:
- "apidocs/openapi/alarms.yaml"
- "alarms/api/**"
- "alarms/**"
- "cmd/alarms/**"
- "internal/atom/**"
reports:
- "apidocs/openapi/reports.yaml"
- "reports/api/**"
- "reports/**"
- "cmd/reports/**"
- "internal/atom/**"
- name: Build images
run: make all -j $(nproc) && make dockers_dev -j $(nproc)
@@ -157,7 +177,7 @@ jobs:
# Check if services are responding
for i in {1..30}; do
if curl -f -s http://localhost:9002/health > /dev/null 2>&1; then
if curl -f -s http://localhost:9000/health > /dev/null 2>&1; then
echo "Services are ready!"
break
fi
@@ -209,15 +229,6 @@ jobs:
checks: all
args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples'
- name: Run Auth API tests
if: steps.changes.outputs.auth == 'true' || steps.changes.outputs.workflow == 'true'
uses: schemathesis/action@v3.0.0
with:
schema: apidocs/openapi/auth.yaml
base-url: ${{ env.AUTH_URL }}
checks: all
args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples'
- name: Run Domains API tests
if: steps.changes.outputs.domains == 'true' || steps.changes.outputs.workflow == 'true'
uses: schemathesis/action@v3.0.0
+16 -44
View File
@@ -66,28 +66,10 @@ jobs:
workflow:
- ".github/workflows/tests.yaml"
auth:
- "auth/**"
- "cmd/auth/**"
- "auth.proto"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "pkg/ulid/**"
- "pkg/uuid/**"
bootstrap:
- "bootstrap/**"
- "cmd/bootstrap/**"
- "pkg/bootstrap/**"
- "provision/**"
- "pkg/sdk/**"
channels:
- "channels/**"
- "cmd/channels/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "core/**"
- "cmd/core/**"
- "pkg/sdk/**"
- "clients/api/grpc/**"
- "groups/api/grpc/**"
@@ -101,10 +83,8 @@ jobs:
clients:
- "clients/**"
- "cmd/clients/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "core/**"
- "cmd/core/**"
- "pkg/ulid/**"
- "pkg/uuid/**"
- "pkg/events/**"
@@ -115,18 +95,14 @@ jobs:
domains:
- "domains/**"
- "cmd/domains/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "core/**"
- "cmd/core/**"
- "internal/grpc/**"
groups:
- "groups/**"
- "cmd/groups/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "core/**"
- "cmd/core/**"
- "pkg/ulid/**"
- "pkg/uuid/**"
- "clients/api/grpc/**"
@@ -140,9 +116,6 @@ jobs:
journal:
- "journal/**"
- "cmd/journal/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "pkg/events/**"
logger:
@@ -165,7 +138,6 @@ jobs:
- "pkg/sdk/**"
- "pkg/errors/**"
- "pkg/groups/**"
- "auth/**"
- "internal/*"
- "clients/**"
- "users/**"
@@ -189,10 +161,8 @@ jobs:
users:
- "users/**"
- "cmd/users/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "core/**"
- "cmd/core/**"
- "pkg/ulid/**"
- "pkg/uuid/**"
- "pkg/events/**"
@@ -200,8 +170,6 @@ jobs:
notifications:
- "notifications/**"
- "cmd/notifications/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "consumers/notifier.go"
- "pkg/events/**"
@@ -232,6 +200,10 @@ jobs:
reports:
- "reports/**"
- "cmd/reports/**"
core:
- "core/**"
- "cmd/core/**"
- "internal/atom/**"
- name: Set matrix for changed modules
id: set-matrix
@@ -240,11 +212,11 @@ jobs:
if [[ "${{ steps.changes.outputs.workflow }}" == "true" || "${{ steps.changes.outputs.pkg-errors }}" == "true" ]]; then
# If workflow or pkg/errors changed, test everything
modules=("auth" "bootstrap" "channels" "cli" "clients" "domains" "groups" "internal" "journal" "logger" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers" "re" "alarms" "reports")
modules=("auth" "core" "channels" "cli" "clients" "domains" "groups" "internal" "journal" "logger" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers" "re" "alarms" "reports")
else
# Add only changed modules
[[ "${{ steps.changes.outputs.auth }}" == "true" ]] && modules+=("auth")
[[ "${{ steps.changes.outputs.bootstrap }}" == "true" ]] && modules+=("bootstrap")
[[ "${{ steps.changes.outputs.core }}" == "true" ]] && modules+=("core")
[[ "${{ steps.changes.outputs.channels }}" == "true" ]] && modules+=("channels")
[[ "${{ steps.changes.outputs.cli }}" == "true" ]] && modules+=("cli")
[[ "${{ steps.changes.outputs.clients }}" == "true" ]] && modules+=("clients")
+3
View File
@@ -21,3 +21,6 @@ docker/addons/certs/openbao/
# Ignore SeaweedFS data directory as it contains runtime-generated data
docker/data/*
demo-ui
node_modules
+10 -11
View File
@@ -4,8 +4,8 @@
override MG_DOCKER_IMAGE_NAME_PREFIX := ghcr.io/absmach/magistrala
MG_DOCKER_VOLUME_NAME_PREFIX ?= magistrala
BUILD_DIR ?= build
SERVICES = auth users clients groups channels domains notifications certs re postgres-writer postgres-reader timescale-writer timescale-reader cli alarms reports bootstrap provision journal fluxmq
TEST_API_SERVICES = journal auth certs clients users channels groups domains
SERVICES = notifications certs re postgres-writer postgres-reader timescale-writer timescale-reader alarms reports journal fluxmq
TEST_API_SERVICES = journal certs clients users channels groups domains
TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES))
DOCKERS = $(addprefix docker_,$(SERVICES))
DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES))
@@ -199,12 +199,11 @@ define test_api_service
--phases=examples,stateful
endef
test_api_users: TEST_API_URL := http://localhost:9002
test_api_clients: TEST_API_URL := http://localhost:9006
test_api_domains: TEST_API_URL := http://localhost:9003
test_api_channels: TEST_API_URL := http://localhost:9005
test_api_groups: TEST_API_URL := http://localhost:9004
test_api_auth: TEST_API_URL := http://localhost:9001
test_api_users: TEST_API_URL := http://localhost:9000
test_api_clients: TEST_API_URL := http://localhost:9000
test_api_domains: TEST_API_URL := http://localhost:9000
test_api_channels: TEST_API_URL := http://localhost:9000
test_api_groups: TEST_API_URL := http://localhost:9000
test_api_certs: TEST_API_URL := http://localhost:9019
test_api_journal: TEST_API_URL := http://localhost:9021
@@ -262,7 +261,7 @@ rundev:
cd scripts && ./run.sh
grpc_mtls_certs:
$(MAKE) -C docker/ssl auth_grpc_certs clients_grpc_certs
$(MAKE) -C docker/ssl clients_grpc_certs
check_tls:
ifeq ($(GRPC_TLS),true)
@@ -284,7 +283,7 @@ check_certs: check_mtls check_tls
ifeq ($(GRPC_MTLS_CERT_FILES_EXISTS),0)
ifeq ($(filter true,$(GRPC_MTLS) $(GRPC_TLS)),true)
ifeq ($(filter $(DEFAULT_DOCKER_COMPOSE_COMMAND),$(DOCKER_COMPOSE_COMMAND)),$(DEFAULT_DOCKER_COMPOSE_COMMAND))
$(MAKE) -C docker/ssl auth_grpc_certs clients_grpc_certs
$(MAKE) -C docker/ssl clients_grpc_certs
endif
endif
endif
@@ -312,7 +311,7 @@ run_stable: check_certs
run_addons: check_certs
$(foreach SVC,$(RUN_ADDON_ARGS),$(if $(filter $(SVC),$(ADDON_SERVICES) $(EXTERNAL_SERVICES)),,$(error Invalid Service $(SVC))))
@$(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file ./docker/.env -p $(DOCKER_PROJECT) up -d auth domains jaeger
@$(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file ./docker/.env -p $(DOCKER_PROJECT) up -d atom jaeger
@for SVC in $(RUN_ADDON_ARGS); do \
MG_ADDONS_CERTS_PATH_PREFIX="../" $(DOCKER_PLATFORM) docker compose -f docker/addons/$$SVC/docker-compose.yaml -p $(DOCKER_PROJECT) --env-file ./docker/.env $(DOCKER_COMPOSE_COMMAND) $(args) & \
done
+6 -11
View File
@@ -26,16 +26,11 @@ The service is configured using the following environment variables (values show
| `MG_MESSAGE_BROKER_URL` | Message broker URL for alarm ingestion | `nats://nats:4222` |
| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` |
| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` |
| `MG_AUTH_GRPC_URL` | Auth gRPC endpoint | `auth:7001` |
| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `300s` |
| `MG_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt}` |
| `MG_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key}` |
| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` |
| `MG_DOMAINS_GRPC_URL` | Domains gRPC endpoint | `domains:7003` |
| `MG_DOMAINS_GRPC_TIMEOUT` | Domains gRPC timeout | `300s` |
| `MG_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt}` |
| `MG_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key}` |
| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` |
| `ATOM_URL` | Atom HTTP endpoint | `http://atom:8080` |
| `ATOM_JWKS_URL` | Atom JWKS endpoint for JWT verification | `http://atom:8080/.well-known/jwks.json` |
| `ATOM_ADMIN_USERNAME` | Atom admin login for service projections | `atom-admin` |
| `ATOM_ADMIN_SECRET` | Atom admin secret for service projections | `change-me` |
| `ATOM_TIMEOUT` | Atom request timeout | `5s` |
| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to access | `true` |
## Features
@@ -44,7 +39,7 @@ The service is configured using the following environment variables (values show
- **Stateful updates**: Updates assignee, acknowledgment, resolution, and metadata fields.
- **Filtering and paging**: Lists alarms by domain, rule, channel, client, subtopic, status, severity, and time range.
- **Observability**: `/metrics` Prometheus endpoint and Jaeger tracing support.
- **Auth and authorization**: Authn/authz enforced via gRPC auth and domains services.
- **Auth and authorization**: Authn/authz enforced through Atom JWT verification and PDP checks.
## Architecture
+1 -2
View File
@@ -106,7 +106,7 @@ func (a Alarm) Validate() error {
// Service specifies an API that must be fulfilled by the domain service.
type Service interface {
CreateAlarm(ctx context.Context, alarm Alarm) error
CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error)
UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error)
ViewAlarm(ctx context.Context, session authn.Session, id string) (Alarm, error)
ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error)
@@ -118,6 +118,5 @@ type Repository interface {
UpdateAlarm(ctx context.Context, alarm Alarm) (Alarm, error)
ViewAlarm(ctx context.Context, alarmID, domainID string) (Alarm, error)
ListAllAlarms(ctx context.Context, pm PageMetadata) (AlarmsPage, error)
ListUserAlarms(ctx context.Context, userID string, pm PageMetadata) (AlarmsPage, error)
DeleteAlarm(ctx context.Context, id string) error
}
+78
View File
@@ -0,0 +1,78 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package alarms
import (
"context"
"github.com/absmach/magistrala/internal/atom"
"github.com/absmach/magistrala/pkg/authn"
)
type atomService struct {
Service
projector atom.Projector
}
func WithAtom(svc Service, projector atom.Projector) Service {
if projector == nil {
return svc
}
return atomService{Service: svc, projector: projector}
}
func (svc atomService) CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) {
created, err := svc.Service.CreateAlarm(ctx, alarm)
if err != nil {
return created, err
}
if created.ID == "" {
return created, nil
}
if err := svc.projector.UpsertResource(ctx, alarmProjection(created)); err != nil {
return created, nil
}
return created, nil
}
func (svc atomService) UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) {
updated, err := svc.Service.UpdateAlarm(ctx, session, alarm)
if err != nil {
return updated, err
}
if err := svc.projector.UpsertResource(ctx, alarmProjection(updated)); err != nil {
return updated, nil
}
return updated, nil
}
func (svc atomService) DeleteAlarm(ctx context.Context, session authn.Session, id string) error {
if err := svc.Service.DeleteAlarm(ctx, session, id); err != nil {
return err
}
_ = svc.projector.DeleteResource(ctx, id)
return nil
}
func alarmProjection(a Alarm) atom.Resource {
res := atom.ResourceFromFields(atom.ObjectFields{
ID: a.ID,
Kind: atom.KindAlarm,
Name: a.Cause,
TenantID: a.DomainID,
OwnerID: a.AssigneeID,
Status: a.Status.String(),
Metadata: map[string]any(a.Metadata),
UpdatedBy: a.UpdatedBy,
CreatedAt: a.CreatedAt,
UpdatedAt: a.UpdatedAt,
})
res.Attributes["rule_id"] = a.RuleID
res.Attributes["channel_id"] = a.ChannelID
res.Attributes["client_id"] = a.ClientID
res.Attributes["severity"] = a.Severity
res.Attributes["measurement"] = a.Measurement
res.Attributes["assignee_id"] = a.AssigneeID
return res
}
+77
View File
@@ -0,0 +1,77 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package alarms
import (
"context"
"testing"
"github.com/absmach/magistrala/internal/atom"
"github.com/absmach/magistrala/pkg/authn"
)
func TestAtomServiceCreateAlarmProjectsCreatedAlarm(t *testing.T) {
projector := &alarmProjector{}
svc := WithAtom(alarmService{
create: Alarm{
ID: "alarm-1",
RuleID: "rule-1",
DomainID: "domain-1",
ChannelID: "channel-1",
ClientID: "client-1",
Cause: "high temperature",
Measurement: "temperature",
Severity: 90,
Status: ActiveStatus,
},
}, projector)
created, err := svc.CreateAlarm(context.Background(), Alarm{RuleID: "rule-1"})
if err != nil {
t.Fatalf("create alarm: %v", err)
}
if created.ID != "alarm-1" {
t.Fatalf("unexpected created alarm: %#v", created)
}
if projector.resource.ID != "alarm-1" || projector.resource.Kind != atom.KindAlarm {
t.Fatalf("unexpected projection: %#v", projector.resource)
}
if projector.resource.Attributes["rule_id"] != "rule-1" {
t.Fatalf("missing rule projection: %#v", projector.resource.Attributes)
}
}
type alarmService struct {
create Alarm
}
func (svc alarmService) CreateAlarm(context.Context, Alarm) (Alarm, error) {
return svc.create, nil
}
func (svc alarmService) UpdateAlarm(context.Context, authn.Session, Alarm) (Alarm, error) {
return Alarm{}, nil
}
func (svc alarmService) ViewAlarm(context.Context, authn.Session, string) (Alarm, error) {
return Alarm{}, nil
}
func (svc alarmService) ListAlarms(context.Context, authn.Session, PageMetadata) (AlarmsPage, error) {
return AlarmsPage{}, nil
}
func (svc alarmService) DeleteAlarm(context.Context, authn.Session, string) error {
return nil
}
type alarmProjector struct {
atom.Projector
resource atom.Resource
}
func (p *alarmProjector) UpsertResource(_ context.Context, resource atom.Resource) error {
p.resource = resource
return nil
}
+2 -1
View File
@@ -48,7 +48,8 @@ func (h handler) Handle(msg *messaging.Message) (err error) {
return err
}
return h.svc.CreateAlarm(context.Background(), alarm)
_, err = h.svc.CreateAlarm(context.Background(), alarm)
return err
}
func (h handler) Cancel() error {
+34 -12
View File
@@ -9,6 +9,7 @@ import (
"github.com/absmach/magistrala/alarms"
"github.com/absmach/magistrala/alarms/operations"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/internal/atom"
"github.com/absmach/magistrala/pkg/authn"
smqauthz "github.com/absmach/magistrala/pkg/authz"
"github.com/absmach/magistrala/pkg/errors"
@@ -26,6 +27,7 @@ var (
type authorizationMiddleware struct {
svc alarms.Service
authz smqauthz.Authorization
atomAuthz atom.Authorizer
entitiesOps permissions.EntitiesOperations[permissions.Operation]
}
@@ -43,7 +45,19 @@ func NewAuthorizationMiddleware(svc alarms.Service, authz smqauthz.Authorization
}, nil
}
func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
func NewAtomAuthorizationMiddleware(svc alarms.Service, authz atom.Authorizer, entitiesOps permissions.EntitiesOperations[permissions.Operation]) (alarms.Service, error) {
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
return &authorizationMiddleware{
svc: svc,
atomAuthz: authz,
entitiesOps: entitiesOps,
}, nil
}
func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
return am.svc.CreateAlarm(ctx, alarm)
}
@@ -58,17 +72,19 @@ func (am *authorizationMiddleware) UpdateAlarm(ctx context.Context, session auth
if err := am.authorize(ctx, operations.OpAssignAlarm, session, policies.DomainType, session.DomainID); err != nil {
return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err)
}
domainUserID := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID)
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: domainUserID,
Permission: policies.MembershipPermission,
ObjectType: policies.DomainType,
Object: session.DomainID,
}, nil); err != nil {
return alarms.Alarm{}, err
if am.atomAuthz == nil {
domainUserID := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID)
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: domainUserID,
Permission: policies.MembershipPermission,
ObjectType: policies.DomainType,
Object: session.DomainID,
}, nil); err != nil {
return alarms.Alarm{}, err
}
}
}
@@ -124,6 +140,9 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
if err != nil {
return err
}
if am.atomAuthz != nil {
return atom.Authorize(ctx, am.atomAuthz, session, perm.String(), objType, obj, atom.KindAlarm)
}
pr := smqauthz.PolicyReq{
Domain: session.DomainID,
@@ -159,6 +178,9 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
if session.Role != authn.SuperAdminRole {
return svcerr.ErrSuperAdminAction
}
if am.atomAuthz != nil {
return atom.Authorize(ctx, am.atomAuthz, session, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject, policies.PlatformType)
}
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
SubjectType: policies.UserType,
Subject: session.UserID,
+2 -2
View File
@@ -27,7 +27,7 @@ func NewLoggingMiddleware(logger *slog.Logger, service alarms.Service) alarms.Se
}
}
func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) {
func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (created alarms.Alarm, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -52,7 +52,7 @@ func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm
lm.logger.Warn("Create alarm failed", args...)
return
}
if alarm.ID != "" {
if created.ID != "" {
lm.logger.Info("Create alarm completed successfully", args...)
}
}(time.Now())
+1 -1
View File
@@ -28,7 +28,7 @@ func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, se
}
}
func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
defer func(begin time.Time) {
mm.counter.With("method", "create_alarm").Add(1)
mm.latency.With("method", "create_alarm").Observe(time.Since(begin).Seconds())
+1 -1
View File
@@ -27,7 +27,7 @@ func NewTracingMiddleware(tracer trace.Tracer, svc alarms.Service) alarms.Servic
}
}
func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "create_alarm", trace.WithAttributes(
attribute.String("rule_id", alarm.RuleID),
attribute.String("measurement", alarm.Measurement),
-72
View File
@@ -231,78 +231,6 @@ func (_c *Repository_ListAllAlarms_Call) RunAndReturn(run func(ctx context.Conte
return _c
}
// ListUserAlarms provides a mock function for the type Repository
func (_mock *Repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
ret := _mock.Called(ctx, userID, pm)
if len(ret) == 0 {
panic("no return value specified for ListUserAlarms")
}
var r0 alarms.AlarmsPage
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok {
return returnFunc(ctx, userID, pm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) alarms.AlarmsPage); ok {
r0 = returnFunc(ctx, userID, pm)
} else {
r0 = ret.Get(0).(alarms.AlarmsPage)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, alarms.PageMetadata) error); ok {
r1 = returnFunc(ctx, userID, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_ListUserAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserAlarms'
type Repository_ListUserAlarms_Call struct {
*mock.Call
}
// ListUserAlarms is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - pm alarms.PageMetadata
func (_e *Repository_Expecter) ListUserAlarms(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserAlarms_Call {
return &Repository_ListUserAlarms_Call{Call: _e.mock.On("ListUserAlarms", ctx, userID, pm)}
}
func (_c *Repository_ListUserAlarms_Call) Run(run func(ctx context.Context, userID string, pm alarms.PageMetadata)) *Repository_ListUserAlarms_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 alarms.PageMetadata
if args[2] != nil {
arg2 = args[2].(alarms.PageMetadata)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Repository_ListUserAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListUserAlarms_Call {
_c.Call.Return(alarmsPage, err)
return _c
}
func (_c *Repository_ListUserAlarms_Call) RunAndReturn(run func(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListUserAlarms_Call {
_c.Call.Return(run)
return _c
}
// UpdateAlarm provides a mock function for the type Repository
func (_mock *Repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
ret := _mock.Called(ctx, alarm)
+17 -8
View File
@@ -44,20 +44,29 @@ func (_m *Service) EXPECT() *Service_Expecter {
}
// CreateAlarm provides a mock function for the type Service
func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
ret := _mock.Called(ctx, alarm)
if len(ret) == 0 {
panic("no return value specified for CreateAlarm")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) error); ok {
var r0 alarms.Alarm
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok {
return returnFunc(ctx, alarm)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok {
r0 = returnFunc(ctx, alarm)
} else {
r0 = ret.Error(0)
r0 = ret.Get(0).(alarms.Alarm)
}
return r0
if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok {
r1 = returnFunc(ctx, alarm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm'
@@ -90,12 +99,12 @@ func (_c *Service_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alar
return _c
}
func (_c *Service_CreateAlarm_Call) Return(err error) *Service_CreateAlarm_Call {
_c.Call.Return(err)
func (_c *Service_CreateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Service_CreateAlarm_Call {
_c.Call.Return(alarm1, err)
return _c
}
func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) error) *Service_CreateAlarm_Call {
func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Service_CreateAlarm_Call {
_c.Call.Return(run)
return _c
}
-29
View File
@@ -198,35 +198,6 @@ func (r *repository) ListAllAlarms(ctx context.Context, pm alarms.PageMetadata)
return r.alarmsPage(ctx, comQuery, pm)
}
func (r *repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
clauses := []string{
`(
EXISTS (
SELECT 1
FROM rules_roles rr
JOIN rules_role_members rrm ON rrm.role_id = rr.id
WHERE rr.entity_id = alarms.rule_id AND rrm.member_id = :user_id
)
OR EXISTS (
SELECT 1
FROM domains_roles dr
JOIN domains_role_members drm ON drm.role_id = dr.id
JOIN domains_role_actions dra ON dra.role_id = dr.id
WHERE dr.entity_id = alarms.domain_id
AND drm.member_id = :user_id
AND dra.action LIKE 'alarm%'
)
)`,
}
clauses = append(clauses, pageQueryConditions(pm)...)
query := fmt.Sprintf("WHERE %s", strings.Join(clauses, " AND "))
pm.UserID = userID
comQuery := fmt.Sprintf(`SELECT DISTINCT %s FROM alarms %s`, alarmColumns, query)
return r.alarmsPage(ctx, comQuery, pm)
}
func (r *repository) alarmsPage(ctx context.Context, comQuery string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
dir := api.DescDir
if pm.Dir == api.AscDir {
-209
View File
@@ -415,215 +415,6 @@ func TestListAlarms(t *testing.T) {
}
}
func TestListUserAlarms(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM domains_role_actions")
require.Nil(t, err, fmt.Sprintf("clean domains_role_actions unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains_role_members")
require.Nil(t, err, fmt.Sprintf("clean domains_role_members unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains_roles")
require.Nil(t, err, fmt.Sprintf("clean domains_roles unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
_, err = db.Exec("DELETE FROM alarms")
require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err))
_, err = db.Exec("DELETE FROM rules")
require.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
})
repo := postgres.NewAlarmsRepo(db)
domainID := generateUUID(t)
domainRoute := generateUUID(t)
userID := generateUUID(t)
otherUserID := generateUUID(t)
adminUserID := generateUUID(t)
domainUserID := generateUUID(t)
_, err := db.Exec(`INSERT INTO domains (id, name, route, status) VALUES ($1, $2, $3, $4)`, domainID, namegen.Generate(), domainRoute, 0)
require.Nil(t, err, fmt.Sprintf("insert domains unexpected error: %s", err))
// Create 10 rules and 10 alarms referencing them.
// Assign userID to the first 6 rules via role membership.
var ruleIDs []string
var createdAlarms []alarms.Alarm
for i := range 10 {
ruleID := generateUUID(t)
_, err := db.Exec(`INSERT INTO rules (id, name, domain_id, status, logic_type, logic_value) VALUES ($1, $2, $3, 0, 0, '')`,
ruleID, fmt.Sprintf("rule-%d", i), domainID)
require.Nil(t, err, fmt.Sprintf("insert rule unexpected error: %s", err))
ruleIDs = append(ruleIDs, ruleID)
alarm := alarms.Alarm{
ID: generateUUID(t),
RuleID: ruleID,
DomainID: domainID,
ChannelID: generateUUID(t),
ClientID: generateUUID(t),
Measurement: namegen.Generate(),
Value: namegen.Generate(),
Unit: namegen.Generate(),
Threshold: namegen.Generate(),
Cause: namegen.Generate(),
Status: 0,
AssigneeID: generateUUID(t),
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
}
alarm, err = repo.CreateAlarm(context.Background(), alarm)
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
createdAlarms = append(createdAlarms, alarm)
}
// Assign userID to the first 6 rules via rules_roles + rules_role_members.
userRoleIDs := make([]string, 6)
for i := range 6 {
roleID := generateUUID(t)
userRoleIDs[i] = roleID
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i])
require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
_, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, ruleIDs[i])
require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
}
for i := range 10 {
var roleID string
if i < 6 {
roleID = userRoleIDs[i]
} else {
roleID = generateUUID(t)
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i])
require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
}
_, err := db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, adminUserID, ruleIDs[i])
require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
}
domainRoleID := generateUUID(t)
_, err = db.Exec(`INSERT INTO domains_roles (id, name, entity_id) VALUES ($1, $2, $3)`, domainRoleID, "admin", domainID)
require.Nil(t, err, fmt.Sprintf("insert domains_roles unexpected error: %s", err))
_, err = db.Exec(`INSERT INTO domains_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, domainRoleID, domainUserID, domainID)
require.Nil(t, err, fmt.Sprintf("insert domains_role_members unexpected error: %s", err))
_, err = db.Exec(`INSERT INTO domains_role_actions (role_id, action) VALUES ($1, $2)`, domainRoleID, "alarm_read")
require.Nil(t, err, fmt.Sprintf("insert domains_role_actions unexpected error: %s", err))
_ = createdAlarms
cases := []struct {
desc string
userID string
pm alarms.PageMetadata
count int
err error
}{
{
desc: "list user alarms returns only accessible alarms",
userID: userID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 100,
},
count: 6,
err: nil,
},
{
desc: "list user alarms with limit",
userID: userID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 3,
},
count: 3,
err: nil,
},
{
desc: "list user alarms with offset",
userID: userID,
pm: alarms.PageMetadata{
Offset: 4,
Limit: 100,
},
count: 2,
err: nil,
},
{
desc: "list user alarms with domain filter",
userID: userID,
pm: alarms.PageMetadata{
DomainID: domainID,
Offset: 0,
Limit: 100,
},
count: 6,
err: nil,
},
{
desc: "list user alarms with non-existing domain returns 0",
userID: userID,
pm: alarms.PageMetadata{
DomainID: generateUUID(t),
Offset: 0,
Limit: 100,
},
count: 0,
err: nil,
},
{
desc: "list alarms for user with no role assignments returns 0",
userID: otherUserID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 100,
},
count: 0,
err: nil,
},
{
desc: "list alarms for admin user with role on all rules returns all alarms",
userID: adminUserID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 100,
},
count: 10,
err: nil,
},
{
desc: "list alarms for user with domain-level rule access returns all alarms",
userID: domainUserID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 100,
},
count: 10,
err: nil,
},
{
desc: "list user alarms ordered by created_at ascending",
userID: userID,
pm: alarms.PageMetadata{
Offset: 0,
Limit: 100,
Order: "created_at",
Dir: "asc",
},
count: 6,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.ListUserAlarms(context.Background(), tc.userID, tc.pm)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
return
}
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
assert.Equal(t, tc.count, len(page.Alarms), fmt.Sprintf("%s: expected %d alarms, got %d", tc.desc, tc.count, len(page.Alarms)))
})
}
}
func TestDeleteAlarm(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM alarms")
-10
View File
@@ -4,9 +4,6 @@
package postgres
import (
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
rpostgres "github.com/absmach/magistrala/re/postgres"
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
migrate "github.com/rubenv/sql-migrate"
)
@@ -54,12 +51,5 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
},
}
rulesMigration, err := rpostgres.Migration()
if err != nil {
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
}
alarmsMigration.Migrations = append(alarmsMigration.Migrations, rulesMigration.Migrations...)
return alarmsMigration, nil
}
+12 -10
View File
@@ -26,10 +26,10 @@ func NewService(idp magistrala.IDProvider, repo Repository) Service {
}
}
func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error {
func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) {
id, err := s.idp.ID()
if err != nil {
return err
return Alarm{}, err
}
alarm.ID = id
if alarm.CreatedAt.IsZero() {
@@ -37,14 +37,18 @@ func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error {
}
if err := alarm.Validate(); err != nil {
return err
return Alarm{}, err
}
if _, err = s.repo.CreateAlarm(ctx, alarm); err != nil && err != repoerr.ErrNotFound {
return err
created, err := s.repo.CreateAlarm(ctx, alarm)
if err != nil && err != repoerr.ErrNotFound {
return Alarm{}, err
}
if err == repoerr.ErrNotFound {
return Alarm{}, nil
}
return nil
return created, nil
}
func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID string) (Alarm, error) {
@@ -52,10 +56,8 @@ func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID
}
func (s *service) ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) {
if session.SuperAdmin {
return s.repo.ListAllAlarms(ctx, pm)
}
return s.repo.ListUserAlarms(ctx, session.UserID, pm)
pm.DomainID = session.DomainID
return s.repo.ListAllAlarms(ctx, pm)
}
func (s *service) DeleteAlarm(ctx context.Context, session authn.Session, alarmID string) error {
+2 -2
View File
@@ -72,7 +72,7 @@ func TestCreateAlarm(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := repo.On("CreateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err)
err := svc.CreateAlarm(context.Background(), tc.alarm)
_, err := svc.CreateAlarm(context.Background(), tc.alarm)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
})
@@ -205,7 +205,7 @@ func TestListAlarms(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
s := authn.Session{DomainID: tc.pm.DomainID}
repoCall := repo.On("ListUserAlarms", context.Background(), s.UserID, tc.pm).Return(tc.page, tc.err)
repoCall := repo.On("ListAllAlarms", context.Background(), tc.pm).Return(tc.page, tc.err)
_, err := svc.ListAlarms(context.Background(), s, tc.pm)
if tc.err != nil {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
+3 -6
View File
@@ -13,10 +13,7 @@ import (
"github.com/absmach/magistrala"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/clients"
"github.com/absmach/magistrala/groups"
"github.com/absmach/magistrala/pkg/errors"
"github.com/absmach/magistrala/users"
"github.com/gofrs/uuid/v5"
)
@@ -80,9 +77,9 @@ const (
DefStartLevel = 1
DefEndLevel = 0
DefStatus = "enabled"
DefClientStatus = clients.Enabled
DefUserStatus = users.Enabled
DefGroupStatus = groups.Enabled
DefClientStatus = "enabled"
DefUserStatus = "enabled"
DefGroupStatus = "enabled"
// ContentType represents JSON content type.
ContentType = "application/json"
-401
View File
@@ -1,401 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth_test
import (
"context"
"fmt"
"net"
"testing"
"time"
grpcAuthV1 "github.com/absmach/magistrala/api/grpc/auth/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/auth"
grpcapi "github.com/absmach/magistrala/auth/api/grpc/auth"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/policies"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
port = 8081
id = "testID"
usersType = "users"
adminPermission = "admin"
authoritiesObj = "authorities"
memberRelation = "member"
validToken = "valid"
inValidToken = "invalid"
validPATToken = "valid"
)
var (
domainID = testsutil.GenerateUUID(&testing.T{})
authAddr = fmt.Sprintf("localhost:%d", port)
clientID = testsutil.GenerateUUID(&testing.T{})
)
func startGRPCServer(svc auth.Service, port int) *grpc.Server {
listener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port))
server := grpc.NewServer()
grpcAuthV1.RegisterAuthServiceServer(server, grpcapi.NewAuthServer(svc))
go func() {
err := server.Serve(listener)
assert.Nil(&testing.T{}, err, fmt.Sprintf(`"Unexpected error creating auth server %s"`, err))
}()
return server
}
func TestIdentify(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
desc string
token string
key auth.Key
idt *grpcAuthV1.AuthNRes
svcErr error
err error
}{
{
desc: "authenticate user with valid user token",
token: validToken,
key: auth.Key{ID: "", Subject: id, Role: auth.UserRole},
idt: &grpcAuthV1.AuthNRes{UserId: id, UserRole: uint32(auth.UserRole)},
err: nil,
},
{
desc: "authenticate user with invalid user token",
token: "invalid",
key: auth.Key{},
idt: &grpcAuthV1.AuthNRes{},
svcErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "authenticate user with empty token",
token: "",
idt: &grpcAuthV1.AuthNRes{},
err: apiutil.ErrBearerToken,
},
{
desc: "authenticate user with valid PAT token",
token: "pat_" + validPATToken,
key: auth.Key{ID: id, Type: auth.PersonalAccessToken, Subject: clientID, Role: auth.UserRole},
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: clientID, UserRole: uint32(auth.UserRole), TokenType: uint32(auth.PersonalAccessToken)},
err: nil,
},
{
desc: "authenticate user with invalid PAT token",
token: "pat_invalid",
key: auth.Key{},
idt: &grpcAuthV1.AuthNRes{},
svcErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Identify", mock.Anything, tc.token).Return(tc.key, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestAuthorize(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
desc string
token string
authRequest *grpcAuthV1.AuthZReq
authResponse *grpcAuthV1.AuthZRes
expectedReq *policies.Policy
expectedPAT *auth.PATAuthz
err error
}{
{
desc: "authorize user with authorized token",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
},
{
desc: "authorize user with unauthorized token",
token: inValidToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "authorize user with empty subject",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: "",
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicySub,
},
{
desc: "authorize user with empty subject type",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: "",
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicySub,
},
{
desc: "authorize user with empty object",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: "",
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicyObj,
},
{
desc: "authorize user with empty object type",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: "",
Relation: memberRelation,
Permission: adminPermission,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicyObj,
},
{
desc: "authorize user with empty permission",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: "",
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMalformedPolicyPer,
},
{
desc: "authorize user with valid PAT token",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
ObjectType: policies.ClientType,
Domain: domainID,
Object: clientID,
},
PatReq: &grpcAuthV1.PATReq{
PatId: id,
Domain: domainID,
Operation: "view",
UserId: id,
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
},
{
desc: "authorize bootstrap PAT keeps PAT domain when policy domain is empty",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.MembershipPermission,
ObjectType: policies.DomainType,
Object: domainID,
},
PatReq: &grpcAuthV1.PATReq{
PatId: id,
Domain: domainID,
Operation: "create",
UserId: id,
EntityId: auth.AnyIDs,
EntityType: auth.BootstrapStr,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
expectedReq: &policies.Policy{
Domain: domainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: id,
Permission: policies.MembershipPermission,
ObjectType: policies.DomainType,
Object: domainID,
},
expectedPAT: &auth.PATAuthz{
PatID: id,
UserID: id,
EntityType: auth.BootstrapType,
EntityID: auth.AnyIDs,
Operation: "create",
Domain: domainID,
},
err: nil,
},
{
desc: "authorize user with unauthorized PAT token",
token: inValidToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
ObjectType: policies.ClientType,
Domain: domainID,
Object: clientID,
},
PatReq: &grpcAuthV1.PATReq{
PatId: id,
Domain: domainID,
Operation: "view",
UserId: id,
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "authorize PAT with missing user id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
ObjectType: policies.ClientType,
Domain: domainID,
Object: clientID,
},
PatReq: &grpcAuthV1.PATReq{
PatId: id,
Domain: domainID,
Operation: "view",
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingUserID,
},
{
desc: "authorize PAT with missing entity id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
PolicyReq: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
ObjectType: policies.ClientType,
Domain: domainID,
Object: clientID,
},
PatReq: &grpcAuthV1.PATReq{
PatId: id,
Domain: domainID,
Operation: "view",
UserId: id,
EntityType: auth.ClientsScopeStr,
},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
-24
View File
@@ -1,24 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package auth_test
import (
"os"
"testing"
"github.com/absmach/magistrala/auth/mocks"
)
var svc *mocks.Service
func TestMain(m *testing.M) {
svc = new(mocks.Service)
server := startGRPCServer(svc, port)
code := m.Run()
server.GracefulStop()
os.Exit(code)
}
-245
View File
@@ -1,245 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package token_test
import (
"context"
"fmt"
"net"
"testing"
"time"
grpcTokenV1 "github.com/absmach/magistrala/api/grpc/token/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/auth"
grpcapi "github.com/absmach/magistrala/auth/api/grpc/token"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
const (
port = 8082
validToken = "valid"
inValidToken = "invalid"
invalidID = "invalid"
)
var (
validID = testsutil.GenerateUUID(&testing.T{})
authAddr = fmt.Sprintf("localhost:%d", port)
)
func startGRPCServer(svc auth.Service, port int) *grpc.Server {
listener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port))
server := grpc.NewServer()
grpcTokenV1.RegisterTokenServiceServer(server, grpcapi.NewTokenServer(svc))
go func() {
err := server.Serve(listener)
assert.Nil(&testing.T{}, err, fmt.Sprintf(`"Unexpected error creating auth server %s"`, err))
}()
return server
}
func TestIssue(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
userId string
kind auth.KeyType
issueResponse auth.Token
err error
}{
{
desc: "issue for user with valid token",
userId: validID,
kind: auth.AccessKey,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
},
err: nil,
},
{
desc: "issue recovery key",
userId: validID,
kind: auth.RecoveryKey,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
},
err: nil,
},
{
desc: "issue API key unauthenticated",
userId: validID,
kind: auth.APIKey,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
},
{
desc: "issue for invalid key type",
userId: validID,
kind: 32,
issueResponse: auth.Token{},
err: errors.ErrMalformedEntity,
},
{
desc: "issue for user that does notexist",
userId: "",
kind: auth.APIKey,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
},
}
for _, tc := range cases {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Issue(context.Background(), &grpcTokenV1.IssueReq{UserId: tc.userId, Type: uint32(tc.kind)})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
func TestRefresh(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
token string
issueResponse auth.Token
err error
}{
{
desc: "refresh token with valid token",
token: validToken,
issueResponse: auth.Token{
AccessToken: validToken,
RefreshToken: validToken,
},
err: nil,
},
{
desc: "refresh token with invalid token",
token: inValidToken,
issueResponse: auth.Token{},
err: svcerr.ErrAuthentication,
},
{
desc: "refresh token with empty token",
token: "",
issueResponse: auth.Token{},
err: apiutil.ErrMissingSecret,
},
}
for _, tc := range cases {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Refresh(context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: tc.token})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
func TestRevoke(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
id string
err error
}{
{
desc: "revoke token with valid id",
id: validID,
err: nil,
},
{
desc: "revoke token with invalid id",
id: invalidID,
err: svcerr.ErrAuthentication,
},
{
desc: "revoke token with empty id",
id: "",
err: apiutil.ErrMissingID,
},
{
desc: "revoke already revoked token",
id: validID,
err: svcerr.ErrConflict,
},
}
for _, tc := range cases {
svcCall := svc.On("RevokeToken", mock.Anything, mock.Anything, tc.id).Return(tc.err)
_, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{TokenId: tc.id})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
func TestListUserRefreshTokens(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
userID string
listResponse []auth.TokenInfo
err error
}{
{
desc: "list tokens for user with valid id",
userID: validID,
listResponse: []auth.TokenInfo{
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 1"},
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 2"},
},
err: nil,
},
{
desc: "list tokens for user with empty list",
userID: validID,
listResponse: []auth.TokenInfo{},
err: nil,
},
{
desc: "list tokens with invalid user id",
userID: invalidID,
listResponse: nil,
err: svcerr.ErrAuthentication,
},
{
desc: "list tokens with empty user id",
userID: "",
listResponse: nil,
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
svcCall := svc.On("ListUserRefreshTokens", mock.Anything, tc.userID).Return(tc.listResponse, tc.err)
_, err := grpcClient.ListUserRefreshTokens(context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.userID})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
-24
View File
@@ -1,24 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package token_test
import (
"os"
"testing"
"github.com/absmach/magistrala/auth/mocks"
)
var svc *mocks.Service
func TestMain(m *testing.M) {
svc = new(mocks.Service)
server := startGRPCServer(svc, port)
code := m.Run()
server.GracefulStop()
os.Exit(code)
}
+3
View File
@@ -1,3 +1,6 @@
//go:build oldservices
// +build oldservices
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
-122
View File
@@ -1,122 +0,0 @@
# BOOTSTRAP SERVICE
New devices need to be configured properly and connected to the Magistrala. Bootstrap service is used in order to accomplish that. This service provides the following features:
1. Creating new Magistrala Clients
2. Providing basic configuration for the newly created Clients
3. Enabling/disabling bootstrap enrollments
Pre-provisioning a new Client is as simple as sending Configuration data to the Bootstrap service. Once the Client is online, it sends a request for initial config to Bootstrap service. Bootstrap service provides an API for enabling and disabling bootstrap enrollments. Bootstrapping does not implicitly enable an enrollment; it has to be done manually.
In order to bootstrap successfully, the Client needs to send bootstrapping request to the specific URL, as well as a secret key. This key and URL are pre-provisioned during the manufacturing process. If the Client is provisioned on the Bootstrap service side, the corresponding configuration will be sent as a response. Otherwise, the Client will be saved so that it can be provisioned later.
## Client Configuration Entity
Client Configuration consists of two logical parts: the custom configuration that can be interpreted by the Client itself and Magistrala-related configuration. Magistrala config contains:
1. corresponding Magistrala Client ID
2. corresponding Magistrala Client key
3. list of the Magistrala channels the Client is connected to
> Note: list of channels contains IDs of the Magistrala channels. These channels are _pre-provisioned_ on the Magistrala side and, unlike corresponding Magistrala Client, Bootstrap service is not able to create Magistrala Channels.
Enabling and disabling a bootstrap enrollment is an enrollment toggle. Configuration keeps a _status_:
| Status | What it means |
| -------- | ----------------------------------------------------------- |
| disabled | Enrollment exists, but bootstrap is not allowed |
| enabled | Enrollment can be used to fetch bootstrap configuration |
Switching between statuses `enabled` and `disabled` enables and disables the enrollment, respectively.
Client configuration also contains the so-called `external ID` and `external key`. An external ID is a unique identifier of corresponding Client. For example, a device MAC address is a good choice for external ID. External key is a secret key that is used for authentication during the bootstrapping procedure.
## Configuration
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
| Variable | Description | Default |
| ------------------------------ | -------------------------------------------------------------------------------- | --------------------------------- |
| MG_BOOTSTRAP_LOG_LEVEL | Log level for Bootstrap (debug, info, warn, error) | info |
| MG_BOOTSTRAP_DB_HOST | Database host address | localhost |
| MG_BOOTSTRAP_DB_PORT | Database host port | 5432 |
| MG_BOOTSTRAP_DB_USER | Database user | magistrala |
| MG_BOOTSTRAP_DB_PASS | Database password | magistrala |
| MG_BOOTSTRAP_DB_NAME | Name of the database used by the service | bootstrap |
| MG_BOOTSTRAP_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
| MG_BOOTSTRAP_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
| MG_BOOTSTRAP_DB_SSL_KEY | Path to the PEM encoded key file | "" |
| MG_BOOTSTRAP_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
| MG_BOOTSTRAP_ENCRYPT_KEY | Secret key for secure bootstrapping encryption | 12345678910111213141516171819202 |
| MG_BOOTSTRAP_HTTP_HOST | Bootstrap service HTTP host | "" |
| MG_BOOTSTRAP_HTTP_PORT | Bootstrap service HTTP port | 9013 |
| MG_BOOTSTRAP_HTTP_SERVER_CERT | Path to server certificate in pem format | "" |
| MG_BOOTSTRAP_HTTP_SERVER_KEY | Path to server key in pem format | "" |
| MG_BOOTSTRAP_EVENT_CONSUMER | Bootstrap service event source consumer name | bootstrap |
| MG_ES_URL | Event store URL | <nats://localhost:4222> |
| MG_AUTH_GRPC_URL | Auth service Auth gRPC URL | <localhost:8181> |
| MG_AUTH_GRPC_TIMEOUT | Auth service Auth gRPC request timeout in seconds | 1s |
| MG_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded auth service Auth gRPC client certificate file | "" |
| MG_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded auth service Auth gRPC client key file | "" |
| MG_AUTH_GRPC_SERVER_CERTS | Path to the PEM encoded auth server Auth gRPC server trusted CA certificate file | "" |
| MG_CLIENTS_URL | Base URL for Magistrala Clients | <http://localhost:9000> |
| MG_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true |
| MG_BOOTSTRAP_INSTANCE_ID | Bootstrap service instance ID | "" |
## Deployment
The service itself is distributed as Docker container. Check the [`bootstrap`](https://github.com/absmach/magistrala/blob/main/docker/addons/bootstrap/docker-compose.yaml) service section in docker-compose file to see how service is deployed.
To start the service outside of the container, execute the following shell script:
```bash
# download the latest version of the service
git clone https://github.com/absmach/magistrala
cd magistrala
# compile the servic e
make bootstrap
# copy binary to bin
make install
# set the environment variables and run the service
MG_BOOTSTRAP_LOG_LEVEL=info \
MG_BOOTSTRAP_DB_HOST=localhost \
MG_BOOTSTRAP_DB_PORT=5432 \
MG_BOOTSTRAP_DB_USER=magistrala \
MG_BOOTSTRAP_DB_PASS=magistrala \
MG_BOOTSTRAP_DB_NAME=bootstrap \
MG_BOOTSTRAP_DB_SSL_MODE=disable \
MG_BOOTSTRAP_DB_SSL_CERT="" \
MG_BOOTSTRAP_DB_SSL_KEY="" \
MG_BOOTSTRAP_DB_SSL_ROOT_CERT="" \
MG_BOOTSTRAP_HTTP_HOST=localhost \
MG_BOOTSTRAP_HTTP_PORT=9013 \
MG_BOOTSTRAP_HTTP_SERVER_CERT="" \
MG_BOOTSTRAP_HTTP_SERVER_KEY="" \
MG_BOOTSTRAP_EVENT_CONSUMER=bootstrap \
MG_ES_URL=nats://localhost:4222 \
MG_AUTH_GRPC_URL=localhost:8181 \
MG_AUTH_GRPC_TIMEOUT=1s \
MG_AUTH_GRPC_CLIENT_CERT="" \
MG_AUTH_GRPC_CLIENT_KEY="" \
MG_AUTH_GRPC_SERVER_CERTS="" \
MG_CLIENTS_URL=http://localhost:9000 \
MG_JAEGER_URL=http://localhost:14268/api/traces \
MG_JAEGER_TRACE_RATIO=1.0 \
MG_SEND_TELEMETRY=true \
MG_BOOTSTRAP_INSTANCE_ID="" \
$GOBIN/magistrala-bootstrap
```
Setting `MG_BOOTSTRAP_HTTP_SERVER_CERT` and `MG_BOOTSTRAP_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key.
Setting `MG_AUTH_GRPC_CLIENT_CERT` and `MG_AUTH_GRPC_CLIENT_KEY` will enable TLS against the auth service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_AUTH_GRPC_SERVER_CERTS` will enable TLS against the auth service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs.
## Usage
For more information about service capabilities and its usage, please check out the [API documentation](https://docs.api.magistrala.absmach.eu/?urls.primaryName=bootstrap.yaml).
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package api contains implementation of bootstrap service HTTP API.
package api
-506
View File
@@ -1,506 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/go-kit/kit/endpoint"
)
func addEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(addReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
config := bootstrap.Config{
ExternalID: req.ExternalID,
ExternalKey: req.ExternalKey,
Name: req.Name,
ClientCert: req.ClientCert,
ClientKey: req.ClientKey,
CACert: req.CACert,
Content: req.Content,
ProfileID: req.ProfileID,
RenderContext: req.RenderContext,
}
saved, err := svc.Add(ctx, session, req.token, config)
if err != nil {
return nil, err
}
res := configRes{
ID: saved.ID,
ExternalID: saved.ExternalID,
Name: saved.Name,
Content: saved.Content,
Status: saved.Status,
ProfileID: saved.ProfileID,
RenderContext: saved.RenderContext,
ClientCert: saved.ClientCert,
CACert: saved.CACert,
ClientKey: saved.ClientKey,
created: true,
}
return res, nil
}
}
func updateCertEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(updateCertReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
cfg, err := svc.UpdateCert(ctx, session, req.configID, req.ClientCert, req.ClientKey, req.CACert)
if err != nil {
return nil, err
}
res := updateConfigRes{
ID: cfg.ID,
ClientCert: cfg.ClientCert,
CACert: cfg.CACert,
ClientKey: cfg.ClientKey,
}
return res, nil
}
}
func viewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(entityReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
config, err := svc.View(ctx, session, req.id)
if err != nil {
return nil, err
}
res := viewRes{
ID: config.ID,
ExternalID: config.ExternalID,
Name: config.Name,
Content: config.Content,
Status: config.Status,
ProfileID: config.ProfileID,
RenderContext: config.RenderContext,
}
return res, nil
}
}
func updateEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(updateReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
config := bootstrap.Config{
ID: req.id,
Name: req.Name,
Content: req.Content,
RenderContext: req.RenderContext,
}
if err := svc.Update(ctx, session, config); err != nil {
return nil, err
}
return updateRes{}, nil
}
}
func listEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(listReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
page, err := svc.List(ctx, session, req.filter, req.offset, req.limit)
if err != nil {
return nil, err
}
res := listRes{
Total: page.Total,
Offset: page.Offset,
Limit: page.Limit,
Configs: []viewRes{},
}
for _, cfg := range page.Configs {
view := viewRes{
ID: cfg.ID,
ExternalID: cfg.ExternalID,
Name: cfg.Name,
Content: cfg.Content,
Status: cfg.Status,
ProfileID: cfg.ProfileID,
RenderContext: cfg.RenderContext,
}
res.Configs = append(res.Configs, view)
}
return res, nil
}
}
func removeEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(entityReq)
if err := req.validate(); err != nil {
return removeRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.Remove(ctx, session, req.id); err != nil {
return nil, err
}
return removeRes{}, nil
}
}
func bootstrapEndpoint(svc bootstrap.Service, reader bootstrap.ConfigReader, secure bool) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(bootstrapReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
cfg, err := svc.Bootstrap(ctx, req.key, req.id, secure)
if err != nil {
return nil, err
}
return reader.ReadConfig(cfg, secure)
}
}
func enableConfigEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(changeConfigStatusReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
cfg, err := svc.EnableConfig(ctx, session, req.id)
if err != nil {
return nil, err
}
return changeConfigStatusRes{Config: cfg}, nil
}
}
func disableConfigEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(changeConfigStatusReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
cfg, err := svc.DisableConfig(ctx, session, req.id)
if err != nil {
return nil, err
}
return changeConfigStatusRes{Config: cfg}, nil
}
}
func createProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(createProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
saved, err := svc.CreateProfile(ctx, session, req.Profile)
if err != nil {
return nil, err
}
return profileRes{Profile: saved, created: true}, nil
}
}
func uploadProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(uploadProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
saved, err := svc.CreateProfile(ctx, session, req.Profile)
if err != nil {
return nil, err
}
return profileRes{Profile: saved, created: true}, nil
}
}
func viewProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(viewProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
p, err := svc.ViewProfile(ctx, session, req.profileID)
if err != nil {
return nil, err
}
return profileRes{Profile: p}, nil
}
}
func profileSlotsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(viewProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
p, err := svc.ViewProfile(ctx, session, req.profileID)
if err != nil {
return nil, err
}
return profileSlotsRes{BindingSlots: p.BindingSlots}, nil
}
}
func renderPreviewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(renderPreviewReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
p, err := svc.ViewProfile(ctx, session, req.profileID)
if err != nil {
return nil, err
}
cfg := req.Config
bindings := req.Bindings
if req.ConfigID != "" {
stored, err := svc.View(ctx, session, req.ConfigID)
if err != nil {
return nil, err
}
cfg = stored
bindings, err = svc.ListBindings(ctx, session, req.ConfigID)
if err != nil {
return nil, err
}
}
cfg.DomainID = session.DomainID
cfg.ProfileID = p.ID
if cfg.RenderContext == nil {
cfg.RenderContext = req.RenderContext
}
rendered, err := bootstrap.NewRenderer().Render(p, cfg, bindings)
if err != nil {
return nil, err
}
return renderPreviewRes{Content: string(rendered)}, nil
}
}
func updateProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(updateProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
req.Profile.ID = req.profileID
updated, err := svc.UpdateProfile(ctx, session, req.Profile)
if err != nil {
return nil, err
}
return profileRes{Profile: updated}, nil
}
}
func deleteProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(deleteProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.DeleteProfile(ctx, session, req.profileID); err != nil {
return nil, err
}
return removeRes{}, nil
}
}
func listProfilesEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(listProfilesReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
page, err := svc.ListProfiles(ctx, session, req.offset, req.limit, req.name)
if err != nil {
return nil, err
}
return profilesPageRes{ProfilesPage: page}, nil
}
}
func assignProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(assignProfileReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.AssignProfile(ctx, session, req.configID, req.ProfileID); err != nil {
return nil, err
}
return removeRes{}, nil
}
}
func bindResourcesEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(bindResourcesReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.BindResources(ctx, session, req.token, req.configID, req.Bindings); err != nil {
return nil, err
}
return removeRes{}, nil
}
}
func listBindingsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(listBindingsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
snapshots, err := svc.ListBindings(ctx, session, req.configID)
if err != nil {
return nil, err
}
return bindingsRes{Bindings: snapshots}, nil
}
}
func refreshBindingsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(refreshBindingsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.RefreshBindings(ctx, session, req.token, req.configID); err != nil {
return nil, err
}
return removeRes{}, nil
}
}
File diff suppressed because it is too large Load Diff
-280
View File
@@ -1,280 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/bootstrap"
)
const maxLimitSize = 100
type addReq struct {
token string
ExternalID string `json:"external_id"`
ExternalKey string `json:"external_key"`
Name string `json:"name"`
Content string `json:"content"`
ClientCert string `json:"client_cert"`
ClientKey string `json:"client_key"`
CACert string `json:"ca_cert"`
ProfileID string `json:"profile_id"`
RenderContext map[string]any `json:"render_context"`
}
func (req addReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.ExternalID == "" {
return apiutil.ErrMissingID
}
if req.ExternalKey == "" {
return apiutil.ErrBearerKey
}
return nil
}
type entityReq struct {
id string
}
func (req entityReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type updateReq struct {
id string
Name string `json:"name"`
Content string `json:"content"`
RenderContext map[string]any `json:"render_context"`
}
func (req updateReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type updateCertReq struct {
configID string
ClientCert string `json:"client_cert"`
ClientKey string `json:"client_key"`
CACert string `json:"ca_cert"`
}
func (req updateCertReq) validate() error {
if req.configID == "" {
return apiutil.ErrMissingID
}
return nil
}
type listReq struct {
filter bootstrap.Filter
offset uint64
limit uint64
}
func (req listReq) validate() error {
if req.limit > maxLimitSize {
return apiutil.ErrLimitSize
}
return nil
}
type bootstrapReq struct {
key string
id string
}
func (req bootstrapReq) validate() error {
if req.key == "" {
return apiutil.ErrBearerKey
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type changeConfigStatusReq struct {
token string
id string
}
func (req changeConfigStatusReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
// --- Profile requests ---
type createProfileReq struct {
bootstrap.Profile
}
func (req createProfileReq) validate() error {
if req.Name == "" {
return apiutil.ErrMissingName
}
return nil
}
type uploadProfileReq struct {
bootstrap.Profile
}
func (req uploadProfileReq) validate() error {
if req.Name == "" {
return apiutil.ErrMissingName
}
return nil
}
type viewProfileReq struct {
profileID string
}
func (req viewProfileReq) validate() error {
if req.profileID == "" {
return apiutil.ErrMissingID
}
return nil
}
type updateProfileReq struct {
profileID string
bootstrap.Profile
}
func (req updateProfileReq) validate() error {
if req.profileID == "" {
return apiutil.ErrMissingID
}
return nil
}
type renderPreviewReq struct {
profileID string
ConfigID string `json:"config_id,omitempty"`
Config bootstrap.Config `json:"config"`
RenderContext map[string]any `json:"render_context,omitempty"`
Bindings []bootstrap.BindingSnapshot `json:"bindings,omitempty"`
}
func (req renderPreviewReq) validate() error {
if req.profileID == "" {
return apiutil.ErrMissingID
}
return nil
}
type deleteProfileReq struct {
profileID string
}
func (req deleteProfileReq) validate() error {
if req.profileID == "" {
return apiutil.ErrMissingID
}
return nil
}
type listProfilesReq struct {
offset uint64
limit uint64
name string
}
func (req listProfilesReq) validate() error {
if req.limit == 0 || req.limit > maxLimitSize {
return apiutil.ErrLimitSize
}
return nil
}
// --- Enrollment binding requests ---
type assignProfileReq struct {
configID string
ProfileID string `json:"profile_id"`
}
func (req assignProfileReq) validate() error {
if req.configID == "" || req.ProfileID == "" {
return apiutil.ErrMissingID
}
return nil
}
type bindResourcesReq struct {
token string
configID string
Bindings []bootstrap.BindingRequest `json:"bindings"`
}
func (req bindResourcesReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.configID == "" {
return apiutil.ErrMissingID
}
if len(req.Bindings) == 0 {
return apiutil.ErrEmptyList
}
for _, b := range req.Bindings {
if b.Slot == "" || b.Type == "" || b.ResourceID == "" {
return apiutil.ErrMissingID
}
}
return nil
}
type listBindingsReq struct {
configID string
}
func (req listBindingsReq) validate() error {
if req.configID == "" {
return apiutil.ErrMissingID
}
return nil
}
type refreshBindingsReq struct {
token string
configID string
}
func (req refreshBindingsReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.configID == "" {
return apiutil.ErrMissingID
}
return nil
}
-245
View File
@@ -1,245 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"fmt"
"testing"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/stretchr/testify/assert"
)
func TestAddReqValidation(t *testing.T) {
cases := []struct {
desc string
token string
externalID string
externalKey string
err error
}{
{
desc: "valid request",
token: "token",
externalID: "external-id",
externalKey: "external-key",
err: nil,
},
{
desc: "empty token",
token: "",
externalID: "external-id",
externalKey: "external-key",
err: apiutil.ErrBearerToken,
},
{
desc: "empty external ID",
token: "token",
externalID: "",
externalKey: "external-key",
err: apiutil.ErrMissingID,
},
{
desc: "empty external key",
token: "token",
externalID: "external-id",
externalKey: "",
err: apiutil.ErrBearerKey,
},
{
desc: "empty external key and external ID",
token: "token",
externalID: "",
externalKey: "",
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
req := addReq{
token: tc.token,
ExternalID: tc.externalID,
ExternalKey: tc.externalKey,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestEntityReqValidation(t *testing.T) {
cases := []struct {
desc string
id string
err error
}{
{
desc: "empty id",
id: "",
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
req := entityReq{
id: tc.id,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestUpdateReqValidation(t *testing.T) {
cases := []struct {
desc string
id string
err error
}{
{
desc: "valid request",
id: "id",
err: nil,
},
{
desc: "empty id",
id: "",
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
req := updateReq{
id: tc.id,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestUpdateCertReqValidation(t *testing.T) {
cases := []struct {
desc string
configID string
err error
}{
{
desc: "empty config id",
configID: "",
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
req := updateCertReq{
configID: tc.configID,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestListReqValidation(t *testing.T) {
cases := []struct {
desc string
offset uint64
limit uint64
err error
}{
{
desc: "too large limit",
offset: 0,
limit: maxLimitSize + 1,
err: apiutil.ErrLimitSize,
},
{
desc: "default limit",
offset: 0,
limit: defLimit,
err: nil,
},
}
for _, tc := range cases {
req := listReq{
offset: tc.offset,
limit: tc.limit,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestBootstrapReqValidation(t *testing.T) {
cases := []struct {
desc string
externKey string
externID string
err error
}{
{
desc: "empty external key",
externKey: "",
externID: "id",
err: apiutil.ErrBearerKey,
},
{
desc: "empty external id",
externKey: "key",
externID: "",
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
req := bootstrapReq{
id: tc.externID,
key: tc.externKey,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestChangeConfigStatusReqValidation(t *testing.T) {
cases := []struct {
desc string
token string
id string
err error
}{
{
desc: "empty token",
token: "",
id: "id",
err: apiutil.ErrBearerToken,
},
{
desc: "empty id",
token: "token",
id: "",
err: apiutil.ErrMissingID,
},
{
desc: "valid request",
token: "token",
id: "id",
err: nil,
},
}
for _, tc := range cases {
req := changeConfigStatusReq{
token: tc.token,
id: tc.id,
}
err := req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
-223
View File
@@ -1,223 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"fmt"
"net/http"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/bootstrap"
)
var (
_ magistrala.Response = (*removeRes)(nil)
_ magistrala.Response = (*configRes)(nil)
_ magistrala.Response = (*changeConfigStatusRes)(nil)
_ magistrala.Response = (*viewRes)(nil)
_ magistrala.Response = (*listRes)(nil)
)
type removeRes struct{}
func (res removeRes) Code() int {
return http.StatusNoContent
}
func (res removeRes) Headers() map[string]string {
return map[string]string{}
}
func (res removeRes) Empty() bool {
return true
}
type updateRes struct{}
func (res updateRes) Code() int {
return http.StatusOK
}
func (res updateRes) Headers() map[string]string {
return map[string]string{}
}
func (res updateRes) Empty() bool {
return true
}
type configRes struct {
ID string `json:"id"`
ExternalID string `json:"external_id"`
Name string `json:"name,omitempty"`
Content string `json:"content,omitempty"`
Status bootstrap.Status `json:"status"`
ProfileID string `json:"profile_id,omitempty"`
RenderContext map[string]any `json:"render_context,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
CACert string `json:"ca_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
created bool
}
func (res configRes) Code() int {
if res.created {
return http.StatusCreated
}
return http.StatusOK
}
func (res configRes) Headers() map[string]string {
if res.created {
return map[string]string{
"Location": fmt.Sprintf("/clients/configs/%s", res.ID),
}
}
return map[string]string{}
}
func (res configRes) Empty() bool {
return false
}
type viewRes struct {
ID string `json:"id,omitempty"`
ExternalID string `json:"external_id"`
Content string `json:"content,omitempty"`
Name string `json:"name,omitempty"`
Status bootstrap.Status `json:"status"`
ProfileID string `json:"profile_id,omitempty"`
RenderContext map[string]any `json:"render_context,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
CACert string `json:"ca_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
}
func (res viewRes) Code() int {
return http.StatusOK
}
func (res viewRes) Headers() map[string]string {
return map[string]string{}
}
func (res viewRes) Empty() bool {
return false
}
type listRes struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Configs []viewRes `json:"configs"`
}
func (res listRes) Code() int {
return http.StatusOK
}
func (res listRes) Headers() map[string]string {
return map[string]string{}
}
func (res listRes) Empty() bool {
return false
}
type changeConfigStatusRes struct {
bootstrap.Config
}
func (res changeConfigStatusRes) Code() int {
return http.StatusOK
}
func (res changeConfigStatusRes) Headers() map[string]string {
return map[string]string{}
}
func (res changeConfigStatusRes) Empty() bool {
return false
}
type updateConfigRes struct {
ID string `json:"id,omitempty"`
CACert string `json:"ca_cert,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
}
func (res updateConfigRes) Code() int {
return http.StatusOK
}
func (res updateConfigRes) Headers() map[string]string {
return map[string]string{}
}
func (res updateConfigRes) Empty() bool {
return false
}
// profileRes is returned on create (201) or update (200).
type profileRes struct {
bootstrap.Profile
created bool
}
func (res profileRes) Code() int {
if res.created {
return http.StatusCreated
}
return http.StatusOK
}
func (res profileRes) Headers() map[string]string {
if res.created {
return map[string]string{
"Location": fmt.Sprintf("/bootstrap/profiles/%s", res.ID),
}
}
return map[string]string{}
}
func (res profileRes) Empty() bool { return false }
// profilesPageRes is returned by ListProfiles.
type profilesPageRes struct {
bootstrap.ProfilesPage
}
func (res profilesPageRes) Code() int { return http.StatusOK }
func (res profilesPageRes) Headers() map[string]string { return map[string]string{} }
func (res profilesPageRes) Empty() bool { return false }
// profileSlotsRes is returned by profile slots endpoint.
type profileSlotsRes struct {
BindingSlots []bootstrap.BindingSlot `json:"binding_slots"`
}
func (res profileSlotsRes) Code() int { return http.StatusOK }
func (res profileSlotsRes) Headers() map[string]string { return map[string]string{} }
func (res profileSlotsRes) Empty() bool { return false }
// renderPreviewRes is returned by profile render-preview endpoint.
type renderPreviewRes struct {
Content string `json:"content"`
}
func (res renderPreviewRes) Code() int { return http.StatusOK }
func (res renderPreviewRes) Headers() map[string]string { return map[string]string{} }
func (res renderPreviewRes) Empty() bool { return false }
// bindingsRes is returned by ListBindings.
type bindingsRes struct {
Bindings []bootstrap.BindingSnapshot `json:"bindings"`
}
func (res bindingsRes) Code() int { return http.StatusOK }
func (res bindingsRes) Headers() map[string]string { return map[string]string{} }
func (res bindingsRes) Empty() bool { return false }
-512
View File
@@ -1,512 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/url"
"strings"
"github.com/absmach/magistrala"
api "github.com/absmach/magistrala/api/http"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/pelletier/go-toml/v2"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"gopkg.in/yaml.v3"
)
const (
contentType = "application/json"
yamlContentType = "yaml"
tomlContentType = "toml"
byteContentType = "application/octet-stream"
offsetKey = "offset"
limitKey = "limit"
defOffset = 0
defLimit = 10
)
var (
fullMatch = []string{"status", "external_id", "id"}
partialMatch = []string{"name"}
// ErrBootstrap indicates error in getting bootstrap configuration.
ErrBootstrap = errors.New("failed to read bootstrap configuration")
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc bootstrap.Service, authn smqauthn.AuthNMiddleware, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
r := chi.NewRouter()
r.Route("/{domainID}/clients", func(r chi.Router) {
r.Group(func(r chi.Router) {
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/configs", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
addEndpoint(svc),
decodeAddRequest,
api.EncodeResponse,
opts...), "add").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listEndpoint(svc),
decodeListRequest,
api.EncodeResponse,
opts...), "list").ServeHTTP)
r.Get("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
viewEndpoint(svc),
decodeEntityRequest,
api.EncodeResponse,
opts...), "view").ServeHTTP)
r.Patch("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
updateEndpoint(svc),
decodeUpdateRequest,
api.EncodeResponse,
opts...), "update").ServeHTTP)
r.Delete("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
removeEndpoint(svc),
decodeEntityRequest,
api.EncodeResponse,
opts...), "remove").ServeHTTP)
r.Patch("/certs/{configID}", otelhttp.NewHandler(kithttp.NewServer(
updateCertEndpoint(svc),
decodeUpdateCertRequest,
api.EncodeResponse,
opts...), "update_cert").ServeHTTP)
r.Post("/{configID}/enable", otelhttp.NewHandler(kithttp.NewServer(
enableConfigEndpoint(svc),
decodeChangeConfigStatusRequest,
api.EncodeResponse,
opts...), "enable_config").ServeHTTP)
r.Post("/{configID}/disable", otelhttp.NewHandler(kithttp.NewServer(
disableConfigEndpoint(svc),
decodeChangeConfigStatusRequest,
api.EncodeResponse,
opts...), "disable_config").ServeHTTP)
})
})
// Profile and enrollment binding endpoints.
r.Route("/bootstrap", func(r chi.Router) {
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
r.Route("/profiles", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
createProfileEndpoint(svc),
decodeCreateProfileRequest,
api.EncodeResponse,
opts...), "create_profile").ServeHTTP)
r.Post("/upload", otelhttp.NewHandler(kithttp.NewServer(
uploadProfileEndpoint(svc),
decodeUploadProfileRequest,
api.EncodeResponse,
opts...), "upload_profile").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listProfilesEndpoint(svc),
decodeListProfilesRequest,
api.EncodeResponse,
opts...), "list_profiles").ServeHTTP)
r.Get("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
viewProfileEndpoint(svc),
decodeProfileEntityRequest,
api.EncodeResponse,
opts...), "view_profile").ServeHTTP)
r.Get("/{profileID}/slots", otelhttp.NewHandler(kithttp.NewServer(
profileSlotsEndpoint(svc),
decodeProfileEntityRequest,
api.EncodeResponse,
opts...), "profile_slots").ServeHTTP)
r.Post("/{profileID}/render-preview", otelhttp.NewHandler(kithttp.NewServer(
renderPreviewEndpoint(svc),
decodeRenderPreviewRequest,
api.EncodeResponse,
opts...), "render_preview").ServeHTTP)
r.Patch("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
updateProfileEndpoint(svc),
decodeUpdateProfileRequest,
api.EncodeResponse,
opts...), "update_profile").ServeHTTP)
r.Delete("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
deleteProfileEndpoint(svc),
decodeDeleteProfileRequest,
api.EncodeResponse,
opts...), "delete_profile").ServeHTTP)
})
r.Route("/enrollments", func(r chi.Router) {
r.Patch("/{configID}/profile", otelhttp.NewHandler(kithttp.NewServer(
assignProfileEndpoint(svc),
decodeAssignProfileRequest,
api.EncodeResponse,
opts...), "assign_profile").ServeHTTP)
r.Put("/{configID}/bindings", otelhttp.NewHandler(kithttp.NewServer(
bindResourcesEndpoint(svc),
decodeBindResourcesRequest,
api.EncodeResponse,
opts...), "bind_resources").ServeHTTP)
r.Get("/{configID}/bindings", otelhttp.NewHandler(kithttp.NewServer(
listBindingsEndpoint(svc),
decodeEnrollmentEntityRequest,
api.EncodeResponse,
opts...), "list_bindings").ServeHTTP)
r.Post("/{configID}/bindings/refresh", otelhttp.NewHandler(kithttp.NewServer(
refreshBindingsEndpoint(svc),
decodeRefreshBindingsRequest,
api.EncodeResponse,
opts...), "refresh_bindings").ServeHTTP)
})
})
})
r.Route("/clients/bootstrap", func(r chi.Router) {
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
bootstrapEndpoint(svc, reader, false),
decodeBootstrapRequest,
api.EncodeResponse,
opts...), "bootstrap").ServeHTTP)
r.Get("/{externalID}", otelhttp.NewHandler(kithttp.NewServer(
bootstrapEndpoint(svc, reader, false),
decodeBootstrapRequest,
api.EncodeResponse,
opts...), "bootstrap").ServeHTTP)
r.Get("/secure/{externalID}", otelhttp.NewHandler(kithttp.NewServer(
bootstrapEndpoint(svc, reader, true),
decodeBootstrapRequest,
encodeSecureRes,
opts...), "bootstrap_secure").ServeHTTP)
})
r.Get("/health", magistrala.Health("bootstrap", instanceID))
r.Handle("/metrics", promhttp.Handler())
return r
}
func decodeAddRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := addReq{
token: apiutil.ExtractBearerToken(r),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeUpdateRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := updateReq{
id: chi.URLParam(r, "configID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeUpdateCertRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := updateCertReq{
configID: chi.URLParam(r, "configID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeListRequest(_ context.Context, r *http.Request) (any, error) {
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
q, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidQueryParams)
}
req := listReq{
filter: parseFilter(q),
offset: o,
limit: l,
}
rawStatus := q.Get("status")
parsed, err := bootstrap.ToStatus(rawStatus)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidQueryParams)
}
if parsed == bootstrap.AllStatus {
delete(req.filter.FullMatch, "status")
} else {
req.filter.FullMatch["status"] = parsed.String()
}
return req, nil
}
func decodeBootstrapRequest(_ context.Context, r *http.Request) (any, error) {
req := bootstrapReq{
id: chi.URLParam(r, "externalID"),
key: apiutil.ExtractClientSecret(r),
}
return req, nil
}
func decodeChangeConfigStatusRequest(_ context.Context, r *http.Request) (any, error) {
return changeConfigStatusReq{
token: apiutil.ExtractBearerToken(r),
id: chi.URLParam(r, "configID"),
}, nil
}
func decodeEntityRequest(_ context.Context, r *http.Request) (any, error) {
req := entityReq{
id: chi.URLParam(r, "configID"),
}
return req, nil
}
func encodeSecureRes(_ context.Context, w http.ResponseWriter, response any) error {
w.Header().Set("Content-Type", byteContentType)
w.WriteHeader(http.StatusOK)
if b, ok := response.([]byte); ok {
if _, err := w.Write(b); err != nil {
return err
}
}
return nil
}
func parseFilter(values url.Values) bootstrap.Filter {
ret := bootstrap.Filter{
FullMatch: make(map[string]string),
PartialMatch: make(map[string]string),
}
for k := range values {
if contains(fullMatch, k) {
ret.FullMatch[k] = values.Get(k)
}
if contains(partialMatch, k) {
ret.PartialMatch[k] = strings.ToLower(values.Get(k))
}
}
return ret
}
func contains(l []string, s string) bool {
for _, v := range l {
if v == s {
return true
}
}
return false
}
func decodeCreateProfileRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
var req createProfileReq
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) {
contentType := r.Header.Get("Content-Type")
var req uploadProfileReq
var inferredFormat bootstrap.ContentFormat
switch {
case strings.Contains(contentType, "json"):
inferredFormat = bootstrap.ContentFormatJSON
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
case strings.Contains(contentType, yamlContentType):
inferredFormat = bootstrap.ContentFormatYAML
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
if err := decodeYAMLProfile(body, &req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
case strings.Contains(contentType, tomlContentType):
inferredFormat = bootstrap.ContentFormatTOML
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
if err := decodeTOMLProfile(body, &req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
default:
return nil, apiutil.ErrUnsupportedContentType
}
if req.Profile.ContentFormat == "" {
req.Profile.ContentFormat = inferredFormat
}
return req, nil
}
func decodeYAMLProfile(body []byte, profile *bootstrap.Profile) error {
var raw map[string]any
if err := yaml.Unmarshal(body, &raw); err != nil {
return err
}
return decodeProfileMap(raw, profile)
}
func decodeTOMLProfile(body []byte, profile *bootstrap.Profile) error {
var raw map[string]any
if err := toml.Unmarshal(body, &raw); err != nil {
return err
}
return decodeProfileMap(raw, profile)
}
func decodeProfileMap(raw map[string]any, profile *bootstrap.Profile) error {
body, err := json.Marshal(raw)
if err != nil {
return err
}
return json.Unmarshal(body, profile)
}
func decodeListProfilesRequest(_ context.Context, r *http.Request) (any, error) {
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
return listProfilesReq{offset: o, limit: l, name: n}, nil
}
func decodeProfileEntityRequest(_ context.Context, r *http.Request) (any, error) {
return viewProfileReq{profileID: chi.URLParam(r, "profileID")}, nil
}
func decodeDeleteProfileRequest(_ context.Context, r *http.Request) (any, error) {
return deleteProfileReq{profileID: chi.URLParam(r, "profileID")}, nil
}
func decodeUpdateProfileRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := updateProfileReq{profileID: chi.URLParam(r, "profileID")}
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeRenderPreviewRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := renderPreviewReq{profileID: chi.URLParam(r, "profileID")}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeAssignProfileRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := assignProfileReq{configID: chi.URLParam(r, "configID")}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeBindResourcesRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
req := bindResourcesReq{
token: apiutil.ExtractBearerToken(r),
configID: chi.URLParam(r, "configID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeEnrollmentEntityRequest(_ context.Context, r *http.Request) (any, error) {
return listBindingsReq{configID: chi.URLParam(r, "configID")}, nil
}
func decodeRefreshBindingsRequest(_ context.Context, r *http.Request) (any, error) {
return refreshBindingsReq{
token: apiutil.ExtractBearerToken(r),
configID: chi.URLParam(r, "configID"),
}, nil
}
-73
View File
@@ -1,73 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import "context"
// Config represents a bootstrap enrollment.
type Config struct {
ID string `json:"id"`
DomainID string `json:"domain_id,omitempty"`
Name string `json:"name,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
CACert string `json:"ca_cert,omitempty"`
ExternalID string `json:"external_id"`
ExternalKey string `json:"external_key"`
Content string `json:"content,omitempty"`
Status Status `json:"status"`
ProfileID string `json:"profile_id,omitempty"`
RenderContext map[string]any `json:"render_context,omitempty"`
}
// Filter is used for the search filters.
type Filter struct {
FullMatch map[string]string
PartialMatch map[string]string
}
// ConfigsPage contains page related metadata as well as list of Configs that
// belong to this page.
type ConfigsPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Configs []Config `json:"configs"`
}
// ConfigRepository specifies a Config persistence API.
type ConfigRepository interface {
// Save persists the Config. Successful operation is indicated by non-nil
// error response.
Save(ctx context.Context, cfg Config) (string, error)
// RetrieveByID retrieves the Config having the provided identifier, that is owned
// by the specified user.
RetrieveByID(ctx context.Context, domainID, id string) (Config, error)
// RetrieveAll retrieves a subset of Configs that belong to the given domain,
// with given filter parameters.
RetrieveAll(ctx context.Context, domainID string, filter Filter, offset, limit uint64) ConfigsPage
// RetrieveByExternalID returns Config for given external ID.
RetrieveByExternalID(ctx context.Context, externalID string) (Config, error)
// Update updates an existing Config. A non-nil error is returned
// to indicate operation failure.
Update(ctx context.Context, cfg Config) error
// AssignProfile sets the profile reference for the given Config.
AssignProfile(ctx context.Context, domainID, id, profileID string) error
// UpdateCerts updates and returns an existing Config certificate and domainID.
// A non-nil error is returned to indicate operation failure.
UpdateCert(ctx context.Context, domainID, id, clientCert, clientKey, caCert string) (Config, error)
// Remove removes the Config having the provided identifier, that is owned
// by the specified user.
Remove(ctx context.Context, domainID, id string) error
// ChangeStatus changes the Status of the Config owned by the specific user.
ChangeStatus(ctx context.Context, domainID, id string, status Status) error
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package bootstrap contains the domain concept definitions needed to support
// Magistrala bootstrap service functionality.
package bootstrap
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package events provides the domain concept definitions needed to support
// bootstrap events functionality.
package events
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package producer contains the domain events needed to support
// event sourcing of Bootstrap service actions.
package producer
-288
View File
@@ -1,288 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package producer
import (
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/pkg/events"
)
const (
configPrefix = "bootstrap.config."
configCreate = configPrefix + "create"
configUpdate = configPrefix + "update"
configRemove = configPrefix + "remove"
configView = configPrefix + "view"
configList = configPrefix + "list"
clientPrefix = "bootstrap.client."
clientBootstrap = clientPrefix + "bootstrap"
configEnable = configPrefix + "enable"
configDisable = configPrefix + "disable"
certUpdate = "bootstrap.cert.update"
profilePrefix = "bootstrap.profile."
profileCreate = profilePrefix + "create"
profileView = profilePrefix + "view"
profileUpdate = profilePrefix + "update"
profileList = profilePrefix + "list"
profileDelete = profilePrefix + "delete"
profileAssign = profilePrefix + "assign"
bindingsPrefix = "bootstrap.bindings."
bindingsBind = bindingsPrefix + "bind"
bindingsList = bindingsPrefix + "list"
bindingsRefresh = bindingsPrefix + "refresh"
)
var (
_ events.Event = (*configEvent)(nil)
_ events.Event = (*removeConfigEvent)(nil)
_ events.Event = (*bootstrapEvent)(nil)
_ events.Event = (*enableConfigEvent)(nil)
_ events.Event = (*disableConfigEvent)(nil)
_ events.Event = (*updateCertEvent)(nil)
_ events.Event = (*listConfigsEvent)(nil)
_ events.Event = (*profileEvent)(nil)
_ events.Event = (*deleteProfileEvent)(nil)
_ events.Event = (*assignProfileEvent)(nil)
_ events.Event = (*bindResourcesEvent)(nil)
_ events.Event = (*listBindingsEvent)(nil)
_ events.Event = (*refreshBindingsEvent)(nil)
)
type configEvent struct {
bootstrap.Config
operation string
}
func (ce configEvent) Encode() (map[string]any, error) {
val := map[string]any{
"status": ce.Status.String(),
"operation": ce.operation,
}
if ce.ID != "" {
val["config_id"] = ce.ID
}
if ce.Content != "" {
val["content"] = ce.Content
}
if ce.DomainID != "" {
val["domain_id"] = ce.DomainID
}
if ce.Name != "" {
val["name"] = ce.Name
}
if ce.ExternalID != "" {
val["external_id"] = ce.ExternalID
}
if ce.ClientCert != "" {
val["client_cert"] = ce.ClientCert
}
if ce.ClientKey != "" {
val["client_key"] = ce.ClientKey
}
if ce.CACert != "" {
val["ca_cert"] = ce.CACert
}
if ce.Content != "" {
val["content"] = ce.Content
}
return val, nil
}
type removeConfigEvent struct {
config string
}
func (rce removeConfigEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": rce.config,
"operation": configRemove,
}, nil
}
type listConfigsEvent struct {
offset uint64
limit uint64
fullMatch map[string]string
partialMatch map[string]string
}
func (rce listConfigsEvent) Encode() (map[string]any, error) {
val := map[string]any{
"offset": rce.offset,
"limit": rce.limit,
"operation": configList,
}
if len(rce.fullMatch) > 0 {
val["full_match"] = rce.fullMatch
}
if len(rce.partialMatch) > 0 {
val["full_match"] = rce.partialMatch
}
return val, nil
}
type bootstrapEvent struct {
bootstrap.Config
externalID string
success bool
}
func (be bootstrapEvent) Encode() (map[string]any, error) {
val := map[string]any{
"external_id": be.externalID,
"success": be.success,
"operation": clientBootstrap,
}
if be.ID != "" {
val["config_id"] = be.ID
}
if be.Content != "" {
val["content"] = be.Content
}
if be.DomainID != "" {
val["domain_id"] = be.DomainID
}
if be.Name != "" {
val["name"] = be.Name
}
if be.ExternalID != "" {
val["external_id"] = be.ExternalID
}
if be.ClientCert != "" {
val["client_cert"] = be.ClientCert
}
if be.ClientKey != "" {
val["client_key"] = be.ClientKey
}
if be.CACert != "" {
val["ca_cert"] = be.CACert
}
if be.Content != "" {
val["content"] = be.Content
}
return val, nil
}
type enableConfigEvent struct {
configID string
}
func (e enableConfigEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": e.configID,
"operation": configEnable,
}, nil
}
type disableConfigEvent struct {
configID string
}
func (e disableConfigEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": e.configID,
"operation": configDisable,
}, nil
}
type updateCertEvent struct {
configID string
clientCert string
clientKey string
caCert string
}
func (uce updateCertEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": uce.configID,
"client_cert": uce.clientCert,
"client_key": uce.clientKey,
"ca_cert": uce.caCert,
"operation": certUpdate,
}, nil
}
type profileEvent struct {
bootstrap.Profile
operation string
}
func (pe profileEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": pe.operation,
}
if pe.ID != "" {
val["profile_id"] = pe.ID
}
if pe.DomainID != "" {
val["domain_id"] = pe.DomainID
}
if pe.Name != "" {
val["name"] = pe.Name
}
return val, nil
}
type deleteProfileEvent struct {
profileID string
}
func (dpe deleteProfileEvent) Encode() (map[string]any, error) {
return map[string]any{
"profile_id": dpe.profileID,
"operation": profileDelete,
}, nil
}
type assignProfileEvent struct {
configID string
profileID string
}
func (ape assignProfileEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": ape.configID,
"profile_id": ape.profileID,
"operation": profileAssign,
}, nil
}
type bindResourcesEvent struct {
configID string
slots []string
}
func (bre bindResourcesEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": bre.configID,
"slots": bre.slots,
"operation": bindingsBind,
}, nil
}
type listBindingsEvent struct {
configID string
}
func (lbe listBindingsEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": lbe.configID,
"operation": bindingsList,
}, nil
}
type refreshBindingsEvent struct {
configID string
}
func (rbe refreshBindingsEvent) Encode() (map[string]any, error) {
return map[string]any{
"config_id": rbe.configID,
"operation": bindingsRefresh,
}, nil
}
-61
View File
@@ -1,61 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package producer_test
import (
"context"
"fmt"
"log"
"os"
"testing"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
redisURL string
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "redis",
Tag: "7.2.4-alpine",
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
opts, err := redis.ParseURL(redisURL)
if err != nil {
log.Fatalf("Could not parse redis URL: %s", err)
}
if err := pool.Retry(func() error {
redisClient = redis.NewClient(opts)
return redisClient.Ping(context.Background()).Err()
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
-284
View File
@@ -1,284 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package producer
import (
"context"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/events"
)
var _ bootstrap.Service = (*eventStore)(nil)
const (
magistralaPrefix = "magistrala."
createStream = magistralaPrefix + configCreate
listStream = magistralaPrefix + configList
removeStream = magistralaPrefix + configRemove
updateCertStream = magistralaPrefix + certUpdate
bootstrapStream = magistralaPrefix + clientBootstrap
enableConfigStream = magistralaPrefix + configEnable
disableConfigStream = magistralaPrefix + configDisable
createProfileStream = magistralaPrefix + profileCreate
viewProfileStream = magistralaPrefix + profileView
updateProfileStream = magistralaPrefix + profileUpdate
listProfilesStream = magistralaPrefix + profileList
deleteProfileStream = magistralaPrefix + profileDelete
assignProfileStream = magistralaPrefix + profileAssign
bindResourcesStream = magistralaPrefix + bindingsBind
listBindingsStream = magistralaPrefix + bindingsList
refreshBindingsStream = magistralaPrefix + bindingsRefresh
)
type eventStore struct {
events.Publisher
svc bootstrap.Service
}
// NewEventStoreMiddleware returns wrapper around bootstrap service that sends
// events to event store.
func NewEventStoreMiddleware(svc bootstrap.Service, publisher events.Publisher) bootstrap.Service {
return &eventStore{
svc: svc,
Publisher: publisher,
}
}
func (es *eventStore) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
saved, err := es.svc.Add(ctx, session, token, cfg)
if err != nil {
return saved, err
}
ev := configEvent{
saved, configCreate,
}
if err := es.Publish(ctx, createStream, ev); err != nil {
return saved, err
}
return saved, err
}
func (es *eventStore) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
cfg, err := es.svc.View(ctx, session, id)
if err != nil {
return cfg, err
}
ev := configEvent{
cfg, configView,
}
if err := es.Publish(ctx, magistralaPrefix+configView, ev); err != nil {
return cfg, err
}
return cfg, err
}
func (es *eventStore) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
if err := es.svc.Update(ctx, session, cfg); err != nil {
return err
}
ev := configEvent{
cfg, configUpdate,
}
return es.Publish(ctx, magistralaPrefix+configUpdate, ev)
}
func (es eventStore) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
cfg, err := es.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
if err != nil {
return cfg, err
}
ev := updateCertEvent{
configID: id,
clientCert: clientCert,
clientKey: clientKey,
caCert: caCert,
}
if err := es.Publish(ctx, updateCertStream, ev); err != nil {
return cfg, err
}
return cfg, nil
}
func (es *eventStore) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
bp, err := es.svc.List(ctx, session, filter, offset, limit)
if err != nil {
return bp, err
}
ev := listConfigsEvent{
offset: offset,
limit: limit,
fullMatch: filter.FullMatch,
partialMatch: filter.PartialMatch,
}
if err := es.Publish(ctx, listStream, ev); err != nil {
return bp, err
}
return bp, nil
}
func (es *eventStore) Remove(ctx context.Context, session smqauthn.Session, id string) error {
if err := es.svc.Remove(ctx, session, id); err != nil {
return err
}
ev := removeConfigEvent{
config: id,
}
return es.Publish(ctx, removeStream, ev)
}
func (es *eventStore) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
cfg, err := es.svc.Bootstrap(ctx, externalKey, externalID, secure)
ev := bootstrapEvent{
cfg,
externalID,
true,
}
if err != nil {
ev.success = false
}
if err := es.Publish(ctx, bootstrapStream, ev); err != nil {
return cfg, err
}
return cfg, err
}
func (es *eventStore) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
cfg, err := es.svc.EnableConfig(ctx, session, id)
if err != nil {
return cfg, err
}
ev := enableConfigEvent{configID: id}
if err := es.Publish(ctx, enableConfigStream, ev); err != nil {
return cfg, err
}
return cfg, nil
}
func (es *eventStore) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
cfg, err := es.svc.DisableConfig(ctx, session, id)
if err != nil {
return cfg, err
}
ev := disableConfigEvent{configID: id}
if err := es.Publish(ctx, disableConfigStream, ev); err != nil {
return cfg, err
}
return cfg, nil
}
func (es *eventStore) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
saved, err := es.svc.CreateProfile(ctx, session, p)
if err != nil {
return saved, err
}
ev := profileEvent{saved, profileCreate}
if err := es.Publish(ctx, createProfileStream, ev); err != nil {
return saved, err
}
return saved, nil
}
func (es *eventStore) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
p, err := es.svc.ViewProfile(ctx, session, profileID)
if err != nil {
return p, err
}
ev := profileEvent{p, profileView}
if err := es.Publish(ctx, viewProfileStream, ev); err != nil {
return p, err
}
return p, nil
}
func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
updated, err := es.svc.UpdateProfile(ctx, session, p)
if err != nil {
return bootstrap.Profile{}, err
}
ev := profileEvent{updated, profileUpdate}
return updated, es.Publish(ctx, updateProfileStream, ev)
}
func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
pp, err := es.svc.ListProfiles(ctx, session, offset, limit, name)
if err != nil {
return pp, err
}
ev := profileEvent{operation: profileList}
if err := es.Publish(ctx, listProfilesStream, ev); err != nil {
return pp, err
}
return pp, nil
}
func (es *eventStore) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
if err := es.svc.DeleteProfile(ctx, session, profileID); err != nil {
return err
}
ev := deleteProfileEvent{profileID: profileID}
return es.Publish(ctx, deleteProfileStream, ev)
}
func (es *eventStore) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
if err := es.svc.AssignProfile(ctx, session, configID, profileID); err != nil {
return err
}
ev := assignProfileEvent{configID: configID, profileID: profileID}
return es.Publish(ctx, assignProfileStream, ev)
}
func (es *eventStore) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
if err := es.svc.BindResources(ctx, session, token, configID, bindings); err != nil {
return err
}
slots := make([]string, len(bindings))
for i, b := range bindings {
slots[i] = b.Slot
}
ev := bindResourcesEvent{configID: configID, slots: slots}
return es.Publish(ctx, bindResourcesStream, ev)
}
func (es *eventStore) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
bs, err := es.svc.ListBindings(ctx, session, configID)
if err != nil {
return bs, err
}
ev := listBindingsEvent{configID: configID}
if err := es.Publish(ctx, listBindingsStream, ev); err != nil {
return bs, err
}
return bs, nil
}
func (es *eventStore) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
if err := es.svc.RefreshBindings(ctx, session, token, configID); err != nil {
return err
}
ev := refreshBindingsEvent{configID: configID}
return es.Publish(ctx, refreshBindingsStream, ev)
}
-914
View File
@@ -1,914 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package producer_test
import (
"context"
"fmt"
"strconv"
"strings"
"testing"
"time"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/bootstrap/events/producer"
bootstraphasher "github.com/absmach/magistrala/bootstrap/hasher"
"github.com/absmach/magistrala/bootstrap/mocks"
"github.com/absmach/magistrala/internal/testsutil"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/events/store"
sdkmocks "github.com/absmach/magistrala/pkg/sdk/mocks"
"github.com/absmach/magistrala/pkg/uuid"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
streamID = "magistrala.bootstrap"
validToken = "validToken"
unknownID = "unknown"
configPrefix = "config."
configCreate = configPrefix + "create"
configView = configPrefix + "view"
configUpdate = configPrefix + "update"
configRemove = configPrefix + "remove"
configList = configPrefix + "list"
clientPrefix = "client."
clientBootstrap = clientPrefix + "bootstrap"
configEnable = configPrefix + "enable"
configDisable = configPrefix + "disable"
certUpdate = "cert.update"
)
var (
encKey = []byte("1234567891011121")
domainID = testsutil.GenerateUUID(&testing.T{})
validID = testsutil.GenerateUUID(&testing.T{})
config = bootstrap.Config{
ID: testsutil.GenerateUUID(&testing.T{}),
ExternalID: testsutil.GenerateUUID(&testing.T{}),
ExternalKey: testsutil.GenerateUUID(&testing.T{}),
Content: "config",
Status: bootstrap.EnabledStatus,
}
)
type testVariable struct {
svc bootstrap.Service
boot *mocks.ConfigRepository
sdk *sdkmocks.SDK
}
func newTestVariable(t *testing.T, redisURL string) testVariable {
boot := new(mocks.ConfigRepository)
sdk := new(sdkmocks.SDK)
idp := uuid.NewMock()
svc := bootstrap.New(boot, nil, nil, nil, nil, sdk, bootstraphasher.New(), encKey, idp)
publisher, err := store.NewPublisher(context.Background(), redisURL, "bootstrap-es-pub-test")
require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
svc = producer.NewEventStoreMiddleware(svc, publisher)
return testVariable{
svc: svc,
boot: boot,
sdk: sdk,
}
}
func TestAdd(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
cases := []struct {
desc string
config bootstrap.Config
token string
session smqauthn.Session
id string
domainID string
saveErr error
err error
event map[string]any
}{
{
desc: "create config successfully",
config: config,
token: validToken,
id: validID,
domainID: domainID,
event: map[string]any{
"config_id": "1",
"domain_id": domainID,
"name": config.Name,
"external_id": config.ExternalID,
"content": config.Content,
"timestamp": time.Now().Unix(),
"operation": configCreate,
},
err: nil,
},
{
desc: "create config with failed to save",
config: config,
token: validToken,
id: validID,
domainID: domainID,
event: nil,
saveErr: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("Save", context.Background(), mock.Anything).Return(mock.Anything, tc.saveErr)
_, err := tv.svc.Add(context.Background(), tc.session, tc.token, tc.config)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestView(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
nonExisting := config
nonExisting.ID = unknownID
cases := []struct {
desc string
config bootstrap.Config
token string
session smqauthn.Session
id string
domainID string
retrieveErr error
err error
event map[string]any
}{
{
desc: "view successfully",
config: config,
token: validToken,
id: validID,
domainID: domainID,
err: nil,
event: map[string]any{
"config_id": config.ID,
"domain_id": config.DomainID,
"name": config.Name,
"external_id": config.ExternalID,
"content": config.Content,
"timestamp": time.Now().Unix(),
"operation": configView,
},
},
{
desc: "view with failed retrieve",
config: nonExisting,
token: validToken,
id: validID,
domainID: domainID,
retrieveErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.config.ID).Return(config, tc.retrieveErr)
_, err := tv.svc.View(context.Background(), tc.session, tc.config.ID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
msg := streams[0].Messages[0]
event = msg.Values
event["timestamp"] = msg.ID
lastID = msg.ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestUpdate(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
modified := config
modified.Content = "new-config"
modified.Name = "new name"
nonExisting := config
nonExisting.ID = unknownID
cases := []struct {
desc string
config bootstrap.Config
token string
session smqauthn.Session
id string
domainID string
updateErr error
err error
event map[string]any
}{
{
desc: "update config successfully",
config: modified,
token: validToken,
id: validID,
domainID: domainID,
err: nil,
event: map[string]any{
"name": modified.Name,
"content": modified.Content,
"timestamp": time.Now().UnixNano(),
"operation": configUpdate,
"external_id": modified.ExternalID,
"config_id": modified.ID,
"domain_id": domainID,
"status": bootstrap.Disabled,
"occurred_at": time.Now().UnixNano(),
},
},
{
desc: "update with failed update",
config: nonExisting,
token: validToken,
id: validID,
domainID: domainID,
updateErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("Update", context.Background(), mock.Anything).Return(tc.updateErr)
err := tv.svc.Update(context.Background(), tc.session, tc.config)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
msg := streams[0].Messages[0]
event = msg.Values
event["timestamp"] = msg.ID
lastID = msg.ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestUpdateCert(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
cases := []struct {
desc string
configID string
userID string
domainID string
token string
session smqauthn.Session
clientCert string
clientKey string
caCert string
updateErr error
err error
event map[string]any
}{
{
desc: "update cert successfully",
configID: config.ID,
userID: validID,
domainID: domainID,
token: validToken,
clientCert: "clientCert",
clientKey: "clientKey",
caCert: "caCert",
err: nil,
event: map[string]any{
"client_cert": "clientCert",
"client_key": "clientKey",
"ca_cert": "caCert",
"operation": certUpdate,
},
},
{
desc: "update cert with failed update",
configID: "clientID",
token: validToken,
userID: validID,
domainID: domainID,
clientCert: "clientCert",
clientKey: "clientKey",
caCert: "caCert",
updateErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
event: nil,
},
{
desc: "update cert with empty client certificate",
configID: config.ID,
token: validToken,
userID: validID,
domainID: domainID,
clientCert: "",
clientKey: "clientKey",
caCert: "caCert",
err: nil,
event: nil,
},
{
desc: "update cert with empty client key",
configID: config.ID,
token: validToken,
userID: validID,
domainID: domainID,
clientCert: "clientCert",
clientKey: "",
caCert: "caCert",
err: nil,
event: nil,
},
{
desc: "update cert with empty CA certificate",
configID: config.ID,
token: validToken,
userID: validID,
domainID: domainID,
clientCert: "clientCert",
clientKey: "clientKey",
caCert: "",
err: nil,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("UpdateCert", context.Background(), tc.domainID, tc.configID, tc.clientCert, tc.clientKey, tc.caCert).Return(config, tc.updateErr)
_, err := tv.svc.UpdateCert(context.Background(), tc.session, tc.configID, tc.clientCert, tc.clientKey, tc.caCert)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestList(t *testing.T) {
tv := newTestVariable(t, redisURL)
numClients := 101
var c bootstrap.Config
saved := make([]bootstrap.Config, 0)
for i := 0; i < numClients; i++ {
c = config
c.ExternalID = testsutil.GenerateUUID(t)
c.ExternalKey = testsutil.GenerateUUID(t)
c.Name = fmt.Sprintf("%s-%d", config.Name, i)
if i == 41 {
c.Status = bootstrap.Active
}
saved = append(saved, c)
}
cases := []struct {
desc string
token string
session smqauthn.Session
userID string
domainID string
config bootstrap.ConfigsPage
filter bootstrap.Filter
offset uint64
limit uint64
retrieveErr error
err error
event map[string]any
}{
{
desc: "list successfully as super admin",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
config: bootstrap.ConfigsPage{
Total: uint64(len(saved)),
Offset: 0,
Limit: 10,
Configs: saved[0:10],
},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
err: nil,
event: map[string]any{
"config_id": c.ID,
"domain_id": c.DomainID,
"name": c.Name,
"external_id": c.ExternalID,
"content": c.Content,
"timestamp": time.Now().Unix(),
"operation": configList,
},
},
{
desc: "list successfully as domain admin",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
config: bootstrap.ConfigsPage{
Total: uint64(len(saved)),
Offset: 0,
Limit: 10,
Configs: saved[0:10],
},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
err: nil,
event: map[string]any{
"config_id": c.ID,
"domain_id": c.DomainID,
"name": c.Name,
"external_id": c.ExternalID,
"content": c.Content,
"timestamp": time.Now().Unix(),
"operation": configList,
},
},
{
desc: "list successfully as non admin",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID},
config: bootstrap.ConfigsPage{
Total: uint64(len(saved)),
Offset: 0,
Limit: 10,
Configs: saved[0:10],
},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
err: nil,
event: map[string]any{
"config_id": c.ID,
"domain_id": c.DomainID,
"name": c.Name,
"external_id": c.ExternalID,
"content": c.Content,
"timestamp": time.Now().Unix(),
"operation": configList,
},
},
{
desc: "list as super admin with failed retrieve all",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
retrieveErr: nil,
err: nil,
event: nil,
},
{
desc: "list as domain admin with failed retrieve all",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
retrieveErr: nil,
err: nil,
event: nil,
},
{
desc: "list as non admin with failed retrieve all",
token: validToken,
userID: validID,
domainID: domainID,
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID},
filter: bootstrap.Filter{},
offset: 0,
limit: 10,
retrieveErr: nil,
err: nil,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
repoCall := tv.boot.On("RetrieveAll", context.Background(), mock.Anything, tc.filter, tc.offset, tc.limit).Return(tc.config, tc.retrieveErr)
_, err := tv.svc.List(context.Background(), tc.session, tc.filter, tc.offset, tc.limit)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestRemove(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
nonExisting := config
nonExisting.ID = unknownID
cases := []struct {
desc string
configID string
userID string
domainID string
token string
session smqauthn.Session
removeErr error
err error
event map[string]any
}{
{
desc: "remove config successfully",
configID: config.ID,
token: validToken,
userID: validID,
domainID: domainID,
err: nil,
event: map[string]any{
"config_id": config.ID,
"timestamp": time.Now().Unix(),
"operation": configRemove,
},
},
{
desc: "remove config with failed removal",
configID: nonExisting.ID,
token: validToken,
userID: validID,
domainID: domainID,
removeErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("Remove", context.Background(), mock.Anything, mock.Anything).Return(tc.removeErr)
err := tv.svc.Remove(context.Background(), tc.session, tc.configID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestBootstrap(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
cases := []struct {
desc string
externalID string
externalKey string
err error
retrieveErr error
event map[string]any
}{
{
desc: "bootstrap successfully",
externalID: config.ExternalID,
externalKey: config.ExternalKey,
err: nil,
event: map[string]any{
"external_id": config.ExternalID,
"success": "1",
"timestamp": time.Now().Unix(),
"operation": clientBootstrap,
},
},
{
desc: "bootstrap with an error",
externalID: "external_id1",
externalKey: "external_id",
retrieveErr: bootstrap.ErrBootstrap,
err: bootstrap.ErrBootstrap,
event: map[string]any{
"external_id": "external_id",
"success": "0",
"timestamp": time.Now().Unix(),
"operation": clientBootstrap,
},
},
}
lastID := "0"
for _, tc := range cases {
repoCall := tv.boot.On("RetrieveByExternalID", context.Background(), mock.Anything).Return(config, tc.retrieveErr)
_, err = tv.svc.Bootstrap(context.Background(), tc.externalKey, tc.externalID, false)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
}
}
func TestEnableConfig(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
cases := []struct {
desc string
id string
userID string
domainID string
session smqauthn.Session
retrieveErr error
statusErr error
err error
event map[string]any
}{
{
desc: "enable config",
id: config.ID,
userID: validID,
domainID: domainID,
err: nil,
event: map[string]any{
"config_id": config.ID,
"timestamp": time.Now().Unix(),
"operation": configEnable,
},
},
{
desc: "enable with failed retrieve by ID",
id: "",
userID: validID,
domainID: domainID,
retrieveErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
event: nil,
},
{
desc: "enable with repo status error",
id: config.ID,
userID: validID,
domainID: domainID,
statusErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
event: nil,
},
}
disabledConfig := config
disabledConfig.Status = bootstrap.DisabledStatus
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(disabledConfig, tc.retrieveErr)
repoCall1 := tv.boot.On("ChangeStatus", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.statusErr)
_, err := tv.svc.EnableConfig(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
repoCall1.Unset()
}
}
func TestDisableConfig(t *testing.T) {
err := redisClient.FlushAll(context.Background()).Err()
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
tv := newTestVariable(t, redisURL)
cases := []struct {
desc string
id string
userID string
domainID string
session smqauthn.Session
retrieveErr error
statusErr error
err error
event map[string]any
}{
{
desc: "disable config",
id: config.ID,
userID: validID,
domainID: domainID,
err: nil,
event: map[string]any{
"config_id": config.ID,
"timestamp": time.Now().Unix(),
"operation": configDisable,
},
},
{
desc: "disable with failed retrieve by ID",
id: "",
userID: validID,
domainID: domainID,
retrieveErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
event: nil,
},
{
desc: "disable with repo status error",
id: config.ID,
userID: validID,
domainID: domainID,
statusErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
event: nil,
},
}
lastID := "0"
for _, tc := range cases {
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(config, tc.retrieveErr)
repoCall1 := tv.boot.On("ChangeStatus", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.statusErr)
_, err := tv.svc.DisableConfig(context.Background(), tc.session, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
Streams: []string{streamID, lastID},
Count: 1,
Block: time.Second,
}).Val()
var event map[string]any
if len(streams) > 0 && len(streams[0].Messages) > 0 {
event := streams[0].Messages
lastID = event[0].ID
}
test(t, tc.event, event, tc.desc)
repoCall.Unset()
repoCall1.Unset()
}
}
func test(t *testing.T, expected, actual map[string]any, description string) {
if expected != nil && actual != nil {
ts1 := expected["timestamp"].(int64)
ats := actual["timestamp"].(string)
ts2, err := strconv.ParseInt(strings.Split(ats, "-")[0], 10, 64)
require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid timestamp, got %s", description, err))
ts1 = ts1 / 1e9
ts2 = ts2 / 1e3
if assert.WithinDuration(t, time.Unix(ts1, 0), time.Unix(ts2, 0), time.Second, fmt.Sprintf("%s: timestamp is not in valid range of 1 second", description)) {
delete(expected, "timestamp")
delete(actual, "timestamp")
}
oa1 := expected["occurred_at"].(int64)
aoa := actual["occurred_at"].(string)
oa2, err := strconv.ParseInt(aoa, 10, 64)
require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid occurred_at, got %s", description, err))
oa1 = oa1 / 1e9
oa2 = oa2 / 1e9
if assert.WithinDuration(t, time.Unix(oa1, 0), time.Unix(oa2, 0), time.Second, fmt.Sprintf("%s: occurred_at is not in valid range of 1 second", description)) {
delete(expected, "occurred_at")
delete(actual, "occurred_at")
}
assert.Equal(t, expected, actual, fmt.Sprintf("%s: got incorrect event\n", description))
}
}
-219
View File
@@ -1,219 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/authz"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/policies"
)
const (
createOperation = "create"
viewOperation = "view"
updateOperation = "update"
updateCertOperation = "update_cert"
listOperation = "list"
removeOperation = "remove"
changeStateOperation = "change_state"
)
var _ bootstrap.Service = (*authorizationMiddleware)(nil)
type authorizationMiddleware struct {
svc bootstrap.Service
authz authz.Authorization
}
// AuthorizationMiddleware adds authorization to the clients service.
func AuthorizationMiddleware(svc bootstrap.Service, authz authz.Authorization) bootstrap.Service {
return &authorizationMiddleware{
svc: svc,
authz: authz,
}
}
func (am *authorizationMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, createOperation, auth.AnyIDs); err != nil {
return bootstrap.Config{}, err
}
return am.svc.Add(ctx, session, token, cfg)
}
func (am *authorizationMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, id); err != nil {
return bootstrap.Config{}, err
}
return am.svc.View(ctx, session, id)
}
func (am *authorizationMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, cfg.ID); err != nil {
return err
}
return am.svc.Update(ctx, session, cfg)
}
func (am *authorizationMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateCertOperation, id); err != nil {
return bootstrap.Config{}, err
}
return am.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
}
func (am *authorizationMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
if err := am.checkSuperAdmin(ctx, session); err == nil {
session.SuperAdmin = true
}
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err == nil {
session.SuperAdmin = true
}
return am.svc.List(ctx, session, filter, offset, limit)
}
func (am *authorizationMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, removeOperation, id); err != nil {
return err
}
return am.svc.Remove(ctx, session, id)
}
func (am *authorizationMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
return am.svc.Bootstrap(ctx, externalKey, externalID, secure)
}
func (am *authorizationMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, changeStateOperation, id); err != nil {
return bootstrap.Config{}, err
}
return am.svc.EnableConfig(ctx, session, id)
}
func (am *authorizationMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, changeStateOperation, id); err != nil {
return bootstrap.Config{}, err
}
return am.svc.DisableConfig(ctx, session, id)
}
func (am *authorizationMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, createOperation, auth.AnyIDs); err != nil {
return bootstrap.Profile{}, err
}
return am.svc.CreateProfile(ctx, session, p)
}
func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, auth.AnyIDs); err != nil {
return bootstrap.Profile{}, err
}
return am.svc.ViewProfile(ctx, session, profileID)
}
func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, auth.AnyIDs); err != nil {
return bootstrap.Profile{}, err
}
return am.svc.UpdateProfile(ctx, session, p)
}
func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err != nil {
return bootstrap.ProfilesPage{}, err
}
return am.svc.ListProfiles(ctx, session, offset, limit, name)
}
func (am *authorizationMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, removeOperation, auth.AnyIDs); err != nil {
return err
}
return am.svc.DeleteProfile(ctx, session, profileID)
}
func (am *authorizationMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
return err
}
return am.svc.AssignProfile(ctx, session, configID, profileID)
}
func (am *authorizationMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
return err
}
return am.svc.BindResources(ctx, session, token, configID, bindings)
}
func (am *authorizationMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, configID); err != nil {
return nil, err
}
return am.svc.ListBindings(ctx, session, configID)
}
func (am *authorizationMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
return err
}
return am.svc.RefreshBindings(ctx, session, token, configID)
}
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session smqauthn.Session) error {
if session.Role != smqauthn.SuperAdminRole {
return svcerr.ErrSuperAdminAction
}
if err := am.authz.Authorize(ctx, authz.PolicyReq{
SubjectType: policies.UserType,
Subject: session.UserID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.MagistralaObject,
}, nil); err != nil {
return err
}
return nil
}
func (am *authorizationMiddleware) authorize(ctx context.Context, session smqauthn.Session, domain, subjType, subjKind, subj, perm, objType, obj, operation, entityID string) error {
req := authz.PolicyReq{
Domain: domain,
SubjectType: subjType,
SubjectKind: subjKind,
Subject: subj,
Permission: perm,
ObjectType: objType,
Object: obj,
}
var pat *authz.PATReq
if session.PatID != "" {
pat = &authz.PATReq{
UserID: session.UserID,
PatID: session.PatID,
EntityID: entityID,
EntityType: auth.BootstrapType.String(),
Operation: operation,
Domain: session.DomainID,
}
}
if err := am.authz.Authorize(ctx, req, pat); err != nil {
return err
}
return nil
}
-355
View File
@@ -1,355 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !test
package middleware
import (
"context"
"log/slog"
"time"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
)
var _ bootstrap.Service = (*loggingMiddleware)(nil)
type loggingMiddleware struct {
logger *slog.Logger
svc bootstrap.Service
}
// LoggingMiddleware adds logging facilities to the bootstrap service.
func LoggingMiddleware(svc bootstrap.Service, logger *slog.Logger) bootstrap.Service {
return &loggingMiddleware{logger, svc}
}
// Add logs the add request. It logs the client ID and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", saved.ID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Add new bootstrap failed", args...)
return
}
lm.logger.Info("Add new bootstrap completed successfully", args...)
}(time.Now())
return lm.svc.Add(ctx, session, token, cfg)
}
// View logs the view request. It logs the client ID and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", id),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("View client config failed", args...)
return
}
lm.logger.Info("View client config completed successfully", args...)
}(time.Now())
return lm.svc.View(ctx, session, id)
}
// Update logs the update request. It logs bootstrap client ID and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Group("config",
slog.String("config_id", cfg.ID),
slog.String("name", cfg.Name),
),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Update bootstrap config failed", args...)
return
}
lm.logger.Info("Update bootstrap config completed successfully", args...)
}(time.Now())
return lm.svc.Update(ctx, session, cfg)
}
// UpdateCert logs the update_cert request. It logs config ID and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", cfg.ID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Update bootstrap config certificate failed", args...)
return
}
lm.logger.Info("Update bootstrap config certificate completed successfully", args...)
}(time.Now())
return lm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
}
// List logs the list request. It logs offset, limit and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (res bootstrap.ConfigsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Group("page",
slog.Any("filter", filter),
slog.Uint64("offset", offset),
slog.Uint64("limit", limit),
slog.Uint64("total", res.Total),
),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List configs failed", args...)
return
}
lm.logger.Info("List configs completed successfully", args...)
}(time.Now())
return lm.svc.List(ctx, session, filter, offset, limit)
}
// Remove logs the remove request. It logs bootstrap ID and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", id),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Remove bootstrap config failed", args...)
return
}
lm.logger.Info("Remove bootstrap config completed successfully", args...)
}(time.Now())
return lm.svc.Remove(ctx, session, id)
}
func (lm *loggingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("external_id", externalID),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("View bootstrap config failed", args...)
return
}
lm.logger.Info("View bootstrap completed successfully", args...)
}(time.Now())
return lm.svc.Bootstrap(ctx, externalKey, externalID, secure)
}
func (lm *loggingMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("id", id),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Enable config failed", args...)
return
}
lm.logger.Info("Enable config completed successfully", args...)
}(time.Now())
return lm.svc.EnableConfig(ctx, session, id)
}
func (lm *loggingMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("id", id),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Disable config failed", args...)
return
}
lm.logger.Info("Disable config completed successfully", args...)
}(time.Now())
return lm.svc.DisableConfig(ctx, session, id)
}
func (lm *loggingMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (saved bootstrap.Profile, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("profile_id", saved.ID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Create profile failed", args...)
return
}
lm.logger.Info("Create profile completed successfully", args...)
}(time.Now())
return lm.svc.CreateProfile(ctx, session, p)
}
func (lm *loggingMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (p bootstrap.Profile, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("profile_id", profileID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("View profile failed", args...)
return
}
lm.logger.Info("View profile completed successfully", args...)
}(time.Now())
return lm.svc.ViewProfile(ctx, session, profileID)
}
func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (updated bootstrap.Profile, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("profile_id", p.ID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Update profile failed", args...)
return
}
lm.logger.Info("Update profile completed successfully", args...)
}(time.Now())
return lm.svc.UpdateProfile(ctx, session, p)
}
func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (page bootstrap.ProfilesPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Uint64("offset", offset),
slog.Uint64("limit", limit),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List profiles failed", args...)
return
}
lm.logger.Info("List profiles completed successfully", args...)
}(time.Now())
return lm.svc.ListProfiles(ctx, session, offset, limit, name)
}
func (lm *loggingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("profile_id", profileID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Delete profile failed", args...)
return
}
lm.logger.Info("Delete profile completed successfully", args...)
}(time.Now())
return lm.svc.DeleteProfile(ctx, session, profileID)
}
func (lm *loggingMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", configID),
slog.String("profile_id", profileID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Assign profile failed", args...)
return
}
lm.logger.Info("Assign profile completed successfully", args...)
}(time.Now())
return lm.svc.AssignProfile(ctx, session, configID, profileID)
}
func (lm *loggingMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", configID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Bind resources failed", args...)
return
}
lm.logger.Info("Bind resources completed successfully", args...)
}(time.Now())
return lm.svc.BindResources(ctx, session, token, configID, bindings)
}
func (lm *loggingMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) (snapshots []bootstrap.BindingSnapshot, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", configID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List bindings failed", args...)
return
}
lm.logger.Info("List bindings completed successfully", args...)
}(time.Now())
return lm.svc.ListBindings(ctx, session, configID)
}
func (lm *loggingMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("config_id", configID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Refresh bindings failed", args...)
return
}
lm.logger.Info("Refresh bindings completed successfully", args...)
}(time.Now())
return lm.svc.RefreshBindings(ctx, session, token, configID)
}
-192
View File
@@ -1,192 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
//go:build !test
package middleware
import (
"context"
"time"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/go-kit/kit/metrics"
)
var _ bootstrap.Service = (*metricsMiddleware)(nil)
type metricsMiddleware struct {
counter metrics.Counter
latency metrics.Histogram
svc bootstrap.Service
}
// MetricsMiddleware instruments core service by tracking request count and latency.
func MetricsMiddleware(svc bootstrap.Service, counter metrics.Counter, latency metrics.Histogram) bootstrap.Service {
return &metricsMiddleware{
counter: counter,
latency: latency,
svc: svc,
}
}
// Add instruments Add method with metrics.
func (mm *metricsMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "add").Add(1)
mm.latency.With("method", "add").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Add(ctx, session, token, cfg)
}
// View instruments View method with metrics.
func (mm *metricsMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "view").Add(1)
mm.latency.With("method", "view").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.View(ctx, session, id)
}
// Update instruments Update method with metrics.
func (mm *metricsMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "update").Add(1)
mm.latency.With("method", "update").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Update(ctx, session, cfg)
}
// UpdateCert instruments UpdateCert method with metrics.
func (mm *metricsMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "update_cert").Add(1)
mm.latency.With("method", "update_cert").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
}
// List instruments List method with metrics.
func (mm *metricsMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (saved bootstrap.ConfigsPage, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "list").Add(1)
mm.latency.With("method", "list").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.List(ctx, session, filter, offset, limit)
}
// Remove instruments Remove method with metrics.
func (mm *metricsMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "remove").Add(1)
mm.latency.With("method", "remove").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Remove(ctx, session, id)
}
// Bootstrap instruments Bootstrap method with metrics.
func (mm *metricsMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "bootstrap").Add(1)
mm.latency.With("method", "bootstrap").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.Bootstrap(ctx, externalKey, externalID, secure)
}
func (mm *metricsMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
defer func(begin time.Time) {
mm.counter.With("method", "enable_config").Add(1)
mm.latency.With("method", "enable_config").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.EnableConfig(ctx, session, id)
}
func (mm *metricsMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
defer func(begin time.Time) {
mm.counter.With("method", "disable_config").Add(1)
mm.latency.With("method", "disable_config").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.DisableConfig(ctx, session, id)
}
func (mm *metricsMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
defer func(begin time.Time) {
mm.counter.With("method", "create_profile").Add(1)
mm.latency.With("method", "create_profile").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.CreateProfile(ctx, session, p)
}
func (mm *metricsMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
defer func(begin time.Time) {
mm.counter.With("method", "view_profile").Add(1)
mm.latency.With("method", "view_profile").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ViewProfile(ctx, session, profileID)
}
func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
defer func(begin time.Time) {
mm.counter.With("method", "update_profile").Add(1)
mm.latency.With("method", "update_profile").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.UpdateProfile(ctx, session, p)
}
func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_profiles").Add(1)
mm.latency.With("method", "list_profiles").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ListProfiles(ctx, session, offset, limit, name)
}
func (mm *metricsMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
defer func(begin time.Time) {
mm.counter.With("method", "delete_profile").Add(1)
mm.latency.With("method", "delete_profile").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.DeleteProfile(ctx, session, profileID)
}
func (mm *metricsMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
defer func(begin time.Time) {
mm.counter.With("method", "assign_profile").Add(1)
mm.latency.With("method", "assign_profile").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.AssignProfile(ctx, session, configID, profileID)
}
func (mm *metricsMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
defer func(begin time.Time) {
mm.counter.With("method", "bind_resources").Add(1)
mm.latency.With("method", "bind_resources").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.BindResources(ctx, session, token, configID, bindings)
}
func (mm *metricsMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_bindings").Add(1)
mm.latency.With("method", "list_bindings").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ListBindings(ctx, session, configID)
}
func (mm *metricsMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
defer func(begin time.Time) {
mm.counter.With("method", "refresh_bindings").Add(1)
mm.latency.With("method", "refresh_bindings").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.RefreshBindings(ctx, session, token, configID)
}
-109
View File
@@ -1,109 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"github.com/absmach/magistrala/bootstrap"
mock "github.com/stretchr/testify/mock"
)
// NewConfigReader creates a new instance of ConfigReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewConfigReader(t interface {
mock.TestingT
Cleanup(func())
}) *ConfigReader {
mock := &ConfigReader{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// ConfigReader is an autogenerated mock type for the ConfigReader type
type ConfigReader struct {
mock.Mock
}
type ConfigReader_Expecter struct {
mock *mock.Mock
}
func (_m *ConfigReader) EXPECT() *ConfigReader_Expecter {
return &ConfigReader_Expecter{mock: &_m.Mock}
}
// ReadConfig provides a mock function for the type ConfigReader
func (_mock *ConfigReader) ReadConfig(config bootstrap.Config, b bool) (any, error) {
ret := _mock.Called(config, b)
if len(ret) == 0 {
panic("no return value specified for ReadConfig")
}
var r0 any
var r1 error
if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) (any, error)); ok {
return returnFunc(config, b)
}
if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) any); ok {
r0 = returnFunc(config, b)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(any)
}
}
if returnFunc, ok := ret.Get(1).(func(bootstrap.Config, bool) error); ok {
r1 = returnFunc(config, b)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ConfigReader_ReadConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadConfig'
type ConfigReader_ReadConfig_Call struct {
*mock.Call
}
// ReadConfig is a helper method to define mock.On call
// - config bootstrap.Config
// - b bool
func (_e *ConfigReader_Expecter) ReadConfig(config interface{}, b interface{}) *ConfigReader_ReadConfig_Call {
return &ConfigReader_ReadConfig_Call{Call: _e.mock.On("ReadConfig", config, b)}
}
func (_c *ConfigReader_ReadConfig_Call) Run(run func(config bootstrap.Config, b bool)) *ConfigReader_ReadConfig_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 bootstrap.Config
if args[0] != nil {
arg0 = args[0].(bootstrap.Config)
}
var arg1 bool
if args[1] != nil {
arg1 = args[1].(bool)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *ConfigReader_ReadConfig_Call) Return(v any, err error) *ConfigReader_ReadConfig_Call {
_c.Call.Return(v, err)
return _c
}
func (_c *ConfigReader_ReadConfig_Call) RunAndReturn(run func(config bootstrap.Config, b bool) (any, error)) *ConfigReader_ReadConfig_Call {
_c.Call.Return(run)
return _c
}
-670
View File
@@ -1,670 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
"github.com/absmach/magistrala/bootstrap"
mock "github.com/stretchr/testify/mock"
)
// NewConfigRepository creates a new instance of ConfigRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewConfigRepository(t interface {
mock.TestingT
Cleanup(func())
}) *ConfigRepository {
mock := &ConfigRepository{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// ConfigRepository is an autogenerated mock type for the ConfigRepository type
type ConfigRepository struct {
mock.Mock
}
type ConfigRepository_Expecter struct {
mock *mock.Mock
}
func (_m *ConfigRepository) EXPECT() *ConfigRepository_Expecter {
return &ConfigRepository_Expecter{mock: &_m.Mock}
}
// AssignProfile provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) AssignProfile(ctx context.Context, domainID string, id string, profileID string) error {
ret := _mock.Called(ctx, domainID, id, profileID)
if len(ret) == 0 {
panic("no return value specified for AssignProfile")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok {
r0 = returnFunc(ctx, domainID, id, profileID)
} else {
r0 = ret.Error(0)
}
return r0
}
// ConfigRepository_AssignProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignProfile'
type ConfigRepository_AssignProfile_Call struct {
*mock.Call
}
// AssignProfile is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - id string
// - profileID string
func (_e *ConfigRepository_Expecter) AssignProfile(ctx interface{}, domainID interface{}, id interface{}, profileID interface{}) *ConfigRepository_AssignProfile_Call {
return &ConfigRepository_AssignProfile_Call{Call: _e.mock.On("AssignProfile", ctx, domainID, id, profileID)}
}
func (_c *ConfigRepository_AssignProfile_Call) Run(run func(ctx context.Context, domainID string, id string, profileID string)) *ConfigRepository_AssignProfile_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
run(
arg0,
arg1,
arg2,
arg3,
)
})
return _c
}
func (_c *ConfigRepository_AssignProfile_Call) Return(err error) *ConfigRepository_AssignProfile_Call {
_c.Call.Return(err)
return _c
}
func (_c *ConfigRepository_AssignProfile_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, profileID string) error) *ConfigRepository_AssignProfile_Call {
_c.Call.Return(run)
return _c
}
// ChangeStatus provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) ChangeStatus(ctx context.Context, domainID string, id string, status bootstrap.Status) error {
ret := _mock.Called(ctx, domainID, id, status)
if len(ret) == 0 {
panic("no return value specified for ChangeStatus")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, bootstrap.Status) error); ok {
r0 = returnFunc(ctx, domainID, id, status)
} else {
r0 = ret.Error(0)
}
return r0
}
// ConfigRepository_ChangeStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangeStatus'
type ConfigRepository_ChangeStatus_Call struct {
*mock.Call
}
// ChangeStatus is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - id string
// - status bootstrap.Status
func (_e *ConfigRepository_Expecter) ChangeStatus(ctx interface{}, domainID interface{}, id interface{}, status interface{}) *ConfigRepository_ChangeStatus_Call {
return &ConfigRepository_ChangeStatus_Call{Call: _e.mock.On("ChangeStatus", ctx, domainID, id, status)}
}
func (_c *ConfigRepository_ChangeStatus_Call) Run(run func(ctx context.Context, domainID string, id string, status bootstrap.Status)) *ConfigRepository_ChangeStatus_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 bootstrap.Status
if args[3] != nil {
arg3 = args[3].(bootstrap.Status)
}
run(
arg0,
arg1,
arg2,
arg3,
)
})
return _c
}
func (_c *ConfigRepository_ChangeStatus_Call) Return(err error) *ConfigRepository_ChangeStatus_Call {
_c.Call.Return(err)
return _c
}
func (_c *ConfigRepository_ChangeStatus_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, status bootstrap.Status) error) *ConfigRepository_ChangeStatus_Call {
_c.Call.Return(run)
return _c
}
// Remove provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) Remove(ctx context.Context, domainID string, id string) error {
ret := _mock.Called(ctx, domainID, id)
if len(ret) == 0 {
panic("no return value specified for Remove")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = returnFunc(ctx, domainID, id)
} else {
r0 = ret.Error(0)
}
return r0
}
// ConfigRepository_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove'
type ConfigRepository_Remove_Call struct {
*mock.Call
}
// Remove is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - id string
func (_e *ConfigRepository_Expecter) Remove(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_Remove_Call {
return &ConfigRepository_Remove_Call{Call: _e.mock.On("Remove", ctx, domainID, id)}
}
func (_c *ConfigRepository_Remove_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_Remove_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *ConfigRepository_Remove_Call) Return(err error) *ConfigRepository_Remove_Call {
_c.Call.Return(err)
return _c
}
func (_c *ConfigRepository_Remove_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) error) *ConfigRepository_Remove_Call {
_c.Call.Return(run)
return _c
}
// RetrieveAll provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) RetrieveAll(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
ret := _mock.Called(ctx, domainID, filter, offset, limit)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 bootstrap.ConfigsPage
if returnFunc, ok := ret.Get(0).(func(context.Context, string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
r0 = returnFunc(ctx, domainID, filter, offset, limit)
} else {
r0 = ret.Get(0).(bootstrap.ConfigsPage)
}
return r0
}
// ConfigRepository_RetrieveAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAll'
type ConfigRepository_RetrieveAll_Call struct {
*mock.Call
}
// RetrieveAll is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - filter bootstrap.Filter
// - offset uint64
// - limit uint64
func (_e *ConfigRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, filter interface{}, offset interface{}, limit interface{}) *ConfigRepository_RetrieveAll_Call {
return &ConfigRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, filter, offset, limit)}
}
func (_c *ConfigRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64)) *ConfigRepository_RetrieveAll_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 bootstrap.Filter
if args[2] != nil {
arg2 = args[2].(bootstrap.Filter)
}
var arg3 uint64
if args[3] != nil {
arg3 = args[3].(uint64)
}
var arg4 uint64
if args[4] != nil {
arg4 = args[4].(uint64)
}
run(
arg0,
arg1,
arg2,
arg3,
arg4,
)
})
return _c
}
func (_c *ConfigRepository_RetrieveAll_Call) Return(configsPage bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call {
_c.Call.Return(configsPage)
return _c
}
func (_c *ConfigRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call {
_c.Call.Return(run)
return _c
}
// RetrieveByExternalID provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
ret := _mock.Called(ctx, externalID)
if len(ret) == 0 {
panic("no return value specified for RetrieveByExternalID")
}
var r0 bootstrap.Config
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bootstrap.Config, error)); ok {
return returnFunc(ctx, externalID)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) bootstrap.Config); ok {
r0 = returnFunc(ctx, externalID)
} else {
r0 = ret.Get(0).(bootstrap.Config)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, externalID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ConfigRepository_RetrieveByExternalID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByExternalID'
type ConfigRepository_RetrieveByExternalID_Call struct {
*mock.Call
}
// RetrieveByExternalID is a helper method to define mock.On call
// - ctx context.Context
// - externalID string
func (_e *ConfigRepository_Expecter) RetrieveByExternalID(ctx interface{}, externalID interface{}) *ConfigRepository_RetrieveByExternalID_Call {
return &ConfigRepository_RetrieveByExternalID_Call{Call: _e.mock.On("RetrieveByExternalID", ctx, externalID)}
}
func (_c *ConfigRepository_RetrieveByExternalID_Call) Run(run func(ctx context.Context, externalID string)) *ConfigRepository_RetrieveByExternalID_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *ConfigRepository_RetrieveByExternalID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByExternalID_Call {
_c.Call.Return(config, err)
return _c
}
func (_c *ConfigRepository_RetrieveByExternalID_Call) RunAndReturn(run func(ctx context.Context, externalID string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByExternalID_Call {
_c.Call.Return(run)
return _c
}
// RetrieveByID provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) RetrieveByID(ctx context.Context, domainID string, id string) (bootstrap.Config, error) {
ret := _mock.Called(ctx, domainID, id)
if len(ret) == 0 {
panic("no return value specified for RetrieveByID")
}
var r0 bootstrap.Config
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (bootstrap.Config, error)); ok {
return returnFunc(ctx, domainID, id)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) bootstrap.Config); ok {
r0 = returnFunc(ctx, domainID, id)
} else {
r0 = ret.Get(0).(bootstrap.Config)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = returnFunc(ctx, domainID, id)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ConfigRepository_RetrieveByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByID'
type ConfigRepository_RetrieveByID_Call struct {
*mock.Call
}
// RetrieveByID is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - id string
func (_e *ConfigRepository_Expecter) RetrieveByID(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_RetrieveByID_Call {
return &ConfigRepository_RetrieveByID_Call{Call: _e.mock.On("RetrieveByID", ctx, domainID, id)}
}
func (_c *ConfigRepository_RetrieveByID_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_RetrieveByID_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *ConfigRepository_RetrieveByID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByID_Call {
_c.Call.Return(config, err)
return _c
}
func (_c *ConfigRepository_RetrieveByID_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByID_Call {
_c.Call.Return(run)
return _c
}
// Save provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) Save(ctx context.Context, cfg bootstrap.Config) (string, error) {
ret := _mock.Called(ctx, cfg)
if len(ret) == 0 {
panic("no return value specified for Save")
}
var r0 string
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) (string, error)); ok {
return returnFunc(ctx, cfg)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) string); ok {
r0 = returnFunc(ctx, cfg)
} else {
r0 = ret.Get(0).(string)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, bootstrap.Config) error); ok {
r1 = returnFunc(ctx, cfg)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ConfigRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save'
type ConfigRepository_Save_Call struct {
*mock.Call
}
// Save is a helper method to define mock.On call
// - ctx context.Context
// - cfg bootstrap.Config
func (_e *ConfigRepository_Expecter) Save(ctx interface{}, cfg interface{}) *ConfigRepository_Save_Call {
return &ConfigRepository_Save_Call{Call: _e.mock.On("Save", ctx, cfg)}
}
func (_c *ConfigRepository_Save_Call) Run(run func(ctx context.Context, cfg bootstrap.Config)) *ConfigRepository_Save_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 bootstrap.Config
if args[1] != nil {
arg1 = args[1].(bootstrap.Config)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *ConfigRepository_Save_Call) Return(s string, err error) *ConfigRepository_Save_Call {
_c.Call.Return(s, err)
return _c
}
func (_c *ConfigRepository_Save_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config) (string, error)) *ConfigRepository_Save_Call {
_c.Call.Return(run)
return _c
}
// Update provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
ret := _mock.Called(ctx, cfg)
if len(ret) == 0 {
panic("no return value specified for Update")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) error); ok {
r0 = returnFunc(ctx, cfg)
} else {
r0 = ret.Error(0)
}
return r0
}
// ConfigRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update'
type ConfigRepository_Update_Call struct {
*mock.Call
}
// Update is a helper method to define mock.On call
// - ctx context.Context
// - cfg bootstrap.Config
func (_e *ConfigRepository_Expecter) Update(ctx interface{}, cfg interface{}) *ConfigRepository_Update_Call {
return &ConfigRepository_Update_Call{Call: _e.mock.On("Update", ctx, cfg)}
}
func (_c *ConfigRepository_Update_Call) Run(run func(ctx context.Context, cfg bootstrap.Config)) *ConfigRepository_Update_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 bootstrap.Config
if args[1] != nil {
arg1 = args[1].(bootstrap.Config)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *ConfigRepository_Update_Call) Return(err error) *ConfigRepository_Update_Call {
_c.Call.Return(err)
return _c
}
func (_c *ConfigRepository_Update_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config) error) *ConfigRepository_Update_Call {
_c.Call.Return(run)
return _c
}
// UpdateCert provides a mock function for the type ConfigRepository
func (_mock *ConfigRepository) UpdateCert(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
ret := _mock.Called(ctx, domainID, id, clientCert, clientKey, caCert)
if len(ret) == 0 {
panic("no return value specified for UpdateCert")
}
var r0 bootstrap.Config
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (bootstrap.Config, error)); ok {
return returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) bootstrap.Config); ok {
r0 = returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
} else {
r0 = ret.Get(0).(bootstrap.Config)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok {
r1 = returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ConfigRepository_UpdateCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCert'
type ConfigRepository_UpdateCert_Call struct {
*mock.Call
}
// UpdateCert is a helper method to define mock.On call
// - ctx context.Context
// - domainID string
// - id string
// - clientCert string
// - clientKey string
// - caCert string
func (_e *ConfigRepository_Expecter) UpdateCert(ctx interface{}, domainID interface{}, id interface{}, clientCert interface{}, clientKey interface{}, caCert interface{}) *ConfigRepository_UpdateCert_Call {
return &ConfigRepository_UpdateCert_Call{Call: _e.mock.On("UpdateCert", ctx, domainID, id, clientCert, clientKey, caCert)}
}
func (_c *ConfigRepository_UpdateCert_Call) Run(run func(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string)) *ConfigRepository_UpdateCert_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
var arg4 string
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 string
if args[5] != nil {
arg5 = args[5].(string)
}
run(
arg0,
arg1,
arg2,
arg3,
arg4,
arg5,
)
})
return _c
}
func (_c *ConfigRepository_UpdateCert_Call) Return(config bootstrap.Config, err error) *ConfigRepository_UpdateCert_Call {
_c.Call.Return(config, err)
return _c
}
func (_c *ConfigRepository_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error)) *ConfigRepository_UpdateCert_Call {
_c.Call.Return(run)
return _c
}
File diff suppressed because it is too large Load Diff
-420
View File
@@ -1,420 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"strings"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/absmach/magistrala/pkg/postgres"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
)
const jsonNull = "null"
var _ bootstrap.ConfigRepository = (*configRepository)(nil)
type configRepository struct {
db postgres.Database
log *slog.Logger
}
// NewConfigRepository instantiates a PostgreSQL implementation of config
// repository.
func NewConfigRepository(db postgres.Database, log *slog.Logger) bootstrap.ConfigRepository {
return &configRepository{db: db, log: log}
}
func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config) (string, error) {
q := `INSERT INTO configs (id, domain_id, name, client_cert, client_key, ca_cert, external_id, external_key, content, status, profile_id, render_context)
VALUES (:id, :domain_id, :name, :client_cert, :client_key, :ca_cert, :external_id, :external_key, :content, :status, :profile_id, :render_context)`
dbcfg, err := toDBConfig(cfg)
if err != nil {
return "", errors.Wrap(repoerr.ErrCreateEntity, err)
}
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
switch pgErr := err.(type) {
case *pgconn.PgError:
if pgErr.Code == pgerrcode.UniqueViolation {
return "", repoerr.ErrConflict
}
}
return "", errors.Wrap(repoerr.ErrCreateEntity, err)
}
return cfg.ID, nil
}
func (cr configRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Config, error) {
q := `SELECT id, external_id, name, content, status, client_cert, client_key, ca_cert, profile_id, render_context
FROM configs
WHERE id = :id AND domain_id = :domain_id`
dbcfg := dbConfig{
ID: id,
DomainID: domainID,
}
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
if !row.Next() {
return bootstrap.Config{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbcfg); err != nil {
return bootstrap.Config{}, err
}
cfg, err := toConfig(dbcfg)
if err != nil {
return bootstrap.Config{}, err
}
return cfg, nil
}
func (cr configRepository) RetrieveAll(ctx context.Context, domainID string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
search, params := buildRetrieveQueryParams(domainID, filter)
n := len(params)
q := `SELECT id, external_id, name, content, status, profile_id, render_context
FROM configs %s ORDER BY id LIMIT $%d OFFSET $%d`
q = fmt.Sprintf(q, search, n+1, n+2)
rows, err := cr.db.QueryContext(ctx, q, append(params, limit, offset)...)
if err != nil {
cr.log.Error(fmt.Sprintf("Failed to retrieve configs due to %s", err))
return bootstrap.ConfigsPage{}
}
defer rows.Close()
var name, content, profileID sql.NullString
var renderContext []byte
configs := []bootstrap.Config{}
for rows.Next() {
c := bootstrap.Config{DomainID: domainID}
if err := rows.Scan(&c.ID, &c.ExternalID, &name, &content, &c.Status, &profileID, &renderContext); err != nil {
cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err))
return bootstrap.ConfigsPage{}
}
c.Name = name.String
c.Content = content.String
if profileID.Valid {
c.ProfileID = profileID.String
}
if len(renderContext) > 0 && string(renderContext) != jsonNull {
if err := json.Unmarshal(renderContext, &c.RenderContext); err != nil {
cr.log.Error(fmt.Sprintf("Failed to decode render context due to %s", err))
return bootstrap.ConfigsPage{}
}
}
configs = append(configs, c)
}
q = fmt.Sprintf(`SELECT COUNT(*) FROM configs %s`, search)
var total uint64
if err := cr.db.QueryRowxContext(ctx, q, params...).Scan(&total); err != nil {
cr.log.Error(fmt.Sprintf("Failed to count configs due to %s", err))
return bootstrap.ConfigsPage{}
}
return bootstrap.ConfigsPage{
Total: total,
Limit: limit,
Offset: offset,
Configs: configs,
}
}
func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
q := `SELECT id, external_key, domain_id, name, client_cert, client_key, ca_cert, content, status, profile_id, render_context
FROM configs
WHERE external_id = :external_id`
dbcfg := dbConfig{
ExternalID: externalID,
}
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
if !row.Next() {
return bootstrap.Config{}, repoerr.ErrNotFound
}
if err := row.StructScan(&dbcfg); err != nil {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
cfg, err := toConfig(dbcfg)
if err != nil {
return bootstrap.Config{}, err
}
return cfg, nil
}
func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
q := `UPDATE configs SET name = :name, content = :content, render_context = :render_context WHERE id = :id AND domain_id = :domain_id `
renderContext, err := json.Marshal(cfg.RenderContext)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
dbcfg := dbConfig{
Name: nullString(cfg.Name),
Content: nullString(cfg.Content),
RenderContext: renderContext,
ID: cfg.ID,
DomainID: cfg.DomainID,
}
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if cnt == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (cr configRepository) AssignProfile(ctx context.Context, domainID, id, profileID string) error {
q := `UPDATE configs SET profile_id = :profile_id WHERE id = :id AND domain_id = :domain_id`
dbcfg := dbConfig{
ID: id,
DomainID: domainID,
ProfileID: nullString(profileID),
}
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if cnt == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (cr configRepository) UpdateCert(ctx context.Context, domainID, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE id = :id AND domain_id = :domain_id
RETURNING id, client_cert, client_key, ca_cert, domain_id`
dbcfg := dbConfig{
ID: id,
ClientCert: nullString(clientCert),
DomainID: domainID,
ClientKey: nullString(clientKey),
CaCert: nullString(caCert),
}
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
if err != nil {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
if ok := row.Next(); !ok {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrNotFound, row.Err())
}
if err := row.StructScan(&dbcfg); err != nil {
return bootstrap.Config{}, err
}
cfg, err := toConfig(dbcfg)
if err != nil {
return bootstrap.Config{}, err
}
return cfg, nil
}
func (cr configRepository) Remove(ctx context.Context, domainID, id string) error {
q := `DELETE FROM configs WHERE id = :id AND domain_id = :domain_id`
dbcfg := dbConfig{
ID: id,
DomainID: domainID,
}
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (cr configRepository) ChangeStatus(ctx context.Context, domainID, id string, status bootstrap.Status) error {
q := `UPDATE configs SET status = :status WHERE id = :id AND domain_id = :domain_id;`
dbcfg := dbConfig{
ID: id,
Status: status,
DomainID: domainID,
}
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
cnt, err := res.RowsAffected()
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if cnt == 0 {
return repoerr.ErrNotFound
}
return nil
}
func buildRetrieveQueryParams(domainID string, filter bootstrap.Filter) (string, []any) {
params := []any{}
queries := []string{}
if domainID != "" {
params = append(params, domainID)
queries = append(queries, fmt.Sprintf("domain_id = $%d", len(params)))
}
counter := len(params) + 1
for k, v := range filter.FullMatch {
if k == "status" {
status, err := bootstrap.ToStatus(v)
if err != nil {
return "", nil
}
if status == bootstrap.AllStatus {
continue
}
params = append(params, status)
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
counter++
continue
}
params = append(params, v)
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
counter++
}
for k, v := range filter.PartialMatch {
params = append(params, v)
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
counter++
}
if len(queries) > 0 {
return "WHERE " + strings.Join(queries, " AND "), params
}
return "", params
}
func nullString(s string) sql.NullString {
if s == "" {
return sql.NullString{}
}
return sql.NullString{String: s, Valid: true}
}
type dbConfig struct {
DomainID string `db:"domain_id"`
ID string `db:"id"`
Name sql.NullString `db:"name"`
ClientCert sql.NullString `db:"client_cert"`
ClientKey sql.NullString `db:"client_key"`
CaCert sql.NullString `db:"ca_cert"`
ExternalID string `db:"external_id"`
ExternalKey string `db:"external_key"`
Content sql.NullString `db:"content"`
Status bootstrap.Status `db:"status"`
ProfileID sql.NullString `db:"profile_id"`
RenderContext []byte `db:"render_context"`
}
func toDBConfig(cfg bootstrap.Config) (dbConfig, error) {
renderContext, err := json.Marshal(cfg.RenderContext)
if err != nil {
return dbConfig{}, err
}
return dbConfig{
ID: cfg.ID,
DomainID: cfg.DomainID,
Name: nullString(cfg.Name),
ClientCert: nullString(cfg.ClientCert),
ClientKey: nullString(cfg.ClientKey),
CaCert: nullString(cfg.CACert),
ExternalID: cfg.ExternalID,
ExternalKey: cfg.ExternalKey,
Content: nullString(cfg.Content),
Status: cfg.Status,
ProfileID: nullString(cfg.ProfileID),
RenderContext: renderContext,
}, nil
}
func toConfig(dbcfg dbConfig) (bootstrap.Config, error) {
cfg := bootstrap.Config{
ID: dbcfg.ID,
DomainID: dbcfg.DomainID,
ExternalID: dbcfg.ExternalID,
ExternalKey: dbcfg.ExternalKey,
Status: dbcfg.Status,
}
if dbcfg.ProfileID.Valid {
cfg.ProfileID = dbcfg.ProfileID.String
}
if dbcfg.Name.Valid {
cfg.Name = dbcfg.Name.String
}
if dbcfg.Content.Valid {
cfg.Content = dbcfg.Content.String
}
if len(dbcfg.RenderContext) > 0 && string(dbcfg.RenderContext) != jsonNull {
if err := json.Unmarshal(dbcfg.RenderContext, &cfg.RenderContext); err != nil {
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
}
if dbcfg.ClientCert.Valid {
cfg.ClientCert = dbcfg.ClientCert.String
}
if dbcfg.ClientKey.Valid {
cfg.ClientKey = dbcfg.ClientKey.String
}
if dbcfg.CaCert.Valid {
cfg.CACert = dbcfg.CaCert.String
}
return cfg, nil
}
-471
View File
@@ -1,471 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"context"
"fmt"
"strconv"
"testing"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/bootstrap/postgres"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/gofrs/uuid/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const numConfigs = 10
var config = bootstrap.Config{
ID: "mg-client",
ExternalID: "external-id",
ExternalKey: "external-key",
DomainID: testsutil.GenerateUUID(&testing.T{}),
Content: "content",
Status: bootstrap.Inactive,
}
func TestSave(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
diff := "different"
duplicateClient := config
duplicateClient.ExternalID = diff
duplicateExternal := config
duplicateExternal.ID = diff
cases := []struct {
desc string
config bootstrap.Config
err error
}{
{
desc: "save a config",
config: config,
err: nil,
},
{
desc: "save config with same Client ID",
config: duplicateClient,
err: repoerr.ErrConflict,
},
{
desc: "save config with same external ID",
config: duplicateExternal,
err: repoerr.ErrConflict,
},
}
for _, tc := range cases {
id, err := repo.Save(context.Background(), tc.config)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, id, tc.config.ID, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.config.ID, id))
}
}
}
func TestRetrieveByID(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
id, err := repo.Save(context.Background(), c)
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
nonexistentConfID, err := uuid.NewV4()
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
cases := []struct {
desc string
domainID string
id string
err error
}{
{
desc: "retrieve config",
domainID: c.DomainID,
id: id,
err: nil,
},
{
desc: "retrieve config with wrong domain ID ",
domainID: "2",
id: id,
err: repoerr.ErrNotFound,
},
{
desc: "retrieve a non-existing config",
domainID: c.DomainID,
id: nonexistentConfID.String(),
err: repoerr.ErrNotFound,
},
{
desc: "retrieve a config with invalid ID",
domainID: c.DomainID,
id: "invalid",
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
_, err := repo.RetrieveByID(context.Background(), tc.domainID, tc.id)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestRetrieveAll(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
for i := 0; i < numConfigs; i++ {
c := config
// Use UUID to prevent conflict errors.
uid, err := uuid.NewV4()
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ExternalID = uid.String()
c.Name = fmt.Sprintf("name %d", i)
c.ID = uid.String()
if i%2 == 0 {
c.Status = bootstrap.Active
}
_, err = repo.Save(context.Background(), c)
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
}
cases := []struct {
desc string
domainID string
offset uint64
limit uint64
filter bootstrap.Filter
size int
}{
{
desc: "retrieve all configs",
domainID: config.DomainID,
offset: 0,
limit: uint64(numConfigs),
size: numConfigs,
},
{
desc: "retrieve a subset of configs",
domainID: config.DomainID,
offset: 5,
limit: uint64(numConfigs - 5),
size: numConfigs - 5,
},
{
desc: "retrieve with wrong domain ID ",
domainID: "2",
offset: 0,
limit: uint64(numConfigs),
size: 0,
},
{
desc: "retrieve all active configs ",
domainID: config.DomainID,
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{FullMatch: map[string]string{"status": bootstrap.Active.String()}},
size: numConfigs / 2,
},
{
desc: "retrieve all with partial match filter",
domainID: config.DomainID,
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
size: 1,
},
{
desc: "retrieve search by name",
domainID: config.DomainID,
offset: 0,
limit: uint64(numConfigs),
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
size: 1,
},
}
for _, tc := range cases {
ret := repo.RetrieveAll(context.Background(), tc.domainID, tc.filter, tc.offset, tc.limit)
size := len(ret.Configs)
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
}
}
func TestRetrieveByExternalID(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
_, err = repo.Save(context.Background(), c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
cases := []struct {
desc string
externalID string
err error
}{
{
desc: "retrieve with invalid external ID",
externalID: strconv.Itoa(numConfigs + 1),
err: repoerr.ErrNotFound,
},
{
desc: "retrieve with external key",
externalID: c.ExternalID,
err: nil,
},
}
for _, tc := range cases {
_, err := repo.RetrieveByExternalID(context.Background(), tc.externalID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestUpdate(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
_, err = repo.Save(context.Background(), c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
c.Content = "new content"
c.Name = "new name"
withRenderContext := c
withRenderContext.RenderContext = map[string]any{
"site": "warehouse-2",
"region": "mombasa",
}
wrongDomainID := c
wrongDomainID.DomainID = "3"
cases := []struct {
desc string
config bootstrap.Config
renderContext map[string]any
err error
}{
{
desc: "update with wrong domainID",
config: wrongDomainID,
err: repoerr.ErrNotFound,
},
{
desc: "update a config",
config: c,
err: nil,
},
{
desc: "update a config render_context",
config: withRenderContext,
renderContext: map[string]any{"site": "warehouse-2", "region": "mombasa"},
err: nil,
},
}
for _, tc := range cases {
err := repo.Update(context.Background(), tc.config)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if tc.err == nil && tc.renderContext != nil {
saved, err := repo.RetrieveByID(context.Background(), tc.config.DomainID, tc.config.ID)
require.Nil(t, err, fmt.Sprintf("%s: unexpected retrieve error: %s\n", tc.desc, err))
assert.Equal(t, tc.renderContext, saved.RenderContext, fmt.Sprintf("%s: expected render_context %v got %v\n", tc.desc, tc.renderContext, saved.RenderContext))
}
}
}
func TestUpdateCert(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
_, err = repo.Save(context.Background(), c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
c.Content = "new content"
c.Name = "new name"
wrongDomainID := c
wrongDomainID.DomainID = "3"
cases := []struct {
desc string
configID string
domainID string
cert string
certKey string
ca string
expectedConfig bootstrap.Config
err error
}{
{
desc: "update with wrong domain ID ",
configID: "",
cert: "cert",
certKey: "certKey",
ca: "",
domainID: wrongDomainID.DomainID,
expectedConfig: bootstrap.Config{},
err: repoerr.ErrNotFound,
},
{
desc: "update a config",
configID: c.ID,
cert: "cert",
certKey: "certKey",
ca: "ca",
domainID: c.DomainID,
expectedConfig: bootstrap.Config{
ID: c.ID,
ClientCert: "cert",
CACert: "ca",
ClientKey: "certKey",
DomainID: c.DomainID,
},
err: nil,
},
}
for _, tc := range cases {
cfg, err := repo.UpdateCert(context.Background(), tc.domainID, tc.configID, tc.cert, tc.certKey, tc.ca)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg))
}
}
func TestRemove(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
id, err := repo.Save(context.Background(), c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
// Removal works the same for both existing and non-existing
// (removed) config
for i := 0; i < 2; i++ {
err := repo.Remove(context.Background(), c.DomainID, id)
assert.Nil(t, err, fmt.Sprintf("%d: failed to remove config due to: %s", i, err))
_, err = repo.RetrieveByID(context.Background(), c.DomainID, id)
assert.True(t, errors.Contains(err, repoerr.ErrNotFound), fmt.Sprintf("%d: expected %s got %s", i, repoerr.ErrNotFound, err))
}
}
func TestChangeStatus(t *testing.T) {
repo := postgres.NewConfigRepository(db, testLog)
c := config
// Use UUID to prevent conflicts.
uid, err := uuid.NewV4()
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
saved, err := repo.Save(context.Background(), c)
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
cases := []struct {
desc string
domainID string
id string
status bootstrap.Status
err error
}{
{
desc: "change status with wrong domain ID ",
id: saved,
domainID: "2",
err: repoerr.ErrNotFound,
},
{
desc: "change status with wrong id",
id: "wrong",
domainID: c.DomainID,
err: repoerr.ErrNotFound,
},
{
desc: "change status to Active",
id: saved,
domainID: c.DomainID,
status: bootstrap.Active,
err: nil,
},
{
desc: "change status to Inactive",
id: saved,
domainID: c.DomainID,
status: bootstrap.Inactive,
err: nil,
},
}
for _, tc := range cases {
err := repo.ChangeStatus(context.Background(), tc.domainID, tc.id, tc.status)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestAssignProfile(t *testing.T) {
configRepo := postgres.NewConfigRepository(db, testLog)
profileRepo := postgres.NewProfileRepository(db, testLog)
c := config
uid, err := uuid.NewV4()
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
c.ID = uid.String()
c.ExternalID = uid.String()
c.ExternalKey = uid.String()
saved, err := configRepo.Save(context.Background(), c)
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
profileID := testsutil.GenerateUUID(t)
_, err = profileRepo.Save(context.Background(), bootstrap.Profile{
ID: profileID,
DomainID: c.DomainID,
Name: "edge-gateway",
ContentFormat: bootstrap.ContentFormatGoTemplate,
Version: 1,
})
require.Nil(t, err, fmt.Sprintf("Saving profile expected to succeed: %s.\n", err))
err = configRepo.AssignProfile(context.Background(), c.DomainID, saved, profileID)
require.Nil(t, err, fmt.Sprintf("Assigning profile expected to succeed: %s.\n", err))
stored, err := configRepo.RetrieveByID(context.Background(), c.DomainID, saved)
require.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
assert.Equal(t, profileID, stored.ProfileID, "expected profile assignment to round-trip through the repository")
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package postgres contains repository implementations using PostgreSQL as
// the underlying database.
package postgres
-329
View File
@@ -1,329 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import migrate "github.com/rubenv/sql-migrate"
// Migration of bootstrap service.
func Migration() *migrate.MemoryMigrationSource {
return &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "configs_1",
Up: []string{
`CREATE TABLE IF NOT EXISTS configs (
mainflux_client TEXT UNIQUE NOT NULL,
owner VARCHAR(254),
name TEXT,
mainflux_key CHAR(36) UNIQUE NOT NULL,
external_id TEXT UNIQUE NOT NULL,
external_key TEXT NOT NULL,
content TEXT,
client_cert TEXT,
client_key TEXT,
ca_cert TEXT,
state BIGINT NOT NULL,
PRIMARY KEY (mainflux_client, owner)
)`,
`CREATE TABLE IF NOT EXISTS unknown_configs (
external_id TEXT UNIQUE NOT NULL,
external_key TEXT NOT NULL,
PRIMARY KEY (external_id, external_key)
)`,
`CREATE TABLE IF NOT EXISTS channels (
mainflux_channel TEXT UNIQUE NOT NULL,
owner VARCHAR(254),
name TEXT,
metadata JSON,
PRIMARY KEY (mainflux_channel, owner)
)`,
`CREATE TABLE IF NOT EXISTS connections (
channel_id TEXT,
channel_owner VARCHAR(256),
config_id TEXT,
config_owner VARCHAR(256),
FOREIGN KEY (channel_id, channel_owner) REFERENCES channels (mainflux_channel, owner) ON DELETE CASCADE ON UPDATE CASCADE,
FOREIGN KEY (config_id, config_owner) REFERENCES configs (mainflux_client, owner) ON DELETE CASCADE ON UPDATE CASCADE,
PRIMARY KEY (channel_id, channel_owner, config_id, config_owner)
)`,
},
Down: []string{
"DROP TABLE connections",
"DROP TABLE configs",
"DROP TABLE channels",
"DROP TABLE unknown_configs",
},
},
{
Id: "configs_2",
Up: []string{
"DROP TABLE IF EXISTS unknown_configs",
},
Down: []string{
"CREATE TABLE IF NOT EXISTS unknown_configs",
},
},
{
Id: "configs_3",
Up: []string{
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS parent_id VARCHAR(36)`,
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS description VARCHAR(1024)`,
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS created_at TIMESTAMP`,
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP`,
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_by VARCHAR(254)`,
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0)`,
},
},
{
Id: "configs_4",
Up: []string{
`ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_client TO magistrala_client`,
`ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_key TO magistrala_secret`,
`ALTER TABLE IF EXISTS channels RENAME COLUMN mainflux_channel TO magistrala_channel`,
},
},
{
Id: "configs_5",
Up: []string{
`ALTER TABLE IF EXISTS configs RENAME COLUMN owner TO domain_id`,
`ALTER TABLE IF EXISTS channels RENAME COLUMN owner TO domain_id`,
`ALTER TABLE IF EXISTS configs ADD CONSTRAINT configs_name_domain_id_key UNIQUE (name, domain_id)`,
},
},
{
Id: "configs_6",
Up: []string{
`ALTER TABLE IF EXISTS connections DROP CONSTRAINT IF EXISTS connections_pkey`,
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS channel_owner`,
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS config_owner`,
`ALTER TABLE IF EXISTS connections ADD COLUMN IF NOT EXISTS domain_id VARCHAR(256) NOT NULL`,
`ALTER TABLE IF EXISTS connections ADD CONSTRAINT connections_pkey PRIMARY KEY (channel_id, config_id, domain_id)`,
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (channel_id, domain_id) REFERENCES channels (magistrala_channel, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (config_id, domain_id) REFERENCES configs (magistrala_client, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
},
},
{
Id: "configs_7",
Up: []string{
`ALTER TABLE IF EXISTS configs RENAME COLUMN magistrala_client TO client_id`,
`ALTER TABLE IF EXISTS configs RENAME COLUMN magistrala_secret TO client_secret`,
`CREATE UNIQUE INDEX IF NOT EXISTS configs_client_id_key ON configs (client_id)`,
`CREATE UNIQUE INDEX IF NOT EXISTS configs_client_id_domain_id_key ON configs (client_id, domain_id)`,
`DROP TABLE IF EXISTS connections`,
`DROP TABLE IF EXISTS channels`,
},
Down: []string{
`ALTER TABLE IF EXISTS configs RENAME COLUMN client_id TO magistrala_client`,
`ALTER TABLE IF EXISTS configs RENAME COLUMN client_secret TO magistrala_secret`,
},
},
{
Id: "configs_8",
Up: []string{
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'client_id'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'id'
) THEN
ALTER TABLE configs RENAME COLUMN client_id TO id;
END IF;
END $$`,
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS client_secret`,
},
Down: []string{
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS client_secret TEXT`,
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'id'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'client_id'
) THEN
ALTER TABLE configs RENAME COLUMN id TO client_id;
END IF;
END $$`,
},
},
{
Id: "configs_10",
Up: []string{
`CREATE TABLE IF NOT EXISTS profiles (
id VARCHAR(36) PRIMARY KEY,
domain_id VARCHAR(36) NOT NULL,
name VARCHAR(1024) NOT NULL,
description TEXT,
template_format VARCHAR(64) NOT NULL DEFAULT 'go-template',
content_template TEXT,
defaults JSONB,
binding_slots JSONB,
version INT NOT NULL DEFAULT 1,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
UNIQUE (domain_id, name)
)`,
`CREATE INDEX IF NOT EXISTS idx_profiles_domain_id ON profiles (domain_id)`,
},
Down: []string{
`DROP TABLE IF EXISTS profiles`,
},
},
{
Id: "configs_11",
Up: []string{
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS profile_id VARCHAR(36) REFERENCES profiles (id) ON DELETE SET NULL`,
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS render_context JSONB`,
},
Down: []string{
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS render_context`,
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS profile_id`,
},
},
{
Id: "configs_12",
Up: []string{
`CREATE TABLE IF NOT EXISTS bindings (
config_id TEXT NOT NULL,
slot VARCHAR(256) NOT NULL,
type VARCHAR(64) NOT NULL,
resource_id TEXT NOT NULL,
snapshot JSONB,
secret_snapshot BYTEA,
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
PRIMARY KEY (config_id, slot)
)`,
`CREATE INDEX IF NOT EXISTS idx_bindings_config_id ON bindings (config_id)`,
},
Down: []string{
`DROP TABLE IF EXISTS bindings`,
},
},
{
Id: "configs_13",
Up: []string{
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'state'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'status'
) THEN
ALTER TABLE configs RENAME COLUMN state TO status;
END IF;
END $$`,
},
Down: []string{
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'status'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'configs' AND column_name = 'state'
) THEN
ALTER TABLE configs RENAME COLUMN status TO state;
END IF;
END $$`,
},
},
{
Id: "configs_14",
Up: []string{
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = 'binding_snapshots'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = 'bindings'
) THEN
ALTER TABLE binding_snapshots RENAME TO bindings;
END IF;
END $$`,
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_class
WHERE relname = 'idx_binding_snapshots_config_id'
) AND NOT EXISTS (
SELECT 1
FROM pg_class
WHERE relname = 'idx_bindings_config_id'
) THEN
ALTER INDEX idx_binding_snapshots_config_id RENAME TO idx_bindings_config_id;
END IF;
END $$`,
},
Down: []string{
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = 'bindings'
) AND NOT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = 'binding_snapshots'
) THEN
ALTER TABLE bindings RENAME TO binding_snapshots;
END IF;
END $$`,
`DO $$
BEGIN
IF EXISTS (
SELECT 1
FROM pg_class
WHERE relname = 'idx_bindings_config_id'
) AND NOT EXISTS (
SELECT 1
FROM pg_class
WHERE relname = 'idx_binding_snapshots_config_id'
) THEN
ALTER INDEX idx_bindings_config_id RENAME TO idx_binding_snapshots_config_id;
END IF;
END $$`,
},
},
{
Id: "configs_15",
Up: []string{
`ALTER TABLE IF EXISTS profiles ADD COLUMN IF NOT EXISTS binding_slots JSONB`,
},
Down: []string{
`ALTER TABLE IF EXISTS profiles DROP COLUMN IF EXISTS binding_slots`,
},
},
{
Id: "configs_16",
Up: []string{
`ALTER TABLE IF EXISTS profiles RENAME COLUMN template_format TO content_format`,
},
Down: []string{
`ALTER TABLE IF EXISTS profiles RENAME COLUMN content_format TO template_format`,
},
},
},
}
}
-88
View File
@@ -1,88 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"fmt"
"log"
"os"
"testing"
"github.com/absmach/magistrala/bootstrap/postgres"
mglog "github.com/absmach/magistrala/logger"
pgclient "github.com/absmach/magistrala/pkg/postgres"
"github.com/jmoiron/sqlx"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
var (
testLog, _ = mglog.New(os.Stdout, "info")
db *sqlx.DB
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err))
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "16.2-alpine",
Env: []string{
"POSTGRES_USER=test",
"POSTGRES_PASSWORD=test",
"POSTGRES_DB=test",
"listen_addresses = '*'",
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
port := container.GetPort("5432/tcp")
if err := pool.Retry(func() error {
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
db, err = sqlx.Open("pgx", url)
if err != nil {
return err
}
return db.Ping()
}); err != nil {
testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err))
}
dbConfig := pgclient.Config{
Host: "localhost",
Port: port,
User: "test",
Pass: "test",
Name: "test",
SSLMode: "disable",
SSLCert: "",
SSLKey: "",
SSLRootCert: "",
}
migration := postgres.Migration()
if db, err = pgclient.Setup(dbConfig, *migration); err != nil {
testLog.Error(fmt.Sprintf("Could not setup test DB connection: %s", err))
}
code := m.Run()
// Defers will not be run when using os.Exit
db.Close()
if err := pool.Purge(container); err != nil {
testLog.Error(fmt.Sprintf("Could not purge container: %s", err))
}
os.Exit(code)
}
-80
View File
@@ -1,80 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/json"
"io"
"net/http"
)
// bootstrapRes represent Magistrala Response to the Bootstrap request.
// This is used as a response from ConfigReader and can easily be
// replaced with any other response format.
type bootstrapRes struct {
ID string `json:"id,omitempty"`
Content string `json:"content,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
CACert string `json:"ca_cert,omitempty"`
}
func (res bootstrapRes) Code() int {
return http.StatusOK
}
func (res bootstrapRes) Headers() map[string]string {
return map[string]string{}
}
func (res bootstrapRes) Empty() bool {
return false
}
type reader struct {
encKey []byte
}
// NewConfigReader return new reader which is used to generate response
// from the config.
func NewConfigReader(encKey []byte) ConfigReader {
return reader{encKey: encKey}
}
func (r reader) ReadConfig(cfg Config, secure bool) (any, error) {
res := bootstrapRes{
ID: cfg.ID,
Content: cfg.Content,
ClientCert: cfg.ClientCert,
ClientKey: cfg.ClientKey,
CACert: cfg.CACert,
}
if secure {
b, err := json.Marshal(res)
if err != nil {
return nil, err
}
return r.encrypt(b)
}
return res, nil
}
func (r reader) encrypt(in []byte) ([]byte, error) {
block, err := aes.NewCipher(r.encKey)
if err != nil {
return nil, err
}
ciphertext := make([]byte, aes.BlockSize+len(in))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, err
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], in)
return ciphertext, nil
}
-102
View File
@@ -1,102 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bootstrap_test
import (
"crypto/aes"
"crypto/cipher"
"encoding/json"
"fmt"
"net/http"
"testing"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/bootstrap"
"github.com/absmach/magistrala/pkg/errors"
"github.com/stretchr/testify/assert"
)
type readResp struct {
ID string `json:"id"`
Content string `json:"content,omitempty"`
ClientCert string `json:"client_cert,omitempty"`
ClientKey string `json:"client_key,omitempty"`
CACert string `json:"ca_cert,omitempty"`
}
func dec(in []byte) ([]byte, error) {
block, err := aes.NewCipher(encKey)
if err != nil {
return nil, err
}
if len(in) < aes.BlockSize {
return nil, errors.ErrMalformedEntity
}
iv := in[:aes.BlockSize]
in = in[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(in, in)
return in, nil
}
func TestReadConfig(t *testing.T) {
cfg := bootstrap.Config{
ID: "smq_id",
ClientCert: "client_cert",
ClientKey: "client_key",
CACert: "ca_cert",
Content: "content",
}
ret := readResp{
ID: "smq_id",
Content: "content",
ClientCert: "client_cert",
ClientKey: "client_key",
CACert: "ca_cert",
}
bin, err := json.Marshal(ret)
assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err))
reader := bootstrap.NewConfigReader(encKey)
cases := []struct {
desc string
config bootstrap.Config
enc []byte
secret bool
err error
}{
{
desc: "read a config",
config: cfg,
enc: bin,
secret: false,
},
{
desc: "read encrypted config",
config: cfg,
enc: bin,
secret: true,
},
}
for _, tc := range cases {
res, err := reader.ReadConfig(tc.config, tc.secret)
assert.Nil(t, err, fmt.Sprintf("Reading config to succeed: %s.\n", err))
if tc.secret {
d, err := dec(res.([]byte))
assert.Nil(t, err, fmt.Sprintf("Decrypting expected to succeed: %s.\n", err))
assert.Equal(t, tc.enc, d, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, d))
continue
}
b, err := json.Marshal(res)
assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err))
assert.Equal(t, tc.enc, b, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, b))
resp, ok := res.(magistrala.Response)
assert.True(t, ok, "If not encrypted, reader should return response.")
assert.False(t, resp.Empty(), fmt.Sprintf("Response should not be empty %s.", err))
assert.Equal(t, http.StatusOK, resp.Code(), "Default config response code should be 200.")
}
}
-532
View File
@@ -1,532 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bootstrap
import (
"context"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"github.com/absmach/magistrala"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
mgsdk "github.com/absmach/magistrala/pkg/sdk"
)
var (
// ErrExternalKey indicates a non-existent bootstrap configuration for given external key.
ErrExternalKey = errors.NewAuthZError("failed to get bootstrap configuration for given external key")
// ErrExternalKeySecure indicates error in getting bootstrap configuration for given encrypted external key.
ErrExternalKeySecure = errors.NewAuthZError("failed to get bootstrap configuration for given encrypted external key")
// ErrBootstrap indicates error in getting bootstrap configuration.
ErrBootstrap = errors.New("failed to read bootstrap configuration")
// ErrAddBootstrap indicates error in adding bootstrap configuration.
ErrAddBootstrap = errors.NewServiceError("failed to add bootstrap configuration")
// ErrBootstrapStatus indicates an invalid bootstrap status.
ErrBootstrapStatus = errors.NewRequestError("invalid bootstrap status")
errRemoveBootstrap = errors.New("failed to remove bootstrap configuration")
errEnableConfig = errors.New("failed to enable bootstrap configuration")
errDisableConfig = errors.New("failed to disable bootstrap configuration")
errUpdateCert = errors.New("failed to update cert")
errCreateProfile = errors.New("failed to create profile")
errViewProfile = errors.New("failed to view profile")
errUpdateProfile = errors.New("failed to update profile")
errDeleteProfile = errors.New("failed to delete profile")
errListProfiles = errors.New("failed to list profiles")
errAssignProfile = errors.New("failed to assign profile to enrollment")
errBindResources = errors.New("failed to bind resources")
errListBindings = errors.New("failed to list bindings")
errRefreshBinding = errors.New("failed to refresh bindings")
errRenderBootstrap = errors.New("failed to render bootstrap configuration")
)
var _ Service = (*bootstrapService)(nil)
// Service specifies an API that must be fulfilled by the domain service
// implementation, and all of its decorators (e.g. logging & metrics).
type Service interface {
// Add adds new Client Config to the user identified by the provided token.
Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error)
// View returns Client Config with given ID belonging to the user identified by the given token.
View(ctx context.Context, session smqauthn.Session, id string) (Config, error)
// Update updates editable fields of the provided Config.
Update(ctx context.Context, session smqauthn.Session, cfg Config) error
// UpdateCert updates an existing Config certificate and token.
// A non-nil error is returned to indicate operation failure.
UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (Config, error)
// List returns subset of Configs with given search params that belong to the
// user identified by the given token.
List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error)
// Remove removes Config with specified token that belongs to the user identified by the given token.
Remove(ctx context.Context, session smqauthn.Session, id string) error
// Bootstrap returns Config to the Client with provided external ID using external key.
Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error)
// EnableConfig enables the Config so its device can successfully bootstrap.
EnableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error)
// DisableConfig disables the Config, preventing its device from bootstrapping.
DisableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error)
// CreateProfile persists a new device Profile.
CreateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error)
// ViewProfile returns the Profile with the given ID.
ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error)
// UpdateProfile updates editable fields of the given Profile and returns the updated Profile.
UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error)
// ListProfiles returns a page of Profiles belonging to the domain.
ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error)
// DeleteProfile removes the Profile with the given ID.
DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error
// AssignProfile sets the ProfileID on an existing enrollment (Config).
AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error
// BindResources resolves the requested bindings through their owning services,
// stores snapshots, and marks the enrollment renderable when all required slots
// are satisfied.
BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []BindingRequest) error
// ListBindings returns all stored binding snapshots for an enrollment.
ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]BindingSnapshot, error)
// RefreshBindings re-resolves all existing bindings for an enrollment and
// updates the stored snapshots.
RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error
}
// ConfigReader is used to parse Config into format which will be encoded
// as a JSON and consumed from the client side. The purpose of this interface
// is to provide convenient way to generate custom configuration response
// based on the specific Config which will be consumed by the client.
type ConfigReader interface {
ReadConfig(Config, bool) (any, error)
}
type bootstrapService struct {
configs ConfigRepository
profiles ProfileRepository
bindings BindingStore
resolver BindingResolver
renderer Renderer
hasher Hasher
sdk mgsdk.SDK
encKey []byte
idProvider magistrala.IDProvider
}
// New returns new Bootstrap service.
func New(
configs ConfigRepository,
profiles ProfileRepository,
bindings BindingStore,
resolver BindingResolver,
renderer Renderer,
sdk mgsdk.SDK,
hasher Hasher,
encKey []byte,
idp magistrala.IDProvider,
) Service {
return &bootstrapService{
configs: configs,
profiles: profiles,
bindings: bindings,
resolver: resolver,
renderer: renderer,
hasher: hasher,
sdk: sdk,
encKey: encKey,
idProvider: idp,
}
}
func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error) {
id, err := bs.idProvider.ID()
if err != nil {
return Config{}, errors.Wrap(ErrAddBootstrap, err)
}
hashedKey, err := bs.hasher.Hash(cfg.ExternalKey)
if err != nil {
return Config{}, errors.Wrap(ErrAddBootstrap, err)
}
cfg.ID = id
cfg.DomainID = session.DomainID
cfg.Status = Active
cfg.ExternalKey = hashedKey
saved, err := bs.configs.Save(ctx, cfg)
if err != nil {
if errors.Contains(err, repoerr.ErrConflict) {
return Config{}, errors.Wrap(svcerr.ErrConflict, err)
}
return Config{}, errors.Wrap(ErrAddBootstrap, err)
}
cfg.ID = saved
return cfg, nil
}
func (bs bootstrapService) View(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, id)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
return cfg, nil
}
func (bs bootstrapService) Update(ctx context.Context, session smqauthn.Session, cfg Config) error {
cfg.DomainID = session.DomainID
if err := bs.configs.Update(ctx, cfg); err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return nil
}
func (bs bootstrapService) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (Config, error) {
cfg, err := bs.configs.UpdateCert(ctx, session.DomainID, id, clientCert, clientKey, caCert)
if err != nil {
return Config{}, errors.Wrap(errUpdateCert, err)
}
return cfg, nil
}
func (bs bootstrapService) List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error) {
return bs.configs.RetrieveAll(ctx, session.DomainID, filter, offset, limit), nil
}
func (bs bootstrapService) Remove(ctx context.Context, session smqauthn.Session, id string) error {
if err := bs.configs.Remove(ctx, session.DomainID, id); err != nil {
return errors.Wrap(errRemoveBootstrap, err)
}
return nil
}
func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error) {
cfg, err := bs.configs.RetrieveByExternalID(ctx, externalID)
if err != nil {
return cfg, errors.Wrap(ErrBootstrap, err)
}
if secure {
dec, err := bs.dec(externalKey)
if err != nil {
return Config{}, errors.Wrap(ErrExternalKeySecure, err)
}
externalKey = dec
}
if err := bs.hasher.Compare(externalKey, cfg.ExternalKey); err != nil {
return Config{}, ErrExternalKey
}
if cfg.Status == DisabledStatus {
return Config{}, ErrBootstrap
}
cfg, err = bs.renderBootstrapConfig(ctx, cfg)
if err != nil {
return Config{}, errors.Wrap(ErrBootstrap, err)
}
return cfg, nil
}
func (bs bootstrapService) renderBootstrapConfig(ctx context.Context, cfg Config) (Config, error) {
if cfg.ProfileID == "" {
return cfg, nil
}
if bs.profiles == nil || bs.bindings == nil || bs.renderer == nil {
return Config{}, errors.Wrap(errRenderBootstrap, errors.New("profile rendering support not configured"))
}
profile, err := bs.profiles.RetrieveByID(ctx, cfg.DomainID, cfg.ProfileID)
if err != nil {
return Config{}, errors.Wrap(errRenderBootstrap, err)
}
bindings, err := bs.bindings.Retrieve(ctx, cfg.ID)
if err != nil {
return Config{}, errors.Wrap(errRenderBootstrap, err)
}
if err := validateRequiredBindings(profile, bindings); err != nil {
return Config{}, errors.Wrap(errRenderBootstrap, err)
}
bindings, err = bs.decryptSecretSnapshots(bindings)
if err != nil {
return Config{}, errors.Wrap(errRenderBootstrap, err)
}
rendered, err := bs.renderer.Render(profile, cfg, bindings)
if err != nil {
return Config{}, errors.Wrap(errRenderBootstrap, err)
}
cfg.Content = string(rendered)
return cfg, nil
}
func (bs bootstrapService) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
cfg, err := bs.changeConfigStatus(ctx, session.DomainID, id, EnabledStatus)
if err != nil {
return Config{}, errors.Wrap(errEnableConfig, err)
}
return cfg, nil
}
func (bs bootstrapService) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
cfg, err := bs.changeConfigStatus(ctx, session.DomainID, id, DisabledStatus)
if err != nil {
return Config{}, errors.Wrap(errDisableConfig, err)
}
return cfg, nil
}
func (bs bootstrapService) changeConfigStatus(ctx context.Context, domainID, id string, status Status) (Config, error) {
cfg, err := bs.configs.RetrieveByID(ctx, domainID, id)
if err != nil {
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if cfg.Status == status {
return cfg, nil
}
if err := bs.configs.ChangeStatus(ctx, domainID, id, status); err != nil {
return Config{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
cfg.Status = status
return cfg, nil
}
// --- Profile management ---
func (bs bootstrapService) CreateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) {
if bs.profiles == nil {
return Profile{}, errors.Wrap(errCreateProfile, errors.New("profile repository not configured"))
}
id, err := bs.idProvider.ID()
if err != nil {
return Profile{}, errors.Wrap(errCreateProfile, err)
}
p.ID = id
p.DomainID = session.DomainID
if p.ContentFormat == "" {
p.ContentFormat = ContentFormatJSON
}
p.Version = 1
if err := validateProfileBindingSlots(p); err != nil {
return Profile{}, errors.Wrap(errCreateProfile, err)
}
if err := validateProfileTemplate(p); err != nil {
return Profile{}, errors.Wrap(errCreateProfile, err)
}
saved, err := bs.profiles.Save(ctx, p)
if err != nil {
return Profile{}, errors.Wrap(errCreateProfile, err)
}
return saved, nil
}
func (bs bootstrapService) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error) {
if bs.profiles == nil {
return Profile{}, errors.Wrap(errViewProfile, errors.New("profile repository not configured"))
}
p, err := bs.profiles.RetrieveByID(ctx, session.DomainID, profileID)
if err != nil {
return Profile{}, errors.Wrap(errViewProfile, err)
}
return p, nil
}
func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) {
if bs.profiles == nil {
return Profile{}, errors.Wrap(errUpdateProfile, errors.New("profile repository not configured"))
}
p.DomainID = session.DomainID
if err := validateProfileBindingSlots(p); err != nil {
return Profile{}, errors.Wrap(errUpdateProfile, err)
}
if err := validateProfileTemplate(p); err != nil {
return Profile{}, errors.Wrap(errUpdateProfile, err)
}
updated, err := bs.profiles.Update(ctx, p)
if err != nil {
return Profile{}, errors.Wrap(errUpdateProfile, err)
}
return updated, nil
}
func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error) {
if bs.profiles == nil {
return ProfilesPage{}, errors.Wrap(errListProfiles, errors.New("profile repository not configured"))
}
page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit, name)
if err != nil {
return ProfilesPage{}, errors.Wrap(errListProfiles, err)
}
return page, nil
}
func (bs bootstrapService) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
if bs.profiles == nil {
return errors.Wrap(errDeleteProfile, errors.New("profile repository not configured"))
}
if err := bs.profiles.Delete(ctx, session.DomainID, profileID); err != nil {
return errors.Wrap(errDeleteProfile, err)
}
return nil
}
// --- Enrollment-profile assignment ---
func (bs bootstrapService) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
if bs.profiles == nil {
return errors.Wrap(errAssignProfile, errors.New("profile repository not configured"))
}
// Validate profile exists in domain.
if _, err := bs.profiles.RetrieveByID(ctx, session.DomainID, profileID); err != nil {
return errors.Wrap(errAssignProfile, err)
}
if err := bs.configs.AssignProfile(ctx, session.DomainID, configID, profileID); err != nil {
return errors.Wrap(errAssignProfile, err)
}
return nil
}
// --- Binding management ---
func (bs bootstrapService) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, requested []BindingRequest) error {
if bs.profiles == nil || bs.bindings == nil || bs.resolver == nil {
return errors.Wrap(errBindResources, errors.New("binding support not configured"))
}
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID)
if err != nil {
return errors.Wrap(errBindResources, err)
}
profile, err := bs.profiles.RetrieveByID(ctx, session.DomainID, cfg.ProfileID)
if err != nil {
return errors.Wrap(errBindResources, err)
}
if err := validateRequestedBindings(profile, requested); err != nil {
return errors.Wrap(errBindResources, err)
}
snapshots, err := bs.resolver.Resolve(ctx, ResolveRequest{
Enrollment: cfg,
Token: token,
Requested: requested,
})
if err != nil {
return errors.Wrap(errBindResources, err)
}
existing, err := bs.bindings.Retrieve(ctx, configID)
if err != nil {
return errors.Wrap(errBindResources, err)
}
if err := validateRequiredBindings(profile, mergeBindingSnapshots(existing, snapshots)); err != nil {
return errors.Wrap(errBindResources, err)
}
snapshots, err = bs.encryptSecretSnapshots(snapshots)
if err != nil {
return errors.Wrap(errBindResources, err)
}
if err := bs.bindings.Save(ctx, configID, snapshots); err != nil {
return errors.Wrap(errBindResources, err)
}
return nil
}
func (bs bootstrapService) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]BindingSnapshot, error) {
if bs.bindings == nil {
return nil, errors.Wrap(errListBindings, errors.New("binding support not configured"))
}
if _, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID); err != nil {
return nil, errors.Wrap(errListBindings, err)
}
snapshots, err := bs.bindings.Retrieve(ctx, configID)
if err != nil {
return nil, errors.Wrap(errListBindings, err)
}
return hideSecretSnapshots(snapshots), nil
}
func (bs bootstrapService) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
if bs.profiles == nil || bs.bindings == nil || bs.resolver == nil {
return errors.Wrap(errRefreshBinding, errors.New("binding support not configured"))
}
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID)
if err != nil {
return errors.Wrap(errRefreshBinding, err)
}
profile, err := bs.profiles.RetrieveByID(ctx, session.DomainID, cfg.ProfileID)
if err != nil {
return errors.Wrap(errRefreshBinding, err)
}
existing, err := bs.bindings.Retrieve(ctx, configID)
if err != nil {
return errors.Wrap(errRefreshBinding, err)
}
if len(existing) == 0 {
return nil
}
// Re-resolve every existing binding to refresh its snapshot.
requested := make([]BindingRequest, len(existing))
for i, b := range existing {
requested[i] = BindingRequest{Slot: b.Slot, Type: b.Type, ResourceID: b.ResourceID}
}
if err := validateRequestedBindings(profile, requested); err != nil {
return errors.Wrap(errRefreshBinding, err)
}
refreshed, err := bs.resolver.Resolve(ctx, ResolveRequest{
Enrollment: cfg,
Token: token,
Requested: requested,
})
if err != nil {
return errors.Wrap(errRefreshBinding, err)
}
if err := validateRequiredBindings(profile, refreshed); err != nil {
return errors.Wrap(errRefreshBinding, err)
}
refreshed, err = bs.encryptSecretSnapshots(refreshed)
if err != nil {
return errors.Wrap(errRefreshBinding, err)
}
return bs.bindings.Save(ctx, configID, refreshed)
}
func (bs bootstrapService) dec(in string) (string, error) {
ciphertext, err := hex.DecodeString(in)
if err != nil {
return "", err
}
block, err := aes.NewCipher(bs.encKey)
if err != nil {
return "", err
}
if len(ciphertext) < aes.BlockSize {
return "", err
}
iv := ciphertext[:aes.BlockSize]
ciphertext = ciphertext[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(ciphertext, ciphertext)
return string(ciphertext), nil
}
File diff suppressed because it is too large Load Diff
-12
View File
@@ -1,12 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package tracing provides tracing instrumentation for Magistrala Users service.
//
// This package provides tracing middleware for Magistrala Users service.
// It can be used to trace incoming requests and add tracing capabilities to
// Magistrala Users service.
//
// For more details about tracing instrumentation for Magistrala messaging refer
// to the documentation at https://magistrala.absmach.eu/docs/.
package tracing
-198
View File
@@ -1,198 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package tracing
import (
"context"
"github.com/absmach/magistrala/bootstrap"
smqauthn "github.com/absmach/magistrala/pkg/authn"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var _ bootstrap.Service = (*tracingMiddleware)(nil)
type tracingMiddleware struct {
tracer trace.Tracer
svc bootstrap.Service
}
// New returns a new bootstrap service with tracing capabilities.
func New(svc bootstrap.Service, tracer trace.Tracer) bootstrap.Service {
return &tracingMiddleware{tracer, svc}
}
// Add traces the "Add" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_register_user", trace.WithAttributes(
attribute.String("config_id", cfg.ID),
attribute.String("domain_id ", cfg.DomainID),
attribute.String("name", cfg.Name),
attribute.String("external_id", cfg.ExternalID),
attribute.String("content", cfg.Content),
attribute.String("status", cfg.Status.String()),
))
defer span.End()
return tm.svc.Add(ctx, session, token, cfg)
}
// View traces the "View" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_view_user", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.View(ctx, session, id)
}
// Update traces the "Update" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
ctx, span := tm.tracer.Start(ctx, "svc_update_user", trace.WithAttributes(
attribute.String("name", cfg.Name),
attribute.String("content", cfg.Content),
attribute.String("config_id", cfg.ID),
attribute.String("domain_id ", cfg.DomainID),
))
defer span.End()
return tm.svc.Update(ctx, session, cfg)
}
// UpdateCert traces the "UpdateCert" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_update_cert", trace.WithAttributes(
attribute.String("config_id", id),
))
defer span.End()
return tm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
}
// List traces the "List" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_users", trace.WithAttributes(
attribute.Int64("offset", int64(offset)),
attribute.Int64("limit", int64(limit)),
))
defer span.End()
return tm.svc.List(ctx, session, filter, offset, limit)
}
// Remove traces the "Remove" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error {
ctx, span := tm.tracer.Start(ctx, "svc_remove_user", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.Remove(ctx, session, id)
}
// Bootstrap traces the "Bootstrap" operation of the wrapped bootstrap.Service.
func (tm *tracingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_bootstrap_user", trace.WithAttributes(
attribute.String("external_id", externalID),
attribute.Bool("secure", secure),
))
defer span.End()
return tm.svc.Bootstrap(ctx, externalKey, externalID, secure)
}
func (tm *tracingMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_enable_config", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.EnableConfig(ctx, session, id)
}
func (tm *tracingMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
ctx, span := tm.tracer.Start(ctx, "svc_disable_config", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.DisableConfig(ctx, session, id)
}
func (tm *tracingMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
ctx, span := tm.tracer.Start(ctx, "svc_create_profile", trace.WithAttributes(
attribute.String("name", p.Name),
attribute.String("domain_id", p.DomainID),
))
defer span.End()
return tm.svc.CreateProfile(ctx, session, p)
}
func (tm *tracingMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
ctx, span := tm.tracer.Start(ctx, "svc_view_profile", trace.WithAttributes(
attribute.String("profile_id", profileID),
))
defer span.End()
return tm.svc.ViewProfile(ctx, session, profileID)
}
func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
ctx, span := tm.tracer.Start(ctx, "svc_update_profile", trace.WithAttributes(
attribute.String("profile_id", p.ID),
))
defer span.End()
return tm.svc.UpdateProfile(ctx, session, p)
}
func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_profiles", trace.WithAttributes(
attribute.Int64("offset", int64(offset)),
attribute.Int64("limit", int64(limit)),
))
defer span.End()
return tm.svc.ListProfiles(ctx, session, offset, limit, name)
}
func (tm *tracingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
ctx, span := tm.tracer.Start(ctx, "svc_delete_profile", trace.WithAttributes(
attribute.String("profile_id", profileID),
))
defer span.End()
return tm.svc.DeleteProfile(ctx, session, profileID)
}
func (tm *tracingMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
ctx, span := tm.tracer.Start(ctx, "svc_assign_profile", trace.WithAttributes(
attribute.String("config_id", configID),
attribute.String("profile_id", profileID),
))
defer span.End()
return tm.svc.AssignProfile(ctx, session, configID, profileID)
}
func (tm *tracingMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
ctx, span := tm.tracer.Start(ctx, "svc_bind_resources", trace.WithAttributes(
attribute.String("config_id", configID),
))
defer span.End()
return tm.svc.BindResources(ctx, session, token, configID, bindings)
}
func (tm *tracingMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
ctx, span := tm.tracer.Start(ctx, "svc_list_bindings", trace.WithAttributes(
attribute.String("config_id", configID),
))
defer span.End()
return tm.svc.ListBindings(ctx, session, configID)
}
func (tm *tracingMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
ctx, span := tm.tracer.Start(ctx, "svc_refresh_bindings", trace.WithAttributes(
attribute.String("config_id", configID),
))
defer span.End()
return tm.svc.RefreshBindings(ctx, session, token, configID)
}
-214
View File
@@ -1,214 +0,0 @@
# Channels
The Channels service is a core component of Magistrala that manages communication channels between devices and applications. It handles channel creation, configuration, access control and message routing within the Magistrala ecosystem.
## Configuration
The service is configured using the following environment variables (unset variables use default values):
| Variable | Description | Default |
| ------------------------- | --------------------------------------------- | ------------------------------ |
| `MG_CHANNELS_LOG_LEVEL` | Log level (debug, info, warn, error) | info |
| `MG_CHANNELS_HTTP_HOST` | HTTP host for Channels service | localhost |
| `MG_CHANNELS_HTTP_PORT` | HTTP port for Channels service | 9005 |
| `MG_CHANNELS_SERVER_CERT` | Path to PEM encoded server certificate | "" |
| `MG_CHANNELS_SERVER_KEY` | Path to PEM encoded server key file | "" |
| `MG_CHANNELS_GRPC_HOST` | gRPC host for Channels service | localhost |
| `MG_CHANNELS_GRPC_PORT` | gRPC port for Channels service | 7005 |
| `MG_CHANNELS_DB_HOST` | Database host address | localhost |
| `MG_CHANNELS_DB_PORT` | Database port | 5432 |
| `MG_CHANNELS_DB_USER` | Database user | magistrala |
| `MG_CHANNELS_DB_PASS` | Database password | magistrala |
| `MG_CHANNELS_DB_NAME` | Name of the database used by the service | channels |
| `MG_CHANNELS_DB_SSL_MODE` | Database connection SSL mode | disable |
| `MG_CHANNELS_CACHE_URL` | Cache database URL | <redis://localhost:6379/0> |
| `MG_JAEGER_URL` | Jaeger tracing server URL | <http://jaeger:4318/v1/traces> |
| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | true |
## Features
- **Channel Management**: Create, update, delete and list channels
- **Access Control**: Manage channel permissions and user access
- **Message Routing**: Route messages between connected devices and services
- **Channel Groups**: Organize channels into logical groups
- **Metadata Support**: Attach custom metadata to channels
- **Real-time Updates**: Live channel state synchronization
## Architecture
The service is built using:
- **Go**: Core service implementation
- **gRPC**: Inter-service communication
- **PostgreSQL**: Primary data storage
- **Redis**: Caching and pub/sub messaging
- **Docker**: Containerized deployment
### Channels Table
| Column | Type | Description |
| ----------------- | ------------- | ----------------------------------------------------- |
| `id` | VARCHAR(36) | UUID of the channel (primary key) |
| `name` | VARCHAR(1024) | Human-readable name |
| `domain_id` | VARCHAR(36) | Domain to which the channel belongs |
| `parent_group_id` | VARCHAR(36) | Optional group parent |
| `tags` | TEXT[] | Array of tags |
| `metadata` | JSONB | Free-form structured metadata |
| `created_by` | VARCHAR(254) | User that created the channel |
| `created_at` | TIMESTAMPTZ | Timestamp of creation |
| `updated_at` | TIMESTAMPTZ | Timestamp of last update |
| `updated_by` | VARCHAR(254) | User that performed last update |
| `status` | SMALLINT | 0 = enabled, 1 = disabled |
| `route` | VARCHAR(36) | Optional route identifier unique within domain if set |
### Connections Table
| Column | Type | Description |
| ------------ | ----------- | ----------------------------------------------- |
| `channel_id` | VARCHAR(36) | Channel UUID |
| `domain_id` | VARCHAR(36) | Domain of channel and client |
| `client_id` | VARCHAR(36) | Client UUID |
| `type` | SMALLINT | Connection type: `1 = Publish`, `2 = Subscribe` |
## Deployment
The service is available as a Docker container. Refer to the Docker Compose section for the `channels` service in `docker-compose.yaml` for deployment configuration.
To build and run locally:
```bash
# download the latest version of the service
git clone https://github.com/absmach/magistrala
cd magistrala
# compile the channels
make channels
make install
MG_CHANNELS_HTTP_HOST=localhost \
MG_CHANNELS_HTTP_PORT=9005 \
MG_CHANNELS_DB_HOST=localhost \
MG_CHANNELS_DB_PORT=5432 \
MG_CHANNELS_DB_USER=magistrala \MG_CHANNELS_DB_PASS=magistrala \MG_CHANNELS_DB_NAME=channels \
$GOBIN/magistrala-channels
```
### Running the Service
```bash
# Set environment variables
export MQ_CHANNELS_DB_HOST=localhost
export MQ_CHANNELS_DB_PORT=5432
# Run the service
go run cmd/main.go
```
### Docker Deployment
```bash
docker run -p 8180:8180 magistrala/channels
```
## Testing
```bash
# Run unit tests
go test ./...
# Run integration tests
make test-integration
```
## Usage
The Channels service supports the following operations:
| Operation | Description |
| --------------- | -------------------------------------------- |
| `create` | Create a new channel |
| `list` | Retrieve all channels (paged) |
| `get` | Retrieve a single channel by ID |
| `update` | Update a channels name & metadata |
| `delete` | Permanently delete a channel |
| `enable` | Enable a previously disabled channel |
| `disable` | Disable an active channel |
| `set-parent` | Assign a parent group to a channel |
| `remove-parent` | Remove parent group from a channel |
| `connect` | Connect one or more clients to channels |
| `disconnect` | Disconnect one or more clients from channels |
### Example: Create a Channel
```bash
curl -X POST http://localhost:9005/<domainID>/channels \
-H "Authorization: Bearer <your_access_token>" \
-H "Content-Type: application/json" \
-d '{
"name": "myChannel",
"metadata": { "location": "lab" },
"route": "sensor-data",
"tags": ["sensor","edge"],
"status": "enabled"
}'
```
### Example: Connect Clients & Channels
```bash
curl -X POST http://localhost:9005/<domainID>/channels/connect \
-H "Authorization: Bearer <your_access_token>" \
-H "Content-Type: application/json" \
-d '{
"channel_ids": ["<chanID1>", "<chanID2>"],
"client_ids": ["<clientID1>", "<clientID2>"],
"types": ["publish", "subscribe"]
}'
```
### Example: Disconnect Clients from a Channel
```bash
curl -X POST http://localhost:9005/<domainID>/channels/disconnect \
-H "Authorization: Bearer <your_access_token>" \
-H "Content-Type: application/json" \
-d '{
"channel_ids": ["<chanID>"],
"client_ids": ["<clientID>"],
"types": ["publish"]
}'
```
## Best Practices
- Use tags and metadata to manage and categorize channels (e.g., environment, region, purpose).
- Assign `route` thoughtfully when channels need a predictable identifier.
- Keep channel hierarchies shallow for easier navigation (avoid deep nesting unless required).
- Use `disable` rather than immediate delete when you want to suspend a channel temporarily.
- Clean up unused connections: regularly review which clients are connected to channels and remove stale links.
- Enforce minimal privileges: only allow clients to connect to channels they truly need.
- Monitoring: use the `/health` endpoint and version metadata for service stability.
## Versioning & Health Check
The Channels service exposes a `/health` endpoint to provide operational status and version info.
### Health Check Request
```bash
curl -X GET http://localhost:9005/health \
-H "accept: application/health+json"
```
### Example Response
```json
{
"status": "pass",
"version": "0.18.0",
"commit": "<commit-hash>",
"description": "channels service",
"build_time": "2025-11-19T..."
}
```
-224
View File
@@ -1,224 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
"fmt"
"time"
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/go-kit/kit/endpoint"
kitgrpc "github.com/go-kit/kit/transport/grpc"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
const svcName = "channels.v1.ChannelsService"
var _ grpcChannelsV1.ChannelsServiceClient = (*grpcClient)(nil)
type grpcClient struct {
timeout time.Duration
authorize endpoint.Endpoint
removeClientConnections endpoint.Endpoint
unsetParentGroupFromChannels endpoint.Endpoint
retrieveEntity endpoint.Endpoint
retrieveIDByRoute endpoint.Endpoint
}
// NewClient returns new gRPC client instance.
func NewClient(conn *grpc.ClientConn, timeout time.Duration) grpcChannelsV1.ChannelsServiceClient {
return &grpcClient{
authorize: kitgrpc.NewClient(
conn,
svcName,
"Authorize",
encodeAuthorizeRequest,
decodeAuthorizeResponse,
grpcChannelsV1.AuthzRes{},
).Endpoint(),
removeClientConnections: kitgrpc.NewClient(
conn,
svcName,
"RemoveClientConnections",
encodeRemoveClientConnectionsRequest,
decodeRemoveClientConnectionsResponse,
grpcChannelsV1.RemoveClientConnectionsRes{},
).Endpoint(),
unsetParentGroupFromChannels: kitgrpc.NewClient(
conn,
svcName,
"UnsetParentGroupFromChannels",
encodeUnsetParentGroupFromChannelsRequest,
decodeUnsetParentGroupFromChannelsResponse,
grpcChannelsV1.UnsetParentGroupFromChannelsRes{},
).Endpoint(),
retrieveEntity: kitgrpc.NewClient(
conn,
svcName,
"RetrieveEntity",
encodeRetrieveEntityRequest,
decodeRetrieveEntityResponse,
grpcCommonV1.RetrieveEntityRes{},
).Endpoint(),
retrieveIDByRoute: kitgrpc.NewClient(
conn,
svcName,
"RetrieveIDByRoute",
encodeRetrieveIDByRouteRequest,
decodeRetrieveIDByRouteResponse,
grpcCommonV1.RetrieveEntityRes{},
).Endpoint(),
timeout: timeout,
}
}
func (client grpcClient) Authorize(ctx context.Context, req *grpcChannelsV1.AuthzReq, _ ...grpc.CallOption) (r *grpcChannelsV1.AuthzRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.authorize(ctx, authorizeReq{
domainID: req.GetDomainId(),
clientID: req.GetClientId(),
clientType: req.GetClientType(),
channelID: req.GetChannelId(),
connType: connections.ConnType(req.GetType()),
})
if err != nil {
return &grpcChannelsV1.AuthzRes{}, decodeError(err)
}
ar := res.(authorizeRes)
return &grpcChannelsV1.AuthzRes{Authorized: ar.authorized}, nil
}
func encodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(authorizeReq)
return &grpcChannelsV1.AuthzReq{
DomainId: req.domainID,
ClientId: req.clientID,
ClientType: req.clientType,
ChannelId: req.channelID,
Type: uint32(req.connType),
}, nil
}
func decodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(*grpcChannelsV1.AuthzRes)
return authorizeRes{authorized: res.GetAuthorized()}, nil
}
func (client grpcClient) RemoveClientConnections(ctx context.Context, req *grpcChannelsV1.RemoveClientConnectionsReq, _ ...grpc.CallOption) (r *grpcChannelsV1.RemoveClientConnectionsRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
if _, err := client.removeClientConnections(ctx, req); err != nil {
return &grpcChannelsV1.RemoveClientConnectionsRes{}, decodeError(err)
}
return &grpcChannelsV1.RemoveClientConnectionsRes{}, nil
}
func encodeRemoveClientConnectionsRequest(_ context.Context, grpcReq any) (any, error) {
return grpcReq.(*grpcChannelsV1.RemoveClientConnectionsReq), nil
}
func decodeRemoveClientConnectionsResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes.(*grpcChannelsV1.RemoveClientConnectionsRes), nil
}
func (client grpcClient) UnsetParentGroupFromChannels(ctx context.Context, req *grpcChannelsV1.UnsetParentGroupFromChannelsReq, _ ...grpc.CallOption) (r *grpcChannelsV1.UnsetParentGroupFromChannelsRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
if _, err := client.unsetParentGroupFromChannels(ctx, req); err != nil {
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, decodeError(err)
}
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, nil
}
func encodeUnsetParentGroupFromChannelsRequest(_ context.Context, grpcReq any) (any, error) {
return grpcReq.(*grpcChannelsV1.UnsetParentGroupFromChannelsReq), nil
}
func decodeUnsetParentGroupFromChannelsResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes.(*grpcChannelsV1.UnsetParentGroupFromChannelsRes), nil
}
func (client grpcClient) RetrieveEntity(ctx context.Context, req *grpcCommonV1.RetrieveEntityReq, _ ...grpc.CallOption) (r *grpcCommonV1.RetrieveEntityRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.retrieveEntity(ctx, req)
if err != nil {
return &grpcCommonV1.RetrieveEntityRes{}, decodeError(err)
}
return res.(*grpcCommonV1.RetrieveEntityRes), nil
}
func encodeRetrieveEntityRequest(_ context.Context, grpcReq any) (any, error) {
return grpcReq.(*grpcCommonV1.RetrieveEntityReq), nil
}
func decodeRetrieveEntityResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes.(*grpcCommonV1.RetrieveEntityRes), nil
}
func (client grpcClient) RetrieveIDByRoute(ctx context.Context, req *grpcCommonV1.RetrieveIDByRouteReq, _ ...grpc.CallOption) (r *grpcCommonV1.RetrieveEntityRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.retrieveIDByRoute(ctx, req)
if err != nil {
return &grpcCommonV1.RetrieveEntityRes{}, decodeError(err)
}
return res.(*grpcCommonV1.RetrieveEntityRes), nil
}
func encodeRetrieveIDByRouteRequest(_ context.Context, grpcReq any) (any, error) {
return grpcReq.(*grpcCommonV1.RetrieveIDByRouteReq), nil
}
func decodeRetrieveIDByRouteResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes.(*grpcCommonV1.RetrieveEntityRes), nil
}
func decodeError(err error) error {
if st, ok := status.FromError(err); ok {
switch st.Code() {
case codes.Unauthenticated:
return errors.Wrap(svcerr.ErrAuthentication, errors.New(st.Message()))
case codes.PermissionDenied:
return errors.Wrap(svcerr.ErrAuthorization, errors.New(st.Message()))
case codes.InvalidArgument:
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
case codes.FailedPrecondition:
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
case codes.NotFound:
return errors.Wrap(svcerr.ErrNotFound, errors.New(st.Message()))
case codes.AlreadyExists:
return errors.Wrap(svcerr.ErrConflict, errors.New(st.Message()))
case codes.OK:
if msg := st.Message(); msg != "" {
return errors.Wrap(errors.ErrUnidentified, errors.New(msg))
}
return nil
default:
return errors.Wrap(fmt.Errorf("unexpected gRPC status: %s (status code:%v)", st.Code().String(), st.Code()), errors.New(st.Message()))
}
}
return err
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package grpc contains implementation of Auth service gRPC API.
package grpc
-85
View File
@@ -1,85 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
ch "github.com/absmach/magistrala/channels"
channels "github.com/absmach/magistrala/channels/private"
"github.com/go-kit/kit/endpoint"
)
func authorizeEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(authorizeReq)
if err := req.validate(); err != nil {
return authorizeRes{}, err
}
if err := svc.Authorize(ctx, ch.AuthzReq{
DomainID: req.domainID,
ClientID: req.clientID,
ClientType: req.clientType,
ChannelID: req.channelID,
Type: req.connType,
}); err != nil {
return authorizeRes{}, err
}
return authorizeRes{authorized: true}, nil
}
}
func removeClientConnectionsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(removeClientConnectionsReq)
if err := svc.RemoveClientConnections(ctx, req.clientID); err != nil {
return removeClientConnectionsRes{}, err
}
return removeClientConnectionsRes{}, nil
}
}
func unsetParentGroupFromChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(unsetParentGroupFromChannelsReq)
if err := svc.UnsetParentGroupFromChannels(ctx, req.parentGroupID); err != nil {
return unsetParentGroupFromChannelsRes{}, err
}
return unsetParentGroupFromChannelsRes{}, nil
}
}
func retrieveEntityEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(retrieveEntityReq)
channel, err := svc.RetrieveByID(ctx, req.Id)
if err != nil {
return retrieveEntityRes{}, err
}
return retrieveEntityRes{id: channel.ID, domain: channel.Domain, parentGroup: channel.ParentGroup, status: uint8(channel.Status)}, nil
}
}
func retrieveIDByRouteEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(retrieveIDByRouteReq)
if err := req.validate(); err != nil {
return retrieveIDByRouteRes{}, err
}
id, err := svc.RetrieveIDByRoute(ctx, req.route, req.domainID)
if err != nil {
return retrieveIDByRouteRes{}, err
}
return retrieveIDByRouteRes{id: id}, nil
}
}
-345
View File
@@ -1,345 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc_test
import (
"context"
"fmt"
"net"
"testing"
"time"
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
ch "github.com/absmach/magistrala/channels"
grpcapi "github.com/absmach/magistrala/channels/api/grpc"
"github.com/absmach/magistrala/channels/private/mocks"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/policies"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
)
const port = 7005
var (
validID = testsutil.GenerateUUID(&testing.T{})
validChannel = ch.Channel{
ID: validID,
Domain: testsutil.GenerateUUID(&testing.T{}),
Status: channels.EnabledStatus,
}
)
func startGRPCServer(svc *mocks.Service, port int) *grpc.Server {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
panic(fmt.Sprintf("failed to obtain port: %s", err))
}
server := grpc.NewServer()
grpcChannelsV1.RegisterChannelsServiceServer(server, grpcapi.NewServer(svc))
go func() {
if err := server.Serve(listener); err != nil {
panic(fmt.Sprintf("failed to serve: %s", err))
}
}()
return server
}
func TestAuthorize(t *testing.T) {
svc := new(mocks.Service)
server := startGRPCServer(svc, port)
defer server.GracefulStop()
authAddr := fmt.Sprintf("localhost:%d", port)
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
client := grpcapi.NewClient(conn, time.Second)
cases := []struct {
desc string
domainID string
clientID string
clientType string
channelID string
connType connections.ConnType
err error
authzErr error
res *grpcChannelsV1.AuthzRes
code codes.Code
}{
{
desc: "authorize successfully",
domainID: validID,
clientID: validID,
clientType: policies.UserType,
channelID: validID,
connType: connections.Publish,
res: &grpcChannelsV1.AuthzRes{Authorized: true},
err: nil,
},
{
desc: "authorize with authorization error",
domainID: validID,
clientID: validID,
clientType: policies.UserType,
channelID: validID,
connType: connections.Publish,
res: &grpcChannelsV1.AuthzRes{Authorized: false},
authzErr: svcerr.ErrAuthorization,
err: svcerr.ErrAuthorization,
},
{
desc: "authorize withnot found error",
domainID: validID,
clientID: validID,
clientType: policies.UserType,
channelID: validID,
connType: connections.Publish,
res: &grpcChannelsV1.AuthzRes{Authorized: false},
authzErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authReq := ch.AuthzReq{
DomainID: tc.domainID,
ClientID: tc.clientID,
ClientType: tc.clientType,
ChannelID: tc.channelID,
Type: tc.connType,
}
svcCall := svc.On("Authorize", mock.Anything, authReq).Return(tc.authzErr)
res, err := client.Authorize(context.Background(), &grpcChannelsV1.AuthzReq{
DomainId: tc.domainID,
ClientId: tc.clientID,
ClientType: tc.clientType,
ChannelId: tc.channelID,
Type: uint32(tc.connType),
})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
assert.Equal(t, tc.res, res, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.res, res))
svcCall.Unset()
})
}
}
func TestRemoveClientConnections(t *testing.T) {
svc := new(mocks.Service)
server := startGRPCServer(svc, port)
defer server.GracefulStop()
authAddr := fmt.Sprintf("localhost:%d", port)
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
client := grpcapi.NewClient(conn, time.Second)
cases := []struct {
desc string
clientID string
err error
code codes.Code
}{
{
desc: "remove client connections successfully",
clientID: validID,
err: nil,
},
{
desc: "remove client connections with error",
clientID: validID,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RemoveClientConnections", mock.Anything, tc.clientID).Return(tc.err)
res, err := client.RemoveClientConnections(context.Background(), &grpcChannelsV1.RemoveClientConnectionsReq{
ClientId: tc.clientID,
})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
assert.Equal(t, &grpcChannelsV1.RemoveClientConnectionsRes{}, res)
svcCall.Unset()
})
}
}
func TestUnsetParentGroupFromChannelsEndpoint(t *testing.T) {
svc := new(mocks.Service)
server := startGRPCServer(svc, port)
defer server.GracefulStop()
authAddr := fmt.Sprintf("localhost:%d", port)
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
client := grpcapi.NewClient(conn, time.Second)
cases := []struct {
desc string
parentGroupID string
err error
code codes.Code
}{
{
desc: "unset parent group from channels successfully",
parentGroupID: validID,
err: nil,
},
{
desc: "unset parent group from channels authorization error",
parentGroupID: validID,
err: svcerr.ErrAuthorization,
},
{
desc: "unset parent group from channels with not found error",
parentGroupID: validID,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("UnsetParentGroupFromChannels", mock.Anything, tc.parentGroupID).Return(tc.err)
res, err := client.UnsetParentGroupFromChannels(context.Background(), &grpcChannelsV1.UnsetParentGroupFromChannelsReq{
ParentGroupId: tc.parentGroupID,
})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
assert.Equal(t, &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, res)
svcCall.Unset()
})
}
}
func TestRetrieveEntity(t *testing.T) {
svc := new(mocks.Service)
server := startGRPCServer(svc, port)
defer server.GracefulStop()
authAddr := fmt.Sprintf("localhost:%d", port)
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
client := grpcapi.NewClient(conn, time.Second)
cases := []struct {
desc string
id string
svcRes ch.Channel
resp *grpcCommonV1.RetrieveEntityRes
code codes.Code
err error
}{
{
desc: "retrieve entity successfully",
id: validID,
svcRes: validChannel,
resp: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: validChannel.ID,
DomainId: validChannel.Domain,
ParentGroupId: validChannel.ParentGroup,
Status: uint32(validChannel.Status),
},
},
err: nil,
},
{
desc: "retrieve entity with error",
id: validID,
resp: &grpcCommonV1.RetrieveEntityRes{},
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RetrieveByID", mock.Anything, tc.id).Return(tc.svcRes, tc.err)
res, err := client.RetrieveEntity(context.Background(), &grpcCommonV1.RetrieveEntityReq{
Id: tc.id,
})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
assert.Equal(t, tc.resp.Entity, res.Entity)
svcCall.Unset()
})
}
}
func TestRetrieveIDByRoute(t *testing.T) {
svc := new(mocks.Service)
server := startGRPCServer(svc, port)
defer server.GracefulStop()
authAddr := fmt.Sprintf("localhost:%d", port)
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
client := grpcapi.NewClient(conn, time.Second)
validRoute := "validRoute"
domainID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
retrieveReq *grpcCommonV1.RetrieveIDByRouteReq
svcRes string
svcErr error
retrieveRes *grpcCommonV1.RetrieveEntityRes
err error
}{
{
desc: "retrieve entity by route successfully",
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
Route: validRoute,
DomainId: domainID,
},
svcRes: validID,
retrieveRes: &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: validID,
},
},
err: nil,
},
{
desc: "retrieve entity by route with empty route",
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
Route: "",
DomainId: domainID,
},
svcRes: "",
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
err: apiutil.ErrMissingRoute,
},
{
desc: "retrieve entity by route with empty domain ID",
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
Route: validRoute,
DomainId: "",
},
svcRes: "",
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
err: apiutil.ErrMissingDomainID,
},
{
desc: "retrieve entity by route with invalid route",
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
Route: "invalidRoute",
DomainId: domainID,
},
svcRes: "",
svcErr: svcerr.ErrNotFound,
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RetrieveIDByRoute", mock.Anything, tc.retrieveReq.Route, tc.retrieveReq.DomainId).Return(tc.svcRes, tc.svcErr)
res, err := client.RetrieveIDByRoute(context.Background(), tc.retrieveReq)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
assert.Equal(t, tc.retrieveRes.Entity, res.Entity)
svcCall.Unset()
})
}
}
-56
View File
@@ -1,56 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
"github.com/absmach/magistrala/pkg/policies"
)
var errDomainID = errors.New("domain id required for users")
type authorizeReq struct {
domainID string
channelID string
clientID string
clientType string
connType connections.ConnType
}
func (req authorizeReq) validate() error {
if req.clientType == policies.UserType && req.domainID == "" {
return errDomainID
}
return nil
}
type removeClientConnectionsReq struct {
clientID string
}
type unsetParentGroupFromChannelsReq struct {
parentGroupID string
}
type retrieveEntityReq struct {
Id string
}
type retrieveIDByRouteReq struct {
route string
domainID string
}
func (req retrieveIDByRouteReq) validate() error {
if req.route == "" {
return apiutil.ErrMissingRoute
}
if req.domainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
-25
View File
@@ -1,25 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc
type authorizeRes struct {
authorized bool
}
type removeClientConnectionsRes struct{}
type unsetParentGroupFromChannelsRes struct{}
type channelBasic struct {
id string
domain string
parentGroup string
status uint8
}
type retrieveEntityRes channelBasic
type retrieveIDByRouteRes struct {
id string
}
-211
View File
@@ -1,211 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package grpc
import (
"context"
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
apiutil "github.com/absmach/magistrala/api/http/util"
smqauth "github.com/absmach/magistrala/auth"
channels "github.com/absmach/magistrala/channels/private"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
kitgrpc "github.com/go-kit/kit/transport/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var _ grpcChannelsV1.ChannelsServiceServer = (*grpcServer)(nil)
type grpcServer struct {
grpcChannelsV1.UnimplementedChannelsServiceServer
authorize kitgrpc.Handler
removeClientConnections kitgrpc.Handler
unsetParentGroupFromChannels kitgrpc.Handler
retrieveEntity kitgrpc.Handler
retrieveIDByRoute kitgrpc.Handler
}
// NewServer returns new AuthServiceServer instance.
func NewServer(svc channels.Service) grpcChannelsV1.ChannelsServiceServer {
return &grpcServer{
authorize: kitgrpc.NewServer(
authorizeEndpoint(svc),
decodeAuthorizeRequest,
encodeAuthorizeResponse,
),
removeClientConnections: kitgrpc.NewServer(
removeClientConnectionsEndpoint(svc),
decodeRemoveClientConnectionsRequest,
encodeRemoveClientConnectionsResponse,
),
unsetParentGroupFromChannels: kitgrpc.NewServer(
unsetParentGroupFromChannelsEndpoint(svc),
decodeUnsetParentGroupFromChannelsRequest,
encodeUnsetParentGroupFromChannelsResponse,
),
retrieveEntity: kitgrpc.NewServer(
retrieveEntityEndpoint(svc),
decodeRetrieveEntityRequest,
encodeRetrieveEntityResponse,
),
retrieveIDByRoute: kitgrpc.NewServer(
retrieveIDByRouteEndpoint(svc),
decodeRetrieveIDByRouteRequest,
encodeRetrieveIDByRouteResponse,
),
}
}
func (s *grpcServer) Authorize(ctx context.Context, req *grpcChannelsV1.AuthzReq) (*grpcChannelsV1.AuthzRes, error) {
_, res, err := s.authorize.ServeGRPC(ctx, req)
if err != nil {
return nil, encodeError(err)
}
return res.(*grpcChannelsV1.AuthzRes), nil
}
func decodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcChannelsV1.AuthzReq)
connType := connections.ConnType(req.GetType())
if err := connections.CheckConnType(connType); err != nil {
return nil, err
}
return authorizeReq{
domainID: req.GetDomainId(),
clientID: req.GetClientId(),
clientType: req.GetClientType(),
channelID: req.GetChannelId(),
connType: connType,
}, nil
}
func encodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(authorizeRes)
return &grpcChannelsV1.AuthzRes{Authorized: res.authorized}, nil
}
func (s *grpcServer) RemoveClientConnections(ctx context.Context, req *grpcChannelsV1.RemoveClientConnectionsReq) (*grpcChannelsV1.RemoveClientConnectionsRes, error) {
_, res, err := s.removeClientConnections.ServeGRPC(ctx, req)
if err != nil {
return nil, encodeError(err)
}
return res.(*grpcChannelsV1.RemoveClientConnectionsRes), nil
}
func decodeRemoveClientConnectionsRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcChannelsV1.RemoveClientConnectionsReq)
return removeClientConnectionsReq{
clientID: req.GetClientId(),
}, nil
}
func encodeRemoveClientConnectionsResponse(_ context.Context, grpcRes any) (any, error) {
_ = grpcRes.(removeClientConnectionsRes)
return &grpcChannelsV1.RemoveClientConnectionsRes{}, nil
}
func (s *grpcServer) UnsetParentGroupFromChannels(ctx context.Context, req *grpcChannelsV1.UnsetParentGroupFromChannelsReq) (*grpcChannelsV1.UnsetParentGroupFromChannelsRes, error) {
_, res, err := s.unsetParentGroupFromChannels.ServeGRPC(ctx, req)
if err != nil {
return nil, encodeError(err)
}
return res.(*grpcChannelsV1.UnsetParentGroupFromChannelsRes), nil
}
func decodeUnsetParentGroupFromChannelsRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcChannelsV1.UnsetParentGroupFromChannelsReq)
return unsetParentGroupFromChannelsReq{
parentGroupID: req.GetParentGroupId(),
}, nil
}
func encodeUnsetParentGroupFromChannelsResponse(_ context.Context, grpcRes any) (any, error) {
_ = grpcRes.(unsetParentGroupFromChannelsRes)
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, nil
}
func (s *grpcServer) RetrieveEntity(ctx context.Context, req *grpcCommonV1.RetrieveEntityReq) (*grpcCommonV1.RetrieveEntityRes, error) {
_, res, err := s.retrieveEntity.ServeGRPC(ctx, req)
if err != nil {
return nil, encodeError(err)
}
return res.(*grpcCommonV1.RetrieveEntityRes), nil
}
func decodeRetrieveEntityRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcCommonV1.RetrieveEntityReq)
return retrieveEntityReq{
Id: req.GetId(),
}, nil
}
func encodeRetrieveEntityResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(retrieveEntityRes)
return &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: res.id,
DomainId: res.domain,
ParentGroupId: res.parentGroup,
Status: uint32(res.status),
},
}, nil
}
func decodeRetrieveIDByRouteRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcCommonV1.RetrieveIDByRouteReq)
return retrieveIDByRouteReq{
route: req.GetRoute(),
domainID: req.GetDomainId(),
}, nil
}
func encodeRetrieveIDByRouteResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(retrieveIDByRouteRes)
return &grpcCommonV1.RetrieveEntityRes{
Entity: &grpcCommonV1.EntityBasic{
Id: res.id,
},
}, nil
}
func (s *grpcServer) RetrieveIDByRoute(ctx context.Context, req *grpcCommonV1.RetrieveIDByRouteReq) (*grpcCommonV1.RetrieveEntityRes, error) {
_, res, err := s.retrieveIDByRoute.ServeGRPC(ctx, req)
if err != nil {
return nil, encodeError(err)
}
return res.(*grpcCommonV1.RetrieveEntityRes), nil
}
func encodeError(err error) error {
switch {
case errors.Contains(err, nil):
return nil
case errors.Contains(err, errors.ErrMalformedEntity),
err == apiutil.ErrInvalidAuthKey,
err == apiutil.ErrMissingID,
err == apiutil.ErrMissingMemberType,
err == apiutil.ErrMissingPolicySub,
err == apiutil.ErrMissingPolicyObj,
err == apiutil.ErrMalformedPolicyAct:
return status.Error(codes.InvalidArgument, err.Error())
case errors.Contains(err, svcerr.ErrAuthentication),
errors.Contains(err, smqauth.ErrKeyExpired),
err == apiutil.ErrMissingEmail,
err == apiutil.ErrBearerToken:
return status.Error(codes.Unauthenticated, err.Error())
case errors.Contains(err, svcerr.ErrAuthorization):
return status.Error(codes.PermissionDenied, err.Error())
default:
return status.Error(codes.Internal, err.Error())
}
}
-329
View File
@@ -1,329 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
api "github.com/absmach/magistrala/api/http"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/internal/nullable"
"github.com/absmach/magistrala/pkg/errors"
"github.com/go-chi/chi/v5"
)
func decodeViewChannel(_ context.Context, r *http.Request) (any, error) {
roles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false)
if err != nil {
return viewChannelReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
req := viewChannelReq{
id: chi.URLParam(r, "channelID"),
roles: roles,
}
return req, nil
}
func decodeCreateChannelReq(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := createChannelReq{}
if err := json.NewDecoder(r.Body).Decode(&req.Channel); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeCreateChannelsReq(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := createChannelsReq{}
if err := json.NewDecoder(r.Body).Decode(&req.Channels); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeListChannels(_ context.Context, r *http.Request) (any, error) {
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
tags, err := apiutil.ReadStringQuery(r, api.TagsKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
var tq channels.TagsQuery
if tags != "" {
tq = channels.ToTagsQuery(tags)
}
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
status, err := channels.ToStatus(s)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
actions := []string{}
allActions = strings.TrimSpace(allActions)
if allActions != "" {
actions = strings.Split(allActions, ",")
}
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
groupID, err := nullable.Parse(r.URL.Query(), api.GroupKey, nullable.ParseString)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
ot, err := apiutil.ReadBoolQuery(r, api.OnlyTotal, false)
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
connectionType, err := apiutil.ReadStringQuery(r, api.ConnTypeKey, "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
cfrom, err := apiutil.ReadStringQuery(r, "created_from", "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
cto, err := apiutil.ReadStringQuery(r, "created_to", "")
if err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
}
var createdFrom, createdTo time.Time
if cfrom != "" {
if createdFrom, err = time.Parse(time.RFC3339, cfrom); err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrInvalidQueryParams, err)
}
}
if cto != "" {
if createdTo, err = time.Parse(time.RFC3339, cto); err != nil {
return listChannelsReq{}, errors.Wrap(apiutil.ErrInvalidQueryParams, err)
}
}
req := listChannelsReq{
Page: channels.Page{
Name: name,
Tags: tq,
Status: status,
Metadata: meta,
RoleName: roleName,
RoleID: roleID,
Actions: actions,
AccessType: accessType,
Order: order,
Dir: dir,
Offset: offset,
Limit: limit,
Group: groupID,
Client: clientID,
ConnectionType: connectionType,
ID: id,
OnlyTotal: ot,
CreatedFrom: createdFrom,
CreatedTo: createdTo,
},
userID: userID,
}
return req, nil
}
func decodeUpdateChannel(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := updateChannelReq{
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeUpdateChannelTags(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := updateChannelTagsReq{
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeSetChannelParentGroupStatus(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := setChannelParentGroupReq{
id: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeRemoveChannelParentGroupStatus(_ context.Context, r *http.Request) (any, error) {
req := removeChannelParentGroupReq{
id: chi.URLParam(r, "channelID"),
}
return req, nil
}
func decodeChangeChannelStatus(_ context.Context, r *http.Request) (any, error) {
req := changeChannelStatusReq{
id: chi.URLParam(r, "channelID"),
}
return req, nil
}
func decodeDeleteChannelReq(_ context.Context, r *http.Request) (any, error) {
req := deleteChannelReq{
id: chi.URLParam(r, "channelID"),
}
return req, nil
}
func decodeConnectChannelClientRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := connectChannelClientsRequest{
channelID: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeDisconnectChannelClientsRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := disconnectChannelClientsRequest{
channelID: chi.URLParam(r, "channelID"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeConnectRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := connectRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeDisconnectRequest(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
req := disconnectRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
File diff suppressed because it is too large Load Diff
-364
View File
@@ -1,364 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/go-kit/kit/endpoint"
)
func createChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(createChannelReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
channels, _, err := svc.CreateChannels(ctx, session, req.Channel)
if err != nil {
return nil, err
}
return createChannelRes{
Channel: channels[0],
created: true,
}, nil
}
}
func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(createChannelsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
channels, _, err := svc.CreateChannels(ctx, session, req.Channels...)
if err != nil {
return nil, err
}
res := channelsPageRes{
pageRes: pageRes{
Total: uint64(len(channels)),
},
Channels: []viewChannelRes{},
}
for _, c := range channels {
res.Channels = append(res.Channels, viewChannelRes{Channel: c})
}
return res, nil
}
}
func viewChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(viewChannelReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
c, err := svc.ViewChannel(ctx, session, req.id, req.roles)
if err != nil {
return nil, err
}
return viewChannelRes{Channel: c}, nil
}
}
func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(listChannelsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
var page channels.ChannelsPage
var err error
switch req.userID != "" {
case true:
page, err = svc.ListUserChannels(ctx, session, req.userID, req.Page)
default:
page, err = svc.ListChannels(ctx, session, req.Page)
}
if err != nil {
return channelsPageRes{}, err
}
res := channelsPageRes{
pageRes: pageRes{
Total: page.Total,
Offset: page.Offset,
Limit: page.Limit,
},
Channels: []viewChannelRes{},
}
for _, c := range page.Channels {
res.Channels = append(res.Channels, viewChannelRes{Channel: c})
}
return res, nil
}
}
func updateChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(updateChannelReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
ch := channels.Channel{
ID: req.id,
Name: req.Name,
Metadata: req.Metadata,
}
ch, err := svc.UpdateChannel(ctx, session, ch)
if err != nil {
return nil, err
}
return updateChannelRes{Channel: ch}, nil
}
}
func updateChannelTagsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(updateChannelTagsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
ch := channels.Channel{
ID: req.id,
Tags: req.Tags,
}
ch, err := svc.UpdateChannelTags(ctx, session, ch)
if err != nil {
return nil, err
}
return updateChannelRes{Channel: ch}, nil
}
}
func setChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(setChannelParentGroupReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.SetParentGroup(ctx, session, req.ParentGroupID, req.id); err != nil {
return nil, err
}
return setChannelParentGroupRes{}, nil
}
}
func removeChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(removeChannelParentGroupReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.RemoveParentGroup(ctx, session, req.id); err != nil {
return nil, err
}
return removeChannelParentGroupRes{}, nil
}
}
func enableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(changeChannelStatusReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
ch, err := svc.EnableChannel(ctx, session, req.id)
if err != nil {
return nil, err
}
return changeChannelStatusRes{Channel: ch}, nil
}
}
func disableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(changeChannelStatusReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
ch, err := svc.DisableChannel(ctx, session, req.id)
if err != nil {
return nil, err
}
return changeChannelStatusRes{Channel: ch}, nil
}
}
func connectChannelClientEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(connectChannelClientsRequest)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.Connect(ctx, session, []string{req.channelID}, req.ClientIDs, req.Types); err != nil {
return nil, err
}
return connectChannelClientsRes{}, nil
}
}
func disconnectChannelClientsEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(disconnectChannelClientsRequest)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.Disconnect(ctx, session, []string{req.channelID}, req.ClientIds, req.Types); err != nil {
return nil, err
}
return disconnectChannelClientsRes{}, nil
}
}
func connectEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(connectRequest)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.Connect(ctx, session, req.ChannelIds, req.ClientIds, req.Types); err != nil {
return nil, err
}
return connectRes{}, nil
}
}
func disconnectEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(disconnectRequest)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.Disconnect(ctx, session, req.ChannelIds, req.ClientIds, req.Types); err != nil {
return nil, err
}
return disconnectRes{}, nil
}
}
func deleteChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(deleteChannelReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.RemoveChannel(ctx, session, req.id); err != nil {
return nil, err
}
return deleteChannelRes{}, nil
}
}
-321
View File
@@ -1,321 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"strings"
api "github.com/absmach/magistrala/api/http"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/connections"
)
type createChannelReq struct {
Channel channels.Channel
}
func (req createChannelReq) validate() error {
if len(req.Channel.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
if req.Channel.ID != "" {
if strings.TrimSpace(req.Channel.ID) == "" {
return apiutil.ErrMissingChannelID
}
}
if req.Channel.Route != "" {
if err := api.ValidateRoute(req.Channel.Route); err != nil {
return err
}
if err := api.ValidateUUID(req.Channel.Route); err == nil {
return apiutil.ErrInvalidRouteFormat
}
}
return nil
}
type createChannelsReq struct {
Channels []channels.Channel
}
func (req createChannelsReq) validate() error {
if len(req.Channels) == 0 {
return apiutil.ErrEmptyList
}
for _, channel := range req.Channels {
if channel.ID != "" {
if strings.TrimSpace(channel.ID) == "" {
return apiutil.ErrMissingChannelID
}
}
if len(channel.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
if channel.Route != "" {
if err := api.ValidateRoute(channel.Route); err != nil {
return err
}
if err := api.ValidateUUID(channel.Route); err == nil {
return apiutil.ErrInvalidRouteFormat
}
}
}
return nil
}
type viewChannelReq struct {
id string
roles bool
}
func (req viewChannelReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type listChannelsReq struct {
channels.Page
userID string
}
func (req listChannelsReq) validate() error {
if req.Limit > api.MaxLimitSize || req.Limit < 1 {
return apiutil.ErrLimitSize
}
if len(req.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
switch req.Order {
case "", api.NameOrder, api.CreatedAtOrder, api.UpdatedAtOrder:
default:
return apiutil.ErrInvalidOrder
}
if req.Dir != "" && (req.Dir != api.DescDir && req.Dir != api.AscDir) {
return apiutil.ErrInvalidDirection
}
if req.ConnectionType != "" {
if _, err := connections.ParseConnType(req.ConnectionType); err != nil {
return apiutil.ErrValidation
}
}
return nil
}
type updateChannelReq struct {
id string
Name string `json:"name,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Tags []string `json:"tags,omitempty"`
}
func (req updateChannelReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
if len(req.Name) > api.MaxNameSize {
return apiutil.ErrNameSize
}
return nil
}
type updateChannelTagsReq struct {
id string
Tags []string `json:"tags,omitempty"`
}
func (req updateChannelTagsReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type setChannelParentGroupReq struct {
id string
ParentGroupID string `json:"parent_group_id"`
}
func (req setChannelParentGroupReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
if req.ParentGroupID == "" {
return apiutil.ErrMissingParentGroupID
}
return nil
}
type removeChannelParentGroupReq struct {
id string
}
func (req removeChannelParentGroupReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type changeChannelStatusReq struct {
id string
}
func (req changeChannelStatusReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
type connectChannelClientsRequest struct {
channelID string
ClientIDs []string `json:"client_ids,omitempty"`
Types []connections.ConnType `json:"types,omitempty"`
}
func (req *connectChannelClientsRequest) validate() error {
if req.channelID == "" || strings.TrimSpace(req.channelID) == "" {
return apiutil.ErrMissingID
}
if len(req.ClientIDs) == 0 {
return apiutil.ErrMissingID
}
for _, tid := range req.ClientIDs {
if err := api.ValidateUUID(tid); err != nil {
return err
}
}
if len(req.Types) == 0 {
return apiutil.ErrMissingConnectionType
}
return nil
}
type disconnectChannelClientsRequest struct {
channelID string
ClientIds []string `json:"client_ids,omitempty"`
Types []connections.ConnType `json:"types,omitempty"`
}
func (req *disconnectChannelClientsRequest) validate() error {
if req.channelID == "" {
return apiutil.ErrMissingID
}
if err := api.ValidateUUID(req.channelID); err != nil {
return err
}
if len(req.ClientIds) == 0 {
return apiutil.ErrMissingID
}
for _, tid := range req.ClientIds {
if err := api.ValidateUUID(tid); err != nil {
return err
}
}
if len(req.Types) == 0 {
return apiutil.ErrMissingConnectionType
}
return nil
}
type connectRequest struct {
ChannelIds []string `json:"channel_ids,omitempty"`
ClientIds []string `json:"client_ids,omitempty"`
Types []connections.ConnType `json:"types,omitempty"`
}
func (req *connectRequest) validate() error {
if len(req.ChannelIds) == 0 {
return apiutil.ErrMissingID
}
for _, cid := range req.ChannelIds {
if strings.TrimSpace(cid) == "" {
return apiutil.ErrMissingChannelID
}
}
if len(req.ClientIds) == 0 {
return apiutil.ErrMissingID
}
for _, tid := range req.ClientIds {
if strings.TrimSpace(tid) == "" {
return apiutil.ErrMissingChannelID
}
}
if len(req.Types) == 0 {
return apiutil.ErrMissingConnectionType
}
return nil
}
type disconnectRequest struct {
ChannelIds []string `json:"channel_ids,omitempty"`
ClientIds []string `json:"client_ids,omitempty"`
Types []connections.ConnType `json:"types,omitempty"`
}
func (req *disconnectRequest) validate() error {
if len(req.ChannelIds) == 0 {
return apiutil.ErrMissingID
}
for _, cid := range req.ChannelIds {
if err := api.ValidateUUID(cid); err != nil {
return err
}
}
if len(req.ClientIds) == 0 {
return apiutil.ErrMissingID
}
for _, tid := range req.ClientIds {
if err := api.ValidateUUID(tid); err != nil {
return err
}
}
if len(req.Types) == 0 {
return apiutil.ErrMissingConnectionType
}
return nil
}
type deleteChannelReq struct {
id string
}
func (req deleteChannelReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
return nil
}
-628
View File
@@ -1,628 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"fmt"
"strings"
"testing"
api "github.com/absmach/magistrala/api/http"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/connections"
"github.com/stretchr/testify/assert"
)
func TestCreateChannelReqValidation(t *testing.T) {
cases := []struct {
desc string
req createChannelReq
err error
}{
{
desc: "valid request",
req: createChannelReq{
Channel: channels.Channel{
Name: valid,
Route: valid,
},
},
err: nil,
},
{
desc: "long name",
req: createChannelReq{
Channel: channels.Channel{
Name: strings.Repeat("a", api.MaxNameSize+1),
Route: valid,
},
},
err: apiutil.ErrNameSize,
},
{
desc: "invalid route",
req: createChannelReq{
Channel: channels.Channel{
Name: valid,
Route: "__invalid",
},
},
err: apiutil.ErrInvalidRouteFormat,
},
{
desc: "uuid as route",
req: createChannelReq{
Channel: channels.Channel{
Name: valid,
Route: testsutil.GenerateUUID(t),
},
},
err: apiutil.ErrInvalidRouteFormat,
},
{
desc: "missing channel ID",
req: createChannelReq{
Channel: channels.Channel{
ID: " ",
},
},
err: apiutil.ErrMissingChannelID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestCreateChannelsReqValidation(t *testing.T) {
cases := []struct {
desc string
req createChannelsReq
err error
}{
{
desc: "valid request",
req: createChannelsReq{
Channels: []channels.Channel{
{
Name: valid,
Route: valid,
},
},
},
err: nil,
},
{
desc: "long name",
req: createChannelsReq{
Channels: []channels.Channel{
{
Name: strings.Repeat("a", api.MaxNameSize+1),
Route: valid,
},
},
},
err: apiutil.ErrNameSize,
},
{
desc: "missing channel ID",
req: createChannelsReq{
Channels: []channels.Channel{
{
ID: " ",
},
},
},
err: apiutil.ErrMissingChannelID,
},
{
desc: "empty list",
req: createChannelsReq{
Channels: []channels.Channel{},
},
err: apiutil.ErrEmptyList,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestViewChannelReqValidation(t *testing.T) {
cases := []struct {
desc string
req viewChannelReq
err error
}{
{
desc: "valid request",
req: viewChannelReq{
id: valid,
},
err: nil,
},
{
desc: "missing ID",
req: viewChannelReq{
id: "",
},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestListChannelsReqValidation(t *testing.T) {
cases := []struct {
desc string
req listChannelsReq
err error
}{
{
desc: "valid request",
req: listChannelsReq{
Page: channels.Page{Limit: 10},
},
err: nil,
},
{
desc: "limit is 0",
req: listChannelsReq{
Page: channels.Page{Limit: 0},
},
err: apiutil.ErrLimitSize,
},
{
desc: "limit is greater than max limit",
req: listChannelsReq{
Page: channels.Page{Limit: api.MaxLimitSize + 1},
},
err: apiutil.ErrLimitSize,
},
{
desc: "name is too long",
req: listChannelsReq{
Page: channels.Page{Limit: 10, Name: strings.Repeat("a", api.MaxNameSize+1)},
},
err: apiutil.ErrNameSize,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestUpdateChannelReqValidate(t *testing.T) {
cases := []struct {
desc string
req updateChannelReq
err error
}{
{
desc: "valid request",
req: updateChannelReq{
id: valid,
},
err: nil,
},
{
desc: "missing ID",
req: updateChannelReq{
id: "",
},
err: apiutil.ErrMissingID,
},
{
desc: "name is too long",
req: updateChannelReq{
id: valid,
Name: strings.Repeat("a", api.MaxNameSize+1),
},
err: apiutil.ErrNameSize,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestUpdateChannelTagsReqValidate(t *testing.T) {
cases := []struct {
desc string
req updateChannelTagsReq
err error
}{
{
desc: "valid request",
req: updateChannelTagsReq{
id: valid,
Tags: []string{"tag1", "tag2"},
},
err: nil,
},
{
desc: "missing ID",
req: updateChannelTagsReq{
id: "",
Tags: []string{"tag1", "tag2"},
},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestSetChannelsParentGroupReqValidate(t *testing.T) {
cases := []struct {
desc string
req setChannelParentGroupReq
err error
}{
{
desc: "valid request",
req: setChannelParentGroupReq{
id: valid,
ParentGroupID: valid,
},
err: nil,
},
{
desc: "missing ID",
req: setChannelParentGroupReq{
id: "",
ParentGroupID: valid,
},
err: apiutil.ErrMissingID,
},
{
desc: "missing parent group ID",
req: setChannelParentGroupReq{
id: valid,
ParentGroupID: "",
},
err: apiutil.ErrMissingParentGroupID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestRemoveChannelParentGroupReqValidate(t *testing.T) {
cases := []struct {
desc string
req removeChannelParentGroupReq
err error
}{
{
desc: "valid request",
req: removeChannelParentGroupReq{
id: valid,
},
err: nil,
},
{
desc: "missing ID",
req: removeChannelParentGroupReq{
id: "",
},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestChangeChannelStatusReqValidate(t *testing.T) {
cases := []struct {
desc string
req changeChannelStatusReq
err error
}{
{
desc: "valid request",
req: changeChannelStatusReq{
id: valid,
},
err: nil,
},
{
desc: "missing ID",
req: changeChannelStatusReq{
id: "",
},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestConnectChannelClientsReqValidate(t *testing.T) {
cases := []struct {
desc string
req connectChannelClientsRequest
err error
}{
{
desc: "valid request",
req: connectChannelClientsRequest{
channelID: valid,
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: nil,
},
{
desc: "missing channel ID",
req: connectChannelClientsRequest{
channelID: "",
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing client IDs",
req: connectChannelClientsRequest{
channelID: valid,
ClientIDs: []string{},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing connection types",
req: connectChannelClientsRequest{
channelID: valid,
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{},
},
err: apiutil.ErrMissingConnectionType,
},
{
desc: "invalid client ID",
req: connectChannelClientsRequest{
channelID: valid,
ClientIDs: []string{"client1", "invalid"},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrInvalidIDFormat,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestDisconnectChannelClientReqValidate(t *testing.T) {
cases := []struct {
desc string
req disconnectChannelClientsRequest
err error
}{
{
desc: "valid request",
req: disconnectChannelClientsRequest{
channelID: testsutil.GenerateUUID(t),
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: nil,
},
{
desc: "missing channel ID",
req: disconnectChannelClientsRequest{
channelID: "",
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "invalid channel ID",
req: disconnectChannelClientsRequest{
channelID: "invalid",
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrInvalidIDFormat,
},
{
desc: "missing client IDs",
req: disconnectChannelClientsRequest{
channelID: testsutil.GenerateUUID(t),
ClientIds: []string{},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing connection types",
req: disconnectChannelClientsRequest{
channelID: testsutil.GenerateUUID(t),
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{},
},
err: apiutil.ErrMissingConnectionType,
},
{
desc: "invalid client ID",
req: disconnectChannelClientsRequest{
channelID: testsutil.GenerateUUID(t),
ClientIds: []string{"client1", "invalid"},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrInvalidIDFormat,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestConnectReqValidate(t *testing.T) {
cases := []struct {
desc string
req connectRequest
err error
}{
{
desc: "valid request",
req: connectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: nil,
},
{
desc: "missing channel IDs",
req: connectRequest{
ChannelIds: []string{},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing client IDs",
req: connectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing connection types",
req: connectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{},
},
err: apiutil.ErrMissingConnectionType,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestDisconnectReqValidate(t *testing.T) {
cases := []struct {
desc string
req disconnectRequest
err error
}{
{
desc: "valid request",
req: disconnectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: nil,
},
{
desc: "missing channel IDs",
req: disconnectRequest{
ChannelIds: []string{},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing client IDs",
req: disconnectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrMissingID,
},
{
desc: "missing connection types",
req: disconnectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{},
},
err: apiutil.ErrMissingConnectionType,
},
{
desc: "invalid client ID",
req: disconnectRequest{
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
ClientIds: []string{"client1", "invalid"},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrInvalidIDFormat,
},
{
desc: "invalid channel ID",
req: disconnectRequest{
ChannelIds: []string{"invalid", testsutil.GenerateUUID(t)},
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
Types: []connections.ConnType{connections.Publish},
},
err: apiutil.ErrInvalidIDFormat,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestDeleteChannelReqValidate(t *testing.T) {
cases := []struct {
desc string
req deleteChannelReq
err error
}{
{
desc: "valid request",
req: deleteChannelReq{
id: valid,
},
err: nil,
},
{
desc: "missing ID",
req: deleteChannelReq{
id: "",
},
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
-221
View File
@@ -1,221 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"fmt"
"net/http"
"github.com/absmach/magistrala"
"github.com/absmach/magistrala/channels"
)
var (
_ magistrala.Response = (*createChannelRes)(nil)
_ magistrala.Response = (*viewChannelRes)(nil)
_ magistrala.Response = (*channelsPageRes)(nil)
_ magistrala.Response = (*updateChannelRes)(nil)
_ magistrala.Response = (*deleteChannelRes)(nil)
_ magistrala.Response = (*connectChannelClientsRes)(nil)
_ magistrala.Response = (*disconnectChannelClientsRes)(nil)
_ magistrala.Response = (*connectRes)(nil)
_ magistrala.Response = (*disconnectRes)(nil)
_ magistrala.Response = (*changeChannelStatusRes)(nil)
)
type pageRes struct {
Limit uint64 `json:"limit,omitempty"`
Offset uint64 `json:"offset,omitempty"`
Total uint64 `json:"total"`
}
type createChannelRes struct {
channels.Channel
created bool
}
func (res createChannelRes) Code() int {
if res.created {
return http.StatusCreated
}
return http.StatusOK
}
func (res createChannelRes) Headers() map[string]string {
if res.created {
return map[string]string{
"Location": fmt.Sprintf("/channels/%s", res.ID),
}
}
return map[string]string{}
}
func (res createChannelRes) Empty() bool {
return false
}
type viewChannelRes struct {
channels.Channel
}
func (res viewChannelRes) Code() int {
return http.StatusOK
}
func (res viewChannelRes) Headers() map[string]string {
return map[string]string{}
}
func (res viewChannelRes) Empty() bool {
return false
}
type channelsPageRes struct {
pageRes
Channels []viewChannelRes `json:"channels,omitempty"`
}
func (res channelsPageRes) Code() int {
return http.StatusOK
}
func (res channelsPageRes) Headers() map[string]string {
return map[string]string{}
}
func (res channelsPageRes) Empty() bool {
return false
}
type changeChannelStatusRes struct {
channels.Channel
}
func (res changeChannelStatusRes) Code() int {
return http.StatusOK
}
func (res changeChannelStatusRes) Headers() map[string]string {
return map[string]string{}
}
func (res changeChannelStatusRes) Empty() bool {
return false
}
type updateChannelRes struct {
channels.Channel
}
func (res updateChannelRes) Code() int {
return http.StatusOK
}
func (res updateChannelRes) Headers() map[string]string {
return map[string]string{}
}
func (res updateChannelRes) Empty() bool {
return false
}
type setChannelParentGroupRes struct{}
func (res setChannelParentGroupRes) Code() int {
return http.StatusOK
}
func (res setChannelParentGroupRes) Headers() map[string]string {
return map[string]string{}
}
func (res setChannelParentGroupRes) Empty() bool {
return true
}
type removeChannelParentGroupRes struct{}
func (res removeChannelParentGroupRes) Code() int {
return http.StatusNoContent
}
func (res removeChannelParentGroupRes) Headers() map[string]string {
return map[string]string{}
}
func (res removeChannelParentGroupRes) Empty() bool {
return true
}
type deleteChannelRes struct{}
func (res deleteChannelRes) Code() int {
return http.StatusNoContent
}
func (res deleteChannelRes) Headers() map[string]string {
return map[string]string{}
}
func (res deleteChannelRes) Empty() bool {
return true
}
type connectChannelClientsRes struct{}
func (res connectChannelClientsRes) Code() int {
return http.StatusCreated
}
func (res connectChannelClientsRes) Headers() map[string]string {
return map[string]string{}
}
func (res connectChannelClientsRes) Empty() bool {
return true
}
type disconnectChannelClientsRes struct{}
func (res disconnectChannelClientsRes) Code() int {
return http.StatusNoContent
}
func (res disconnectChannelClientsRes) Headers() map[string]string {
return map[string]string{}
}
func (res disconnectChannelClientsRes) Empty() bool {
return true
}
type connectRes struct{}
func (res connectRes) Code() int {
return http.StatusCreated
}
func (res connectRes) Headers() map[string]string {
return map[string]string{}
}
func (res connectRes) Empty() bool {
return true
}
type disconnectRes struct{}
func (res disconnectRes) Code() int {
return http.StatusNoContent
}
func (res disconnectRes) Headers() map[string]string {
return map[string]string{}
}
func (res disconnectRes) Empty() bool {
return true
}
-149
View File
@@ -1,149 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"log/slog"
"github.com/absmach/magistrala"
api "github.com/absmach/magistrala/api/http"
apiutil "github.com/absmach/magistrala/api/http/util"
"github.com/absmach/magistrala/channels"
smqauthn "github.com/absmach/magistrala/pkg/authn"
roleManagerHttp "github.com/absmach/magistrala/pkg/roles/rolemanager/api"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
// MakeHandler returns a HTTP handler for Channels API endpoints.
func MakeHandler(svc channels.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp magistrala.IDProvider) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
d := roleManagerHttp.NewDecoder("channelID")
mux.Route("/{domainID}/channels", func(r chi.Router) {
r.Use(authn.Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
createChannelEndpoint(svc),
decodeCreateChannelReq,
api.EncodeResponse,
opts...,
), "create_channel").ServeHTTP)
r.Post("/bulk", otelhttp.NewHandler(kithttp.NewServer(
createChannelsEndpoint(svc),
decodeCreateChannelsReq,
api.EncodeResponse,
opts...,
), "create_channels").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listChannelsEndpoint(svc),
decodeListChannels,
api.EncodeResponse,
opts...,
), "list_channels").ServeHTTP)
r.Post("/connect", otelhttp.NewHandler(kithttp.NewServer(
connectEndpoint(svc),
decodeConnectRequest,
api.EncodeResponse,
opts...,
), "connect").ServeHTTP)
r.Post("/disconnect", otelhttp.NewHandler(kithttp.NewServer(
disconnectEndpoint(svc),
decodeDisconnectRequest,
api.EncodeResponse,
opts...,
), "disconnect").ServeHTTP)
r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts)
r.Route("/{channelID}", func(r chi.Router) {
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
viewChannelEndpoint(svc),
decodeViewChannel,
api.EncodeResponse,
opts...,
), "view_channel").ServeHTTP)
r.Patch("/", otelhttp.NewHandler(kithttp.NewServer(
updateChannelEndpoint(svc),
decodeUpdateChannel,
api.EncodeResponse,
opts...,
), "update_channel_name_and_metadata").ServeHTTP)
r.Patch("/tags", otelhttp.NewHandler(kithttp.NewServer(
updateChannelTagsEndpoint(svc),
decodeUpdateChannelTags,
api.EncodeResponse,
opts...,
), "update_channel_tag").ServeHTTP)
r.Delete("/", otelhttp.NewHandler(kithttp.NewServer(
deleteChannelEndpoint(svc),
decodeDeleteChannelReq,
api.EncodeResponse,
opts...,
), "delete_channel").ServeHTTP)
r.Post("/enable", otelhttp.NewHandler(kithttp.NewServer(
enableChannelEndpoint(svc),
decodeChangeChannelStatus,
api.EncodeResponse,
opts...,
), "enable_channel").ServeHTTP)
r.Post("/disable", otelhttp.NewHandler(kithttp.NewServer(
disableChannelEndpoint(svc),
decodeChangeChannelStatus,
api.EncodeResponse,
opts...,
), "disable_channel").ServeHTTP)
r.Post("/parent", otelhttp.NewHandler(kithttp.NewServer(
setChannelParentGroupEndpoint(svc),
decodeSetChannelParentGroupStatus,
api.EncodeResponse,
opts...,
), "set_channel_parent_group").ServeHTTP)
r.Delete("/parent", otelhttp.NewHandler(kithttp.NewServer(
removeChannelParentGroupEndpoint(svc),
decodeRemoveChannelParentGroupStatus,
api.EncodeResponse,
opts...,
), "remove_channel_parent_group").ServeHTTP)
r.Post("/connect", otelhttp.NewHandler(kithttp.NewServer(
connectChannelClientEndpoint(svc),
decodeConnectChannelClientRequest,
api.EncodeResponse,
opts...,
), "connect_channel_client").ServeHTTP)
r.Post("/disconnect", otelhttp.NewHandler(kithttp.NewServer(
disconnectChannelClientsEndpoint(svc),
decodeDisconnectChannelClientsRequest,
api.EncodeResponse,
opts...,
), "disconnect_channel_client").ServeHTTP)
roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts)
})
})
mux.Get("/health", magistrala.Health("channels", instanceID))
mux.Handle("/metrics", promhttp.Handler())
return mux
}
-7
View File
@@ -1,7 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
import "github.com/absmach/magistrala/pkg/roles"
const BuiltInRoleAdmin roles.BuiltInRoleName = "admin"
-82
View File
@@ -1,82 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache
import (
"context"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/redis/go-redis/v9"
)
var (
ErrEmptyDomainID = errors.New("domain ID is empty")
ErrEmptyChannelID = errors.New("channel ID is empty")
ErrEmptyChannelRoute = errors.New("channel route is empty")
)
type channelsCache struct {
client *redis.Client
duration time.Duration
}
func NewChannelsCache(client *redis.Client, duration time.Duration) channels.Cache {
return &channelsCache{
client: client,
duration: duration,
}
}
func (cc *channelsCache) Save(ctx context.Context, route, domainID, channelID string) error {
key, err := encodeKey(domainID, route)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
if channelID == "" {
return errors.Wrap(repoerr.ErrCreateEntity, ErrEmptyChannelID)
}
if err := cc.client.Set(ctx, key, channelID, cc.duration).Err(); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
return nil
}
func (cc *channelsCache) ID(ctx context.Context, channelRoute, domainID string) (string, error) {
key, err := encodeKey(domainID, channelRoute)
if err != nil {
return "", errors.Wrap(repoerr.ErrNotFound, err)
}
id, err := cc.client.Get(ctx, key).Result()
if err != nil {
return "", errors.Wrap(repoerr.ErrNotFound, err)
}
return id, nil
}
func (cc *channelsCache) Remove(ctx context.Context, channelRoute, domainID string) error {
key, err := encodeKey(domainID, channelRoute)
if err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
if err := cc.client.Del(ctx, key).Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func encodeKey(domainID, channelRoute string) (string, error) {
if domainID == "" {
return "", ErrEmptyDomainID
}
if channelRoute == "" {
return "", ErrEmptyChannelRoute
}
return domainID + ":" + channelRoute, nil
}
-186
View File
@@ -1,186 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/channels/cache"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
)
var (
testRoute = "test-route"
nonExistent = "non-existing"
)
func setupChannelsClient(t *testing.T) channels.Cache {
opts, err := redis.ParseURL(redisURL)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err))
redisClient := redis.NewClient(opts)
return cache.NewChannelsCache(redisClient, 10*time.Minute)
}
func TestSave(t *testing.T) {
cc := setupChannelsClient(t)
route := testRoute
domainID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
domainID string
channelID string
channelRoute string
err error
}{
{
desc: "Save successfully",
domainID: domainID,
channelID: testsutil.GenerateUUID(t),
channelRoute: route,
err: nil,
},
{
desc: "Save with empty domain ID",
domainID: "",
channelID: testsutil.GenerateUUID(t),
channelRoute: route,
err: cache.ErrEmptyDomainID,
},
{
desc: "Save with empty channel ID",
domainID: domainID,
channelID: "",
channelRoute: route,
err: cache.ErrEmptyChannelID,
},
{
desc: "Save with empty channel route",
domainID: domainID,
channelID: testsutil.GenerateUUID(t),
channelRoute: "",
err: cache.ErrEmptyChannelRoute,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := cc.Save(context.Background(), tc.channelRoute, tc.domainID, tc.channelID)
assert.True(t, errors.Contains(err, tc.err))
})
}
}
func TestID(t *testing.T) {
cc := setupChannelsClient(t)
domainID := testsutil.GenerateUUID(t)
route := testRoute
id := testsutil.GenerateUUID(t)
err := cc.Save(context.Background(), route, domainID, id)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on saving channel ID: %s", err))
cases := []struct {
desc string
domainID string
channelRoute string
channelID string
err error
}{
{
desc: "Retrieve existing channel",
domainID: domainID,
channelRoute: route,
channelID: id,
err: nil,
},
{
desc: "Retrieve non-existing channel",
domainID: domainID,
channelRoute: nonExistent,
channelID: "",
err: repoerr.ErrNotFound,
},
{
desc: "Retrieve with empty domain ID",
domainID: "",
channelRoute: route,
channelID: "",
err: cache.ErrEmptyDomainID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
id, err := cc.ID(context.Background(), tc.channelRoute, tc.domainID)
assert.Equal(t, tc.channelID, id, fmt.Sprintf("expected channel ID '%s' got '%s'", tc.channelID, id))
assert.True(t, errors.Contains(err, tc.err))
})
}
}
func TestRemove(t *testing.T) {
cc := setupChannelsClient(t)
domainID := testsutil.GenerateUUID(t)
route := testRoute
id := testsutil.GenerateUUID(t)
err := cc.Save(context.Background(), domainID, route, id)
assert.Nil(t, err, fmt.Sprintf("got unexpected error on saving channel ID: %s", err))
cases := []struct {
desc string
domainID string
channelRoute string
err error
}{
{
desc: "Remove existing channel",
domainID: domainID,
channelRoute: route,
err: nil,
},
{
desc: "Remove non-existing channel",
domainID: domainID,
channelRoute: nonExistent,
err: nil,
},
{
desc: "Remove with empty domain ID",
domainID: "",
channelRoute: route,
err: cache.ErrEmptyDomainID,
},
{
desc: "Remove with empty channel route",
domainID: domainID,
channelRoute: "",
err: cache.ErrEmptyChannelRoute,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := cc.Remove(context.Background(), tc.channelRoute, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
if tc.err == nil {
id, err := cc.ID(context.Background(), tc.channelRoute, tc.domainID)
assert.Equal(t, "", id, fmt.Sprintf("expected channel ID to be empty after removal, got '%s'", id))
assert.True(t, errors.Contains(err, repoerr.ErrNotFound))
}
})
}
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package cache contains the domain concept definitions needed to
// support Magistrala Channels cache service functionality.
package cache
-61
View File
@@ -1,61 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache_test
import (
"context"
"fmt"
"log"
"os"
"testing"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/redis/go-redis/v9"
)
var (
redisClient *redis.Client
redisURL string
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "redis",
Tag: "7.2.4-alpine",
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
opts, err := redis.ParseURL(redisURL)
if err != nil {
log.Fatalf("Could not parse redis URL: %s", err)
}
if err := pool.Retry(func() error {
redisClient = redis.NewClient(opts)
return redisClient.Ping(context.Background()).Err()
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
code := m.Run()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
-240
View File
@@ -1,240 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
import (
"context"
"strings"
"time"
"github.com/absmach/magistrala/internal/nullable"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/roles"
)
// Metadata represents arbitrary JSON.
type Metadata map[string]any
// Channel represents a Magistrala "communication topic". This topic
// contains the clients that can exchange messages between each other.
type Channel struct {
ID string `json:"id"`
Name string `json:"name,omitempty"`
Tags []string `json:"tags,omitempty"`
ParentGroup string `json:"parent_group_id,omitempty"`
Domain string `json:"domain_id,omitempty"`
Route string `json:"route,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
CreatedBy string `json:"created_by,omitempty"`
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
// Extended
ParentGroupPath string `json:"parent_group_path,omitempty"`
RoleID string `json:"role_id,omitempty"`
RoleName string `json:"role_name,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
AccessProviderId string `json:"access_provider_id,omitempty"`
AccessProviderRoleId string `json:"access_provider_role_id,omitempty"`
AccessProviderRoleName string `json:"access_provider_role_name,omitempty"`
AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"`
ConnectionTypes []connections.ConnType `json:"connection_types,omitempty"`
MemberId string `json:"member_id,omitempty"`
Roles []roles.MemberRoleActions `json:"roles,omitempty"`
}
type Operator uint8
const (
OrOp Operator = iota
AndOp
)
type TagsQuery struct {
Elements []string
Operator Operator
}
func ToTagsQuery(s string) TagsQuery {
switch {
case strings.Contains(s, "+"):
elements := strings.Split(s, "+")
for i := range elements {
elements[i] = strings.TrimSpace(elements[i])
}
return TagsQuery{Elements: elements, Operator: AndOp}
case strings.Contains(s, ","):
elements := strings.Split(s, ",")
for i := range elements {
elements[i] = strings.TrimSpace(elements[i])
}
return TagsQuery{Elements: elements, Operator: OrOp}
default:
return TagsQuery{Elements: []string{s}, Operator: OrOp}
}
}
type Page struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
OnlyTotal bool `json:"only_total"`
Order string `json:"order,omitempty"`
Dir string `json:"dir,omitempty"`
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Metadata Metadata `json:"metadata,omitempty"`
Domain string `json:"domain,omitempty"`
Tags TagsQuery `json:"tags,omitempty"`
Status Status `json:"status,omitempty"`
Group nullable.Value[string] `json:"group,omitempty"`
Client string `json:"client,omitempty"`
ConnectionType string `json:"connection_type,omitempty"`
RoleName string `json:"role_name,omitempty"`
RoleID string `json:"role_id,omitempty"`
Actions []string `json:"actions,omitempty"`
AccessType string `json:"access_type,omitempty"`
IDs []string `json:"-"`
CreatedFrom time.Time `json:"created_from,omitempty"`
CreatedTo time.Time `json:"created_to,omitempty"`
}
// ChannelsPage contains page related metadata as well as list of channels that
// belong to this page.
type ChannelsPage struct {
Page
Channels []Channel
}
type Connection struct {
ClientID string
ChannelID string
DomainID string
Type connections.ConnType
}
type AuthzReq struct {
DomainID string
ChannelID string
ClientID string
ClientType string
Type connections.ConnType
}
type Service interface {
// CreateChannels adds channels to the user.
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error)
// ViewChannel retrieves data about the channel identified by the provided
// ID, that belongs to the user.
ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (Channel, error)
// UpdateChannel updates the channel identified by the provided ID, that
// belongs to the user.
UpdateChannel(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
// UpdateChannelTags updates the channel's tags.
UpdateChannelTags(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
EnableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
// ListChannels retrieves data about subset of channels that belongs to the user.
ListChannels(ctx context.Context, session authn.Session, pm Page) (ChannelsPage, error)
// ListUserChannels retrieves data about subset of channels that belong to the specified user.
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm Page) (ChannelsPage, error)
// RemoveChannel removes the client identified by the provided ID, that
// belongs to the user.
RemoveChannel(ctx context.Context, session authn.Session, id string) error
// Connect adds clients to the channels list of connected clients.
Connect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connType []connections.ConnType) error
// Disconnect removes clients from the channels list of connected clients.
Disconnect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connType []connections.ConnType) error
SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error
RemoveParentGroup(ctx context.Context, session authn.Session, id string) error
roles.RoleManager
}
// ChannelRepository specifies a channel persistence API.
type Repository interface {
// Save persists multiple channels. Channels are saved using a transaction. If one channel
// fails then none will be saved. Successful operation is indicated by non-nil error response.
Save(ctx context.Context, chs ...Channel) ([]Channel, error)
// Update performs an update to the existing channel.
Update(ctx context.Context, c Channel) (Channel, error)
UpdateTags(ctx context.Context, ch Channel) (Channel, error)
ChangeStatus(ctx context.Context, channel Channel) (Channel, error)
// RetrieveUserChannels retrieves the channel of given domainID and userID.
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm Page) (ChannelsPage, error)
// RetrieveByID retrieves the channel having the provided identifier
RetrieveByID(ctx context.Context, id string) (Channel, error)
// RetrieveByRoute retrieves the channel having the provided route
RetrieveByRoute(ctx context.Context, route, domainID string) (Channel, error)
// RetrieveByIDWithRoles retrieves channel by its unique ID along with member roles.
RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (Channel, error)
// RetrieveAll retrieves the subset of channels.
RetrieveAll(ctx context.Context, pm Page) (ChannelsPage, error)
// Remove removes the channel having the provided identifier
Remove(ctx context.Context, ids ...string) error
// SetParentGroup set parent group id to a given channel id
SetParentGroup(ctx context.Context, ch Channel) error
// RemoveParentGroup remove parent group id fr given chanel id
RemoveParentGroup(ctx context.Context, ch Channel) error
AddConnections(ctx context.Context, conns []Connection) error
RemoveConnections(ctx context.Context, conns []Connection) error
CheckConnection(ctx context.Context, conn Connection) error
ClientAuthorize(ctx context.Context, conn Connection) error
ChannelConnectionsCount(ctx context.Context, id string) (uint64, error)
DoesChannelHaveConnections(ctx context.Context, id string) (bool, error)
RemoveClientConnections(ctx context.Context, clientID string) error
RemoveChannelConnections(ctx context.Context, channelID string) error
RetrieveParentGroupChannels(ctx context.Context, parentGroupID string) ([]Channel, error)
UnsetParentGroupFromChannels(ctx context.Context, parentGroupID string) error
roles.Repository
}
// Cache contains channels caching interface.
type Cache interface {
// Save stores the channelID for the given domain ID and channel route.
Save(ctx context.Context, channelRoute, domainID, channelID string) error
// ID retrieves the channelID for the given domain ID and channel route.
ID(ctx context.Context, channelRoute, domainID string) (string, error)
// Remove removes the channel ID for the given domain ID and channel route.
Remove(ctx context.Context, channelRoute, domainID string) error
}
-17
View File
@@ -1,17 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
import "errors"
var (
// ErrInvalidStatus indicates invalid status.
ErrInvalidStatus = errors.New("invalid channels status")
// ErrEnableChannel indicates error in enabling channel.
ErrEnableChannel = errors.New("failed to enable channel")
// ErrDisableChannel indicates error in disabling channel.
ErrDisableChannel = errors.New("failed to disable channel")
)
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package events provides the domain concept definitions
// needed to support clients events functionality.
package events
-384
View File
@@ -1,384 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package events
import (
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/roles"
)
const (
channelPrefix = "channel."
channelCreate = channelPrefix + "create"
channelUpdate = channelPrefix + "update"
channelUpdateTags = channelPrefix + "update_tags"
channelEnable = channelPrefix + "enable"
channelDisable = channelPrefix + "disable"
channelRemove = channelPrefix + "remove"
channelView = channelPrefix + "view"
channelList = channelPrefix + "list"
channelListByUser = channelPrefix + "list_by_user"
channelConnect = channelPrefix + "connect"
channelDisconnect = channelPrefix + "disconnect"
channelSetParent = channelPrefix + "set_parent"
channelRemoveParent = channelPrefix + "remove_parent"
)
var (
_ events.Event = (*createChannelEvent)(nil)
_ events.Event = (*updateChannelEvent)(nil)
_ events.Event = (*changeChannelStatusEvent)(nil)
_ events.Event = (*viewChannelEvent)(nil)
_ events.Event = (*listChannelEvent)(nil)
_ events.Event = (*removeChannelEvent)(nil)
_ events.Event = (*connectEvent)(nil)
_ events.Event = (*disconnectEvent)(nil)
)
type createChannelEvent struct {
channels.Channel
rolesProvisioned []roles.RoleProvision
authn.Session
requestID string
}
func (cce createChannelEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": channelCreate,
"id": cce.ID,
"roles_provisioned": cce.rolesProvisioned,
"route": cce.Route,
"status": cce.Status.String(),
"created_at": cce.CreatedAt,
"domain": cce.DomainID,
"user_id": cce.UserID,
"token_type": cce.Type.String(),
"super_admin": cce.SuperAdmin,
"request_id": cce.requestID,
}
if cce.Name != "" {
val["name"] = cce.Name
}
if len(cce.Tags) > 0 {
val["tags"] = cce.Tags
}
if cce.Metadata != nil {
val["metadata"] = cce.Metadata
}
return val, nil
}
type updateChannelEvent struct {
channels.Channel
authn.Session
operation string
requestID string
}
func (uce updateChannelEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": uce.operation,
"updated_at": uce.UpdatedAt,
"updated_by": uce.UpdatedBy,
"domain": uce.DomainID,
"user_id": uce.UserID,
"token_type": uce.Type.String(),
"super_admin": uce.SuperAdmin,
"request_id": uce.requestID,
}
if uce.ID != "" {
val["id"] = uce.ID
}
if uce.Route != "" {
val["route"] = uce.Route
}
if uce.Name != "" {
val["name"] = uce.Name
}
if len(uce.Tags) > 0 {
val["tags"] = uce.Tags
}
if uce.Metadata != nil {
val["metadata"] = uce.Metadata
}
if !uce.CreatedAt.IsZero() {
val["created_at"] = uce.CreatedAt
}
if uce.Status.String() != "" {
val["status"] = uce.Status.String()
}
return val, nil
}
type changeChannelStatusEvent struct {
id string
operation string
status string
updatedAt time.Time
updatedBy string
authn.Session
requestID string
}
func (cse changeChannelStatusEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": cse.operation,
"id": cse.id,
"status": cse.status,
"updated_at": cse.updatedAt,
"updated_by": cse.updatedBy,
"domain": cse.DomainID,
"user_id": cse.UserID,
"token_type": cse.Type.String(),
"super_admin": cse.SuperAdmin,
"request_id": cse.requestID,
}, nil
}
type viewChannelEvent struct {
channels.Channel
authn.Session
requestID string
}
func (vce viewChannelEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": channelView,
"id": vce.ID,
"domain": vce.DomainID,
"user_id": vce.UserID,
"token_type": vce.Type.String(),
"super_admin": vce.SuperAdmin,
"request_id": vce.requestID,
}
if vce.Name != "" {
val["name"] = vce.Name
}
if vce.Route != "" {
val["route"] = vce.Route
}
if len(vce.Tags) > 0 {
val["tags"] = vce.Tags
}
if vce.Metadata != nil {
val["metadata"] = vce.Metadata
}
if !vce.CreatedAt.IsZero() {
val["created_at"] = vce.CreatedAt
}
if !vce.UpdatedAt.IsZero() {
val["updated_at"] = vce.UpdatedAt
}
if vce.UpdatedBy != "" {
val["updated_by"] = vce.UpdatedBy
}
if vce.Status.String() != "" {
val["status"] = vce.Status.String()
}
return val, nil
}
type listChannelEvent struct {
channels.Page
authn.Session
requestID string
}
func (lce listChannelEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": channelList,
"total": lce.Total,
"offset": lce.Offset,
"limit": lce.Limit,
"domain": lce.DomainID,
"user_id": lce.UserID,
"token_type": lce.Type.String(),
"super_admin": lce.SuperAdmin,
"request_id": lce.requestID,
}
if lce.Name != "" {
val["name"] = lce.Name
}
if lce.Order != "" {
val["order"] = lce.Order
}
if lce.Dir != "" {
val["dir"] = lce.Dir
}
if lce.Metadata != nil {
val["metadata"] = lce.Metadata
}
if len(lce.Tags.Elements) > 0 {
val["tag"] = lce.Tags.Elements
}
if lce.Status.String() != "" {
val["status"] = lce.Status.String()
}
if len(lce.IDs) > 0 {
val["ids"] = lce.IDs
}
return val, nil
}
type listUserChannelsEvent struct {
userID string
channels.Page
authn.Session
requestID string
}
func (luce listUserChannelsEvent) Encode() (map[string]any, error) {
val := map[string]any{
"operation": channelListByUser,
"req_user_id": luce.userID,
"total": luce.Total,
"offset": luce.Offset,
"limit": luce.Limit,
"domain": luce.DomainID,
"user_id": luce.UserID,
"token_type": luce.Type.String(),
"super_admin": luce.SuperAdmin,
"request_id": luce.requestID,
}
if luce.Name != "" {
val["name"] = luce.Name
}
if luce.Order != "" {
val["order"] = luce.Order
}
if luce.Dir != "" {
val["dir"] = luce.Dir
}
if luce.Metadata != nil {
val["metadata"] = luce.Metadata
}
if luce.Domain != "" {
val["domain"] = luce.Domain
}
if len(luce.Tags.Elements) > 0 {
val["tag"] = luce.Tags.Elements
}
if luce.Status.String() != "" {
val["status"] = luce.Status.String()
}
if len(luce.IDs) > 0 {
val["ids"] = luce.IDs
}
return val, nil
}
type removeChannelEvent struct {
id string
authn.Session
requestID string
}
func (dce removeChannelEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": channelRemove,
"id": dce.id,
"domain": dce.DomainID,
"user_id": dce.UserID,
"token_type": dce.Type.String(),
"super_admin": dce.SuperAdmin,
"request_id": dce.requestID,
}, nil
}
type connectEvent struct {
chIDs []string
thIDs []string
types []connections.ConnType
authn.Session
requestID string
}
func (ce connectEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": channelConnect,
"client_ids": ce.thIDs,
"channel_ids": ce.chIDs,
"types": ce.types,
"domain": ce.DomainID,
"user_id": ce.UserID,
"token_type": ce.Type.String(),
"super_admin": ce.SuperAdmin,
"request_id": ce.requestID,
}, nil
}
type disconnectEvent struct {
chIDs []string
thIDs []string
types []connections.ConnType
authn.Session
requestID string
}
func (de disconnectEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": channelDisconnect,
"client_ids": de.thIDs,
"channel_ids": de.chIDs,
"types": de.types,
"domain": de.DomainID,
"user_id": de.UserID,
"token_type": de.Type.String(),
"super_admin": de.SuperAdmin,
"request_id": de.requestID,
}, nil
}
type setParentGroupEvent struct {
id string
parentGroupID string
authn.Session
requestID string
}
func (spge setParentGroupEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": channelSetParent,
"id": spge.id,
"parent_group_id": spge.parentGroupID,
"domain": spge.DomainID,
"user_id": spge.UserID,
"token_type": spge.Type.String(),
"super_admin": spge.SuperAdmin,
"request_id": spge.requestID,
}, nil
}
type removeParentGroupEvent struct {
id string
authn.Session
requestID string
}
func (rpge removeParentGroupEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": channelRemoveParent,
"id": rpge.id,
"domain": rpge.DomainID,
"user_id": rpge.UserID,
"token_type": rpge.Type.String(),
"super_admin": rpge.SuperAdmin,
"request_id": rpge.requestID,
}, nil
}
-300
View File
@@ -1,300 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package events
import (
"context"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/events"
"github.com/absmach/magistrala/pkg/events/store"
"github.com/absmach/magistrala/pkg/roles"
rmEvents "github.com/absmach/magistrala/pkg/roles/rolemanager/events"
"github.com/go-chi/chi/v5/middleware"
)
const (
magistralaPrefix = "magistrala."
createStream = magistralaPrefix + channelCreate
updateStream = magistralaPrefix + channelUpdate
updateTagsStream = magistralaPrefix + channelUpdateTags
enableStream = magistralaPrefix + channelEnable
disableStream = magistralaPrefix + channelDisable
removeStream = magistralaPrefix + channelRemove
viewStream = magistralaPrefix + channelView
listStream = magistralaPrefix + channelList
listByUserStream = magistralaPrefix + channelListByUser
connectStream = magistralaPrefix + channelConnect
disconnectStream = magistralaPrefix + channelDisconnect
setParentStream = magistralaPrefix + channelSetParent
removeParentStream = magistralaPrefix + channelRemoveParent
)
var _ channels.Service = (*eventStore)(nil)
type eventStore struct {
events.Publisher
svc channels.Service
rmEvents.RoleManagerEventStore
}
// NewEventStoreMiddleware returns wrapper around clients service that sends
// events to event store.
func NewEventStoreMiddleware(ctx context.Context, svc channels.Service, url string) (channels.Service, error) {
publisher, err := store.NewPublisher(ctx, url, "channels-es-pub")
if err != nil {
return nil, err
}
rolesSvcEventStoreMiddleware := rmEvents.NewRoleManagerEventStore("channels", channelPrefix, svc, publisher)
return &eventStore{
svc: svc,
Publisher: publisher,
RoleManagerEventStore: rolesSvcEventStoreMiddleware,
}, nil
}
func (es *eventStore) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
chs, rps, err := es.svc.CreateChannels(ctx, session, chs...)
if err != nil {
return chs, rps, err
}
for _, ch := range chs {
event := createChannelEvent{
Channel: ch,
rolesProvisioned: rps,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, createStream, event); err != nil {
return chs, rps, err
}
}
return chs, rps, nil
}
func (es *eventStore) UpdateChannel(ctx context.Context, session authn.Session, ch channels.Channel) (channels.Channel, error) {
ch, err := es.svc.UpdateChannel(ctx, session, ch)
if err != nil {
return ch, err
}
event := updateChannelEvent{
Channel: ch,
Session: session,
operation: channelUpdate,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, updateStream, event); err != nil {
return ch, err
}
return ch, nil
}
func (es *eventStore) UpdateChannelTags(ctx context.Context, session authn.Session, ch channels.Channel) (channels.Channel, error) {
ch, err := es.svc.UpdateChannelTags(ctx, session, ch)
if err != nil {
return ch, err
}
event := updateChannelEvent{
Channel: ch,
Session: session,
operation: channelUpdateTags,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, updateTagsStream, event); err != nil {
return ch, err
}
return ch, nil
}
func (es *eventStore) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
chann, err := es.svc.ViewChannel(ctx, session, id, withRoles)
if err != nil {
return chann, err
}
event := viewChannelEvent{
Channel: chann,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, viewStream, event); err != nil {
return chann, err
}
return chann, nil
}
func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
cp, err := es.svc.ListChannels(ctx, session, pm)
if err != nil {
return cp, err
}
event := listChannelEvent{
Page: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, listStream, event); err != nil {
return cp, err
}
return cp, nil
}
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
cp, err := es.svc.ListUserChannels(ctx, session, userID, pm)
if err != nil {
return cp, err
}
event := listUserChannelsEvent{
userID: userID,
Page: pm,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, listByUserStream, event); err != nil {
return cp, err
}
return cp, nil
}
func (es *eventStore) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
ch, err := es.svc.EnableChannel(ctx, session, id)
if err != nil {
return ch, err
}
return es.changeStatus(ctx, session, channelEnable, enableStream, ch)
}
func (es *eventStore) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
ch, err := es.svc.DisableChannel(ctx, session, id)
if err != nil {
return ch, err
}
return es.changeStatus(ctx, session, channelDisable, disableStream, ch)
}
func (es *eventStore) changeStatus(ctx context.Context, session authn.Session, operation, stream string, ch channels.Channel) (channels.Channel, error) {
event := changeChannelStatusEvent{
id: ch.ID,
operation: operation,
updatedAt: ch.UpdatedAt,
updatedBy: ch.UpdatedBy,
status: ch.Status.String(),
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, stream, event); err != nil {
return ch, err
}
return ch, nil
}
func (es *eventStore) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if err := es.svc.RemoveChannel(ctx, session, id); err != nil {
return err
}
event := removeChannelEvent{
id: id,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, removeStream, event); err != nil {
return err
}
return nil
}
func (es *eventStore) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if err := es.svc.Connect(ctx, session, chIDs, thIDs, connTypes); err != nil {
return err
}
event := connectEvent{
chIDs: chIDs,
thIDs: thIDs,
types: connTypes,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, connectStream, event); err != nil {
return err
}
return nil
}
func (es *eventStore) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if err := es.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes); err != nil {
return err
}
event := disconnectEvent{
chIDs: chIDs,
thIDs: thIDs,
types: connTypes,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, disconnectStream, event); err != nil {
return err
}
return nil
}
func (es *eventStore) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
if err := es.svc.SetParentGroup(ctx, session, parentGroupID, id); err != nil {
return err
}
event := setParentGroupEvent{
parentGroupID: parentGroupID,
id: id,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, setParentStream, event); err != nil {
return err
}
return nil
}
func (es *eventStore) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
if err := es.svc.RemoveParentGroup(ctx, session, id); err != nil {
return err
}
event := removeParentGroupEvent{
id: id,
Session: session,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, removeParentStream, event); err != nil {
return err
}
return nil
}
-669
View File
@@ -1,669 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package events_test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/channels/events"
"github.com/absmach/magistrala/channels/mocks"
"github.com/absmach/magistrala/internal/testsutil"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/roles"
"github.com/go-chi/chi/v5/middleware"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
storeClient *redis.Client
storeURL string
validSession = authn.Session{
DomainID: testsutil.GenerateUUID(&testing.T{}),
UserID: testsutil.GenerateUUID(&testing.T{}),
}
validChannel = generateTestChannel(&testing.T{})
validChannelsPage = channels.ChannelsPage{
Page: channels.Page{
Limit: 10,
Offset: 0,
Total: 1,
},
Channels: []channels.Channel{validChannel},
}
)
func newEventStoreMiddleware(t *testing.T) (*mocks.Service, channels.Service) {
svc := new(mocks.Service)
nsvc, err := events.NewEventStoreMiddleware(context.Background(), svc, storeURL)
require.Nil(t, err, fmt.Sprintf("create events store middleware failed with unexpected error: %s", err))
return svc, nsvc
}
func TestMain(m *testing.M) {
code := testsutil.RunRedisTest(m, &storeClient, &storeURL)
os.Exit(code)
}
func TestCreateChannels(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validID := testsutil.GenerateUUID(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, validID)
cases := []struct {
desc string
session authn.Session
channels []channels.Channel
svcRes []channels.Channel
svcRoleRes []roles.RoleProvision
svcErr error
resp []channels.Channel
respRoleRes []roles.RoleProvision
err error
}{
{
desc: "publish successfully",
session: validSession,
channels: []channels.Channel{validChannel},
svcRes: []channels.Channel{validChannel},
svcRoleRes: []roles.RoleProvision{},
svcErr: nil,
resp: []channels.Channel{validChannel},
respRoleRes: []roles.RoleProvision{},
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channels: []channels.Channel{validChannel},
svcRes: []channels.Channel{},
svcRoleRes: []roles.RoleProvision{},
svcErr: svcerr.ErrCreateEntity,
resp: []channels.Channel{},
respRoleRes: []roles.RoleProvision{},
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("CreateChannels", validCtx, tc.session, tc.channels).Return(tc.svcRes, tc.svcRoleRes, tc.svcErr)
resp, respRoleRes, err := nsvc.CreateChannels(validCtx, tc.session, tc.channels...)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
assert.Equal(t, tc.respRoleRes, respRoleRes, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.respRoleRes, respRoleRes))
svcCall.Unset()
})
}
}
func TestViewChannel(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
channelID string
withRoles bool
svcRes channels.Channel
svcErr error
resp channels.Channel
err error
}{
{
desc: "publish successfully",
session: validSession,
channelID: validChannel.ID,
withRoles: false,
svcRes: validChannel,
svcErr: nil,
resp: validChannel,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channelID: validChannel.ID,
withRoles: false,
svcRes: channels.Channel{},
svcErr: svcerr.ErrViewEntity,
resp: channels.Channel{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("ViewChannel", validCtx, tc.session, tc.channelID, tc.withRoles).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.ViewChannel(validCtx, tc.session, tc.channelID, tc.withRoles)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestUpdateChannel(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
updatedChannel := validChannel
updatedChannel.Name = "updatedName"
cases := []struct {
desc string
session authn.Session
channel channels.Channel
svcRes channels.Channel
svcErr error
resp channels.Channel
err error
}{
{
desc: "publish successfully",
session: validSession,
channel: updatedChannel,
svcRes: updatedChannel,
svcErr: nil,
resp: updatedChannel,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channel: updatedChannel,
svcRes: channels.Channel{},
svcErr: svcerr.ErrUpdateEntity,
resp: channels.Channel{},
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("UpdateChannel", validCtx, tc.session, tc.channel).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.UpdateChannel(validCtx, tc.session, tc.channel)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestUpdateChannelTags(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
updatedChannel := validChannel
updatedChannel.Tags = []string{"newTag1", "newTag2"}
cases := []struct {
desc string
session authn.Session
channel channels.Channel
svcRes channels.Channel
svcErr error
resp channels.Channel
err error
}{
{
desc: "publish successfully",
session: validSession,
channel: updatedChannel,
svcRes: updatedChannel,
svcErr: nil,
resp: updatedChannel,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channel: updatedChannel,
svcRes: channels.Channel{},
svcErr: svcerr.ErrUpdateEntity,
resp: channels.Channel{},
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("UpdateChannelTags", validCtx, tc.session, tc.channel).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.UpdateChannelTags(validCtx, tc.session, tc.channel)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestEnableChannel(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
channelID string
svcRes channels.Channel
svcErr error
resp channels.Channel
err error
}{
{
desc: "publish successfully",
session: validSession,
channelID: validChannel.ID,
svcRes: validChannel,
svcErr: nil,
resp: validChannel,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channelID: validChannel.ID,
svcRes: channels.Channel{},
svcErr: svcerr.ErrUpdateEntity,
resp: channels.Channel{},
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("EnableChannel", validCtx, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.EnableChannel(validCtx, tc.session, tc.channelID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestDisableChannel(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
channelID string
svcRes channels.Channel
svcErr error
resp channels.Channel
err error
}{
{
desc: "publish successfully",
session: validSession,
channelID: validChannel.ID,
svcRes: validChannel,
svcErr: nil,
resp: validChannel,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channelID: validChannel.ID,
svcRes: channels.Channel{},
svcErr: svcerr.ErrUpdateEntity,
resp: channels.Channel{},
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("DisableChannel", validCtx, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.DisableChannel(validCtx, tc.session, tc.channelID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestListChannels(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
pageMeta channels.Page
svcRes channels.ChannelsPage
svcErr error
resp channels.ChannelsPage
err error
}{
{
desc: "publish successfully",
session: validSession,
pageMeta: channels.Page{
Limit: 10,
Offset: 0,
},
svcRes: validChannelsPage,
svcErr: nil,
resp: validChannelsPage,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
pageMeta: channels.Page{
Limit: 10,
Offset: 0,
},
svcRes: channels.ChannelsPage{},
svcErr: svcerr.ErrViewEntity,
resp: channels.ChannelsPage{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("ListChannels", validCtx, tc.session, tc.pageMeta).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.ListChannels(validCtx, tc.session, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestListUserChannels(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
userID string
pageMeta channels.Page
svcRes channels.ChannelsPage
svcErr error
resp channels.ChannelsPage
err error
}{
{
desc: "publish successfully",
session: validSession,
userID: validSession.UserID,
pageMeta: channels.Page{
Limit: 10,
Offset: 0,
},
svcRes: validChannelsPage,
svcErr: nil,
resp: validChannelsPage,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
userID: validSession.UserID,
pageMeta: channels.Page{
Limit: 10,
Offset: 0,
},
svcRes: channels.ChannelsPage{},
svcErr: svcerr.ErrViewEntity,
resp: channels.ChannelsPage{},
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("ListUserChannels", validCtx, tc.session, tc.userID, tc.pageMeta).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.ListUserChannels(validCtx, tc.session, tc.userID, tc.pageMeta)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestRemoveChannel(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
channelID string
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
channelID: validChannel.ID,
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channelID: validChannel.ID,
svcErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RemoveChannel", validCtx, tc.session, tc.channelID).Return(tc.svcErr)
err := nsvc.RemoveChannel(validCtx, tc.session, tc.channelID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestConnect(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
chIDs []string
clIDs []string
connTypes []connections.ConnType
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
chIDs: []string{validChannel.ID},
clIDs: []string{testsutil.GenerateUUID(t)},
connTypes: []connections.ConnType{connections.Publish},
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
chIDs: []string{validChannel.ID},
clIDs: []string{testsutil.GenerateUUID(t)},
connTypes: []connections.ConnType{connections.Publish},
svcErr: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Connect", validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes).Return(tc.svcErr)
err := nsvc.Connect(validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestDisconnect(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
chIDs []string
clIDs []string
connTypes []connections.ConnType
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
chIDs: []string{validChannel.ID},
clIDs: []string{testsutil.GenerateUUID(t)},
connTypes: []connections.ConnType{connections.Publish},
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
chIDs: []string{validChannel.ID},
clIDs: []string{testsutil.GenerateUUID(t)},
connTypes: []connections.ConnType{connections.Publish},
svcErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Disconnect", validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes).Return(tc.svcErr)
err := nsvc.Disconnect(validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestSetParentGroup(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
parentGroupID string
channelID string
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
parentGroupID: testsutil.GenerateUUID(t),
channelID: validChannel.ID,
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
parentGroupID: testsutil.GenerateUUID(t),
channelID: validChannel.ID,
svcErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("SetParentGroup", validCtx, tc.session, tc.parentGroupID, tc.channelID).Return(tc.svcErr)
err := nsvc.SetParentGroup(validCtx, tc.session, tc.parentGroupID, tc.channelID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestRemoveParentGroup(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
cases := []struct {
desc string
session authn.Session
channelID string
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
channelID: validChannel.ID,
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
channelID: validChannel.ID,
svcErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RemoveParentGroup", validCtx, tc.session, tc.channelID).Return(tc.svcErr)
err := nsvc.RemoveParentGroup(validCtx, tc.session, tc.channelID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func generateTestChannel(t *testing.T) channels.Channel {
createdAt, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z")
assert.Nil(t, err, fmt.Sprintf("Unexpected error parsing time: %v", err))
return channels.Channel{
ID: testsutil.GenerateUUID(t),
Name: "channelname",
Domain: testsutil.GenerateUUID(t),
Tags: []string{"tag1", "tag2"},
Metadata: channels.Metadata{"key1": "value1"},
CreatedAt: createdAt,
UpdatedAt: createdAt,
Status: channels.EnabledStatus,
}
}
-383
View File
@@ -1,383 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"fmt"
"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/channels/operations"
cOperations "github.com/absmach/magistrala/clients/operations"
dOperations "github.com/absmach/magistrala/domains/operations"
gOperations "github.com/absmach/magistrala/groups/operations"
"github.com/absmach/magistrala/pkg/authn"
smqauthz "github.com/absmach/magistrala/pkg/authz"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/permissions"
"github.com/absmach/magistrala/pkg/policies"
"github.com/absmach/magistrala/pkg/roles"
rolemgr "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
)
var (
errView = errors.New("not authorized to view channel")
errList = errors.New("not authorized to list user channels")
errUpdate = errors.New("not authorized to update channel")
errUpdateTags = errors.New("not authorized to update channel tags")
errEnable = errors.New("not authorized to enable channel")
errDisable = errors.New("not authorized to disable channel")
errDelete = errors.New("not authorized to delete channel")
errConnect = errors.New("not authorized to connect to channel")
errDisconnect = errors.New("not authorized to disconnect from channel")
errSetParentGroup = errors.New("not authorized to set parent group to channel")
errRemoveParentGroup = errors.New("not authorized to remove parent group from channel")
errDomainCreateChannels = errors.New("not authorized to create channel in domain")
errGroupSetChildChannels = errors.New("not authorized to set child channel for group")
errGroupRemoveChildChannels = errors.New("not authorized to remove child channel for group")
errClientDisConnectChannels = errors.New("not authorized to disconnect channel for client")
errClientConnectChannels = errors.New("not authorized to connect channel for client")
)
var _ channels.Service = (*authorizationMiddleware)(nil)
type authorizationMiddleware struct {
svc channels.Service
repo channels.Repository
authz smqauthz.Authorization
entitiesOps permissions.EntitiesOperations[permissions.Operation]
rolemgr.RoleManagerAuthorizationMiddleware
}
// NewAuthorization adds authorization to the channels service.
func NewAuthorization(
entityType string,
svc channels.Service,
authz smqauthz.Authorization,
repo channels.Repository,
entitiesOps permissions.EntitiesOperations[permissions.Operation],
roleOps permissions.Operations[permissions.RoleOperation],
) (channels.Service, error) {
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
ram, err := rolemgr.NewAuthorization(policies.ChannelType, svc, authz, roleOps)
if err != nil {
return nil, err
}
return &authorizationMiddleware{
svc: svc,
authz: authz,
repo: repo,
entitiesOps: entitiesOps,
RoleManagerAuthorizationMiddleware: ram,
}, nil
}
func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.DomainType,
Object: session.DomainID,
}); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errDomainCreateChannels)
}
for _, ch := range chs {
if ch.ParentGroup != "" {
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.GroupType,
Object: ch.ParentGroup,
}); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errors.Wrap(errGroupSetChildChannels, fmt.Errorf("channel name %s parent group id %s", ch.Name, ch.ParentGroup)))
}
}
}
return am.svc.CreateChannels(ctx, session, chs...)
}
func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpViewChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errView)
}
return am.svc.ViewChannel(ctx, session, id, withRoles)
}
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
switch err := am.checkSuperAdmin(ctx, session); {
case err == nil:
session.SuperAdmin = true
case errors.Contains(err, svcerr.ErrSuperAdminAction):
default:
return channels.ChannelsPage{}, err
}
return am.svc.ListChannels(ctx, session, pm)
}
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
if err := am.checkSuperAdmin(ctx, session); err != nil {
return channels.ChannelsPage{}, errors.Wrap(err, errList)
}
return am.svc.ListUserChannels(ctx, session, userID, pm)
}
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: channel.ID,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errUpdate)
}
return am.svc.UpdateChannel(ctx, session, channel)
}
func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: channel.ID,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errUpdateTags)
}
return am.svc.UpdateChannelTags(ctx, session, channel)
}
func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpEnableChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errEnable)
}
return am.svc.EnableChannel(ctx, session, id)
}
func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisableChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errDisable)
}
return am.svc.DisableChannel(ctx, session, id)
}
func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDeleteChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return errors.Wrap(err, errDelete)
}
return am.svc.RemoveChannel(ctx, session, id)
}
func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
for _, chID := range chIDs {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpConnectClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: chID,
}); err != nil {
return errors.Wrap(err, errConnect)
}
}
for _, thID := range thIDs {
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpConnectToChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ClientType,
Object: thID,
}); err != nil {
return errors.Wrap(err, errClientConnectChannels)
}
}
return am.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
for _, chID := range chIDs {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisconnectClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: chID,
}); err != nil {
return errors.Wrap(err, errDisconnect)
}
}
for _, thID := range thIDs {
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpDisconnectFromChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ClientType,
Object: thID,
}); err != nil {
return errors.Wrap(err, errClientDisConnectChannels)
}
}
return am.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return errors.Wrap(err, errSetParentGroup)
}
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.GroupType,
Object: parentGroupID,
}); err != nil {
return errors.Wrap(err, errGroupSetChildChannels)
}
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return errors.Wrap(err, errRemoveParentGroup)
}
ch, err := am.repo.RetrieveByID(ctx, id)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
if ch.ParentGroup != "" {
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupRemoveChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.GroupType,
Object: ch.ParentGroup,
}); err != nil {
return errors.Wrap(err, errGroupRemoveChildChannels)
}
return am.svc.RemoveParentGroup(ctx, session, id)
}
return nil
}
func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, req smqauthz.PolicyReq) error {
req.Domain = session.DomainID
perm, err := am.entitiesOps.GetPermission(entityType, op)
if err != nil {
return err
}
req.Permission = perm.String()
var pat *smqauthz.PATReq
if session.PatID != "" {
entityID := req.Object
opName := am.entitiesOps.OperationName(entityType, op)
if op == operations.OpListUserChannels || op == dOperations.OpCreateDomainChannels || op == dOperations.OpListDomainChannels {
entityID = auth.AnyIDs
}
pat = &smqauthz.PATReq{
UserID: session.UserID,
PatID: session.PatID,
EntityID: entityID,
EntityType: patEntityType(entityType),
Operation: opName,
Domain: session.DomainID,
}
}
if err := am.authz.Authorize(ctx, req, pat); err != nil {
return err
}
return nil
}
func patEntityType(entityType string) string {
switch entityType {
case policies.ClientType:
return auth.ClientsType.String()
default:
return auth.ChannelsType.String()
}
}
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
if session.Role != authn.SuperAdminRole {
return svcerr.ErrSuperAdminAction
}
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
SubjectType: policies.UserType,
Subject: session.UserID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.MagistralaObject,
}, nil); err != nil {
return err
}
return nil
}
-247
View File
@@ -1,247 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/channels/operations"
dOperations "github.com/absmach/magistrala/domains/operations"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/callout"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/errors"
svcerr "github.com/absmach/magistrala/pkg/errors/service"
"github.com/absmach/magistrala/pkg/permissions"
"github.com/absmach/magistrala/pkg/policies"
"github.com/absmach/magistrala/pkg/roles"
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
)
var _ channels.Service = (*calloutMiddleware)(nil)
type calloutMiddleware struct {
svc channels.Service
repo channels.Repository
callout callout.Callout
entitiesOps permissions.EntitiesOperations[permissions.Operation]
rolemw.RoleManagerCalloutMiddleware
}
func NewCallout(svc channels.Service, repo channels.Repository, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation], callout callout.Callout) (channels.Service, error) {
call, err := rolemw.NewCallout(policies.ChannelType, svc, callout, roleOps)
if err != nil {
return nil, err
}
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
return &calloutMiddleware{
svc: svc,
repo: repo,
callout: callout,
entitiesOps: entitiesOps,
RoleManagerCalloutMiddleware: call,
}, nil
}
func (cm *calloutMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
params := map[string]any{
"entities": chs,
"count": len(chs),
}
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, params); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, err
}
return cm.svc.CreateChannels(ctx, session, chs...)
}
func (cm *calloutMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpViewChannel, params); err != nil {
return channels.Channel{}, err
}
return cm.svc.ViewChannel(ctx, session, id, withRoles)
}
func (cm *calloutMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
params := map[string]any{
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpListDomainChannels, params); err != nil {
return channels.ChannelsPage{}, err
}
return cm.svc.ListChannels(ctx, session, pm)
}
func (cm *calloutMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
params := map[string]any{
"user_id": userID,
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpListUserChannels, params); err != nil {
return channels.ChannelsPage{}, err
}
return cm.svc.ListUserChannels(ctx, session, userID, pm)
}
func (cm *calloutMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
params := map[string]any{
"entity_id": channel.ID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannel, params); err != nil {
return channels.Channel{}, err
}
return cm.svc.UpdateChannel(ctx, session, channel)
}
func (cm *calloutMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
params := map[string]any{
"entity_id": channel.ID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, params); err != nil {
return channels.Channel{}, err
}
return cm.svc.UpdateChannelTags(ctx, session, channel)
}
func (cm *calloutMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpEnableChannel, params); err != nil {
return channels.Channel{}, err
}
return cm.svc.EnableChannel(ctx, session, id)
}
func (cm *calloutMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisableChannel, params); err != nil {
return channels.Channel{}, err
}
return cm.svc.DisableChannel(ctx, session, id)
}
func (cm *calloutMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDeleteChannel, params); err != nil {
return err
}
return cm.svc.RemoveChannel(ctx, session, id)
}
func (cm *calloutMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpConnectClient, params); err != nil {
return err
}
return cm.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
func (cm *calloutMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
params := map[string]any{
"channel_ids": chIDs,
"client_ids": thIDs,
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisconnectClient, params); err != nil {
return err
}
return cm.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
params := map[string]any{
"entity_id": id,
"parent_group_id": parentGroupID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpSetParentGroup, params); err != nil {
return err
}
return cm.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
ch, err := cm.repo.RetrieveByID(ctx, id)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
if ch.ParentGroup != "" {
params := map[string]any{
"entity_id": id,
"parent_group_id": ch.ParentGroup,
}
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpRemoveParentGroup, params); err != nil {
return err
}
}
return cm.svc.RemoveParentGroup(ctx, session, id)
}
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, pld map[string]any) error {
var entityID string
if id, ok := pld["entity_id"].(string); ok {
entityID = id
}
req := callout.Request{
BaseRequest: callout.BaseRequest{
Operation: cm.entitiesOps.OperationName(entityType, op),
EntityType: entityType,
EntityID: entityID,
CallerID: session.UserID,
CallerType: policies.UserType,
DomainID: session.DomainID,
Time: time.Now().UTC(),
},
Payload: pld,
}
if err := cm.callout.Callout(ctx, req); err != nil {
return err
}
return nil
}
-9
View File
@@ -1,9 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package middleware provides authorization, logging, metrics and tracing middleware
// for Magistrala Channels Service.
//
// For more details about tracing instrumentation for Magistrala refer to the
// documentation at https://magistrala.absmach.eu/docs/.
package middleware
-294
View File
@@ -1,294 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/roles"
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
"github.com/go-chi/chi/v5/middleware"
)
var _ channels.Service = (*loggingMiddleware)(nil)
type loggingMiddleware struct {
logger *slog.Logger
svc channels.Service
rolemw.RoleManagerLoggingMiddleware
}
// NewLogging adds logging facilities to the channels service.
func NewLogging(svc channels.Service, logger *slog.Logger) channels.Service {
return &loggingMiddleware{logger, svc, rolemw.NewLogging("channels", svc, logger)}
}
func (lm *loggingMiddleware) CreateChannels(ctx context.Context, session authn.Session, clients ...channels.Channel) (cs []channels.Channel, rps []roles.RoleProvision, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn(fmt.Sprintf("Create %d channels failed", len(clients)), args...)
return
}
lm.logger.Info(fmt.Sprintf("Create %d channel completed successfully", len(clients)), args...)
}(time.Now())
return lm.svc.CreateChannels(ctx, session, clients...)
}
func (lm *loggingMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (c channels.Channel, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("channel",
slog.String("id", c.ID),
slog.String("name", c.Name),
slog.Bool("with_roles", withRoles),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("View channel failed", args...)
return
}
lm.logger.Info("View channel completed successfully", args...)
}(time.Now())
return lm.svc.ViewChannel(ctx, session, id, withRoles)
}
func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (cp channels.ChannelsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("page",
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
slog.Uint64("total", cp.Total),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List channels failed", args...)
return
}
lm.logger.Info("List channels completed successfully", args...)
}(time.Now())
return lm.svc.ListChannels(ctx, session, pm)
}
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (cp channels.ChannelsPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.String("user_id", userID),
slog.Group("page",
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
slog.Uint64("total", cp.Total),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List user channels failed", args...)
return
}
lm.logger.Info("List user channels completed successfully", args...)
}(time.Now())
return lm.svc.ListUserChannels(ctx, session, userID, pm)
}
func (lm *loggingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("channel",
slog.String("id", client.ID),
slog.String("name", client.Name),
slog.Any("metadata", client.Metadata),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Update channel failed", args...)
return
}
lm.logger.Info("Update channel completed successfully", args...)
}(time.Now())
return lm.svc.UpdateChannel(ctx, session, client)
}
func (lm *loggingMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("channel",
slog.String("id", c.ID),
slog.String("name", c.Name),
slog.Any("tags", c.Tags),
),
}
if err != nil {
args := append(args, slog.String("error", err.Error()))
lm.logger.Warn("Update channel tags failed", args...)
return
}
lm.logger.Info("Update channel tags completed successfully", args...)
}(time.Now())
return lm.svc.UpdateChannelTags(ctx, session, client)
}
func (lm *loggingMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (c channels.Channel, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("channel",
slog.String("id", id),
slog.String("name", c.Name),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Enable channel failed", args...)
return
}
lm.logger.Info("Enable channel completed successfully", args...)
}(time.Now())
return lm.svc.EnableChannel(ctx, session, id)
}
func (lm *loggingMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (c channels.Channel, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Group("channel",
slog.String("id", id),
slog.String("name", c.Name),
),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Disable channel failed", args...)
return
}
lm.logger.Info("Disable channel completed successfully", args...)
}(time.Now())
return lm.svc.DisableChannel(ctx, session, id)
}
func (lm *loggingMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.String("channel_id", id),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Delete channel failed", args...)
return
}
lm.logger.Info("Delete channel completed successfully", args...)
}(time.Now())
return lm.svc.RemoveChannel(ctx, session, id)
}
func (lm *loggingMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connTypes []connections.ConnType) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Any("channel_ids", chIDs),
slog.Any("client_ids", clIDs),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Connect channels and clients failed", args...)
return
}
lm.logger.Info("Connect channels and clients completed successfully", args...)
}(time.Now())
return lm.svc.Connect(ctx, session, chIDs, clIDs, connTypes)
}
func (lm *loggingMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connTypes []connections.ConnType) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.Any("channel_ids", chIDs),
slog.Any("client_ids", clIDs),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Disconnect channels and clients failed", args...)
return
}
lm.logger.Info("Disconnect channels and clients completed successfully", args...)
}(time.Now())
return lm.svc.Disconnect(ctx, session, chIDs, clIDs, connTypes)
}
func (lm *loggingMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.String("parent_group_id", parentGroupID),
slog.String("channel_id", id),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Set parent group to channel failed", args...)
return
}
lm.logger.Info("Set parent group to channel completed successfully", args...)
}(time.Now())
return lm.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (lm *loggingMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", session.DomainID),
slog.String("request_id", middleware.GetReqID(ctx)),
slog.String("channel_id", id),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Remove parent group from channel failed", args...)
return
}
lm.logger.Info("Remove parent group from channel completed successfully", args...)
}(time.Now())
return lm.svc.RemoveParentGroup(ctx, session, id)
}
-139
View File
@@ -1,139 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/roles"
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
"github.com/go-kit/kit/metrics"
)
var _ channels.Service = (*metricsMiddleware)(nil)
type metricsMiddleware struct {
counter metrics.Counter
latency metrics.Histogram
svc channels.Service
rolemw.RoleManagerMetricsMiddleware
}
// NewMetrics returns a new metrics middleware wrapper.
func NewMetrics(svc channels.Service, counter metrics.Counter, latency metrics.Histogram) channels.Service {
return &metricsMiddleware{
counter: counter,
latency: latency,
svc: svc,
RoleManagerMetricsMiddleware: rolemw.NewMetrics("channels", svc, counter, latency),
}
}
func (ms *metricsMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
defer func(begin time.Time) {
ms.counter.With("method", "register_channels").Add(1)
ms.latency.With("method", "register_channels").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.CreateChannels(ctx, session, chs...)
}
func (ms *metricsMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
defer func(begin time.Time) {
ms.counter.With("method", "view_channel").Add(1)
ms.latency.With("method", "view_channel").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ViewChannel(ctx, session, id, withRoles)
}
func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_channels").Add(1)
ms.latency.With("method", "list_channels").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListChannels(ctx, session, pm)
}
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_user_channels").Add(1)
ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListUserChannels(ctx, session, userID, pm)
}
func (ms *metricsMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
defer func(begin time.Time) {
ms.counter.With("method", "update_channel").Add(1)
ms.latency.With("method", "update_channel").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.UpdateChannel(ctx, session, channel)
}
func (ms *metricsMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
defer func(begin time.Time) {
ms.counter.With("method", "update_channel_tags").Add(1)
ms.latency.With("method", "update_channel_tags").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.UpdateChannelTags(ctx, session, channel)
}
func (ms *metricsMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
defer func(begin time.Time) {
ms.counter.With("method", "enable_channel").Add(1)
ms.latency.With("method", "enable_channel").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.EnableChannel(ctx, session, id)
}
func (ms *metricsMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
defer func(begin time.Time) {
ms.counter.With("method", "disable_channel").Add(1)
ms.latency.With("method", "disable_channel").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.DisableChannel(ctx, session, id)
}
func (ms *metricsMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
defer func(begin time.Time) {
ms.counter.With("method", "delete_channel").Add(1)
ms.latency.With("method", "delete_channel").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RemoveChannel(ctx, session, id)
}
func (ms *metricsMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
defer func(begin time.Time) {
ms.counter.With("method", "connect").Add(1)
ms.latency.With("method", "connect").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
func (ms *metricsMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
defer func(begin time.Time) {
ms.counter.With("method", "disconnect").Add(1)
ms.latency.With("method", "disconnect").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
func (ms *metricsMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
defer func(begin time.Time) {
ms.counter.With("method", "set_parent_group").Add(1)
ms.latency.With("method", "set_parent_group").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (ms *metricsMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
defer func(begin time.Time) {
ms.counter.With("method", "remove_parent_group").Add(1)
ms.latency.With("method", "remove_parent_group").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RemoveParentGroup(ctx, session, id)
}
-135
View File
@@ -1,135 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"github.com/absmach/magistrala/channels"
"github.com/absmach/magistrala/pkg/authn"
"github.com/absmach/magistrala/pkg/connections"
"github.com/absmach/magistrala/pkg/roles"
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
"github.com/absmach/magistrala/pkg/tracing"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var _ channels.Service = (*tracingMiddleware)(nil)
type tracingMiddleware struct {
tracer trace.Tracer
svc channels.Service
rolemw.RoleManagerTracing
}
// NewTracing returns a new channels service with tracing capabilities.
func NewTracing(svc channels.Service, tracer trace.Tracer) channels.Service {
return &tracingMiddleware{tracer, svc, rolemw.NewTracing("channels", svc, tracer)}
}
// CreateChannels traces the "CreateChannels" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_create_channel")
defer span.End()
return tm.svc.CreateChannels(ctx, session, chs...)
}
// ViewChannel traces the "ViewChannel" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_channel", trace.WithAttributes(attribute.String("id", id)))
defer span.End()
return tm.svc.ViewChannel(ctx, session, id, withRoles)
}
// ListChannels traces the "ListChannels" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_channels")
defer span.End()
return tm.svc.ListChannels(ctx, session, pm)
}
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_user_channels")
defer span.End()
return tm.svc.ListUserChannels(ctx, session, userID, pm)
}
// UpdateChannel traces the "UpdateChannel" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, cli channels.Channel) (channels.Channel, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_update_channel", trace.WithAttributes(attribute.String("id", cli.ID)))
defer span.End()
return tm.svc.UpdateChannel(ctx, session, cli)
}
// UpdateChannelTags traces the "UpdateChannelTags" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, cli channels.Channel) (channels.Channel, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_update_channel_tags", trace.WithAttributes(
attribute.String("id", cli.ID),
attribute.StringSlice("tags", cli.Tags),
))
defer span.End()
return tm.svc.UpdateChannelTags(ctx, session, cli)
}
// EnableChannel traces the "EnableChannel" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_enable_channel", trace.WithAttributes(attribute.String("id", id)))
defer span.End()
return tm.svc.EnableChannel(ctx, session, id)
}
// DisableChannel traces the "DisableChannel" operation of the wrapped policies.Service.
func (tm *tracingMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_disable_channel", trace.WithAttributes(attribute.String("id", id)))
defer span.End()
return tm.svc.DisableChannel(ctx, session, id)
}
// DeleteChannel traces the "DeleteChannel" operation of the wrapped channels.Service.
func (tm *tracingMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "delete_channel", trace.WithAttributes(attribute.String("id", id)))
defer span.End()
return tm.svc.RemoveChannel(ctx, session, id)
}
func (tm *tracingMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "connect", trace.WithAttributes(
attribute.StringSlice("channel_ids", chIDs),
attribute.StringSlice("client_ids", thIDs),
))
defer span.End()
return tm.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
}
func (tm *tracingMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "disconnect", trace.WithAttributes(
attribute.StringSlice("channel_ids", chIDs),
attribute.StringSlice("client_ids", thIDs),
))
defer span.End()
return tm.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
}
func (tm *tracingMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "set_parent_group", trace.WithAttributes(
attribute.String("parent_group_id", parentGroupID),
attribute.String("id", id),
))
defer span.End()
return tm.svc.SetParentGroup(ctx, session, parentGroupID, id)
}
func (tm *tracingMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "remove_parent_group", trace.WithAttributes(
attribute.String("id", id),
))
defer span.End()
return tm.svc.RemoveParentGroup(ctx, session, id)
}
-246
View File
@@ -1,246 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
mock "github.com/stretchr/testify/mock"
)
// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewCache(t interface {
mock.TestingT
Cleanup(func())
}) *Cache {
mock := &Cache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Cache is an autogenerated mock type for the Cache type
type Cache struct {
mock.Mock
}
type Cache_Expecter struct {
mock *mock.Mock
}
func (_m *Cache) EXPECT() *Cache_Expecter {
return &Cache_Expecter{mock: &_m.Mock}
}
// ID provides a mock function for the type Cache
func (_mock *Cache) ID(ctx context.Context, channelRoute string, domainID string) (string, error) {
ret := _mock.Called(ctx, channelRoute, domainID)
if len(ret) == 0 {
panic("no return value specified for ID")
}
var r0 string
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (string, error)); ok {
return returnFunc(ctx, channelRoute, domainID)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) string); ok {
r0 = returnFunc(ctx, channelRoute, domainID)
} else {
r0 = ret.Get(0).(string)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = returnFunc(ctx, channelRoute, domainID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Cache_ID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ID'
type Cache_ID_Call struct {
*mock.Call
}
// ID is a helper method to define mock.On call
// - ctx context.Context
// - channelRoute string
// - domainID string
func (_e *Cache_Expecter) ID(ctx interface{}, channelRoute interface{}, domainID interface{}) *Cache_ID_Call {
return &Cache_ID_Call{Call: _e.mock.On("ID", ctx, channelRoute, domainID)}
}
func (_c *Cache_ID_Call) Run(run func(ctx context.Context, channelRoute string, domainID string)) *Cache_ID_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Cache_ID_Call) Return(s string, err error) *Cache_ID_Call {
_c.Call.Return(s, err)
return _c
}
func (_c *Cache_ID_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string) (string, error)) *Cache_ID_Call {
_c.Call.Return(run)
return _c
}
// Remove provides a mock function for the type Cache
func (_mock *Cache) Remove(ctx context.Context, channelRoute string, domainID string) error {
ret := _mock.Called(ctx, channelRoute, domainID)
if len(ret) == 0 {
panic("no return value specified for Remove")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = returnFunc(ctx, channelRoute, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Cache_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove'
type Cache_Remove_Call struct {
*mock.Call
}
// Remove is a helper method to define mock.On call
// - ctx context.Context
// - channelRoute string
// - domainID string
func (_e *Cache_Expecter) Remove(ctx interface{}, channelRoute interface{}, domainID interface{}) *Cache_Remove_Call {
return &Cache_Remove_Call{Call: _e.mock.On("Remove", ctx, channelRoute, domainID)}
}
func (_c *Cache_Remove_Call) Run(run func(ctx context.Context, channelRoute string, domainID string)) *Cache_Remove_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Cache_Remove_Call) Return(err error) *Cache_Remove_Call {
_c.Call.Return(err)
return _c
}
func (_c *Cache_Remove_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string) error) *Cache_Remove_Call {
_c.Call.Return(run)
return _c
}
// Save provides a mock function for the type Cache
func (_mock *Cache) Save(ctx context.Context, channelRoute string, domainID string, channelID string) error {
ret := _mock.Called(ctx, channelRoute, domainID, channelID)
if len(ret) == 0 {
panic("no return value specified for Save")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok {
r0 = returnFunc(ctx, channelRoute, domainID, channelID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Cache_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save'
type Cache_Save_Call struct {
*mock.Call
}
// Save is a helper method to define mock.On call
// - ctx context.Context
// - channelRoute string
// - domainID string
// - channelID string
func (_e *Cache_Expecter) Save(ctx interface{}, channelRoute interface{}, domainID interface{}, channelID interface{}) *Cache_Save_Call {
return &Cache_Save_Call{Call: _e.mock.On("Save", ctx, channelRoute, domainID, channelID)}
}
func (_c *Cache_Save_Call) Run(run func(ctx context.Context, channelRoute string, domainID string, channelID string)) *Cache_Save_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
run(
arg0,
arg1,
arg2,
arg3,
)
})
return _c
}
func (_c *Cache_Save_Call) Return(err error) *Cache_Save_Call {
_c.Call.Return(err)
return _c
}
func (_c *Cache_Save_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string, channelID string) error) *Cache_Save_Call {
_c.Call.Return(run)
return _c
}
-460
View File
@@ -1,460 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
"github.com/absmach/magistrala/api/grpc/channels/v1"
v10 "github.com/absmach/magistrala/api/grpc/common/v1"
mock "github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
// NewChannelsServiceClient creates a new instance of ChannelsServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewChannelsServiceClient(t interface {
mock.TestingT
Cleanup(func())
}) *ChannelsServiceClient {
mock := &ChannelsServiceClient{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// ChannelsServiceClient is an autogenerated mock type for the ChannelsServiceClient type
type ChannelsServiceClient struct {
mock.Mock
}
type ChannelsServiceClient_Expecter struct {
mock *mock.Mock
}
func (_m *ChannelsServiceClient) EXPECT() *ChannelsServiceClient_Expecter {
return &ChannelsServiceClient_Expecter{mock: &_m.Mock}
}
// Authorize provides a mock function for the type ChannelsServiceClient
func (_mock *ChannelsServiceClient) Authorize(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption) (*v1.AuthzRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for Authorize")
}
var r0 *v1.AuthzRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) (*v1.AuthzRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) *v1.AuthzRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.AuthzRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ChannelsServiceClient_Authorize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Authorize'
type ChannelsServiceClient_Authorize_Call struct {
*mock.Call
}
// Authorize is a helper method to define mock.On call
// - ctx context.Context
// - in *v1.AuthzReq
// - opts ...grpc.CallOption
func (_e *ChannelsServiceClient_Expecter) Authorize(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_Authorize_Call {
return &ChannelsServiceClient_Authorize_Call{Call: _e.mock.On("Authorize",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *ChannelsServiceClient_Authorize_Call) Run(run func(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption)) *ChannelsServiceClient_Authorize_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v1.AuthzReq
if args[1] != nil {
arg1 = args[1].(*v1.AuthzReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *ChannelsServiceClient_Authorize_Call) Return(authzRes *v1.AuthzRes, err error) *ChannelsServiceClient_Authorize_Call {
_c.Call.Return(authzRes, err)
return _c
}
func (_c *ChannelsServiceClient_Authorize_Call) RunAndReturn(run func(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption) (*v1.AuthzRes, error)) *ChannelsServiceClient_Authorize_Call {
_c.Call.Return(run)
return _c
}
// RemoveClientConnections provides a mock function for the type ChannelsServiceClient
func (_mock *ChannelsServiceClient) RemoveClientConnections(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for RemoveClientConnections")
}
var r0 *v1.RemoveClientConnectionsRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) *v1.RemoveClientConnectionsRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.RemoveClientConnectionsRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ChannelsServiceClient_RemoveClientConnections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveClientConnections'
type ChannelsServiceClient_RemoveClientConnections_Call struct {
*mock.Call
}
// RemoveClientConnections is a helper method to define mock.On call
// - ctx context.Context
// - in *v1.RemoveClientConnectionsReq
// - opts ...grpc.CallOption
func (_e *ChannelsServiceClient_Expecter) RemoveClientConnections(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RemoveClientConnections_Call {
return &ChannelsServiceClient_RemoveClientConnections_Call{Call: _e.mock.On("RemoveClientConnections",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) Run(run func(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RemoveClientConnections_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v1.RemoveClientConnectionsReq
if args[1] != nil {
arg1 = args[1].(*v1.RemoveClientConnectionsReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) Return(removeClientConnectionsRes *v1.RemoveClientConnectionsRes, err error) *ChannelsServiceClient_RemoveClientConnections_Call {
_c.Call.Return(removeClientConnectionsRes, err)
return _c
}
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) RunAndReturn(run func(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error)) *ChannelsServiceClient_RemoveClientConnections_Call {
_c.Call.Return(run)
return _c
}
// RetrieveEntity provides a mock function for the type ChannelsServiceClient
func (_mock *ChannelsServiceClient) RetrieveEntity(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for RetrieveEntity")
}
var r0 *v10.RetrieveEntityRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) (*v10.RetrieveEntityRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) *v10.RetrieveEntityRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v10.RetrieveEntityRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ChannelsServiceClient_RetrieveEntity_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntity'
type ChannelsServiceClient_RetrieveEntity_Call struct {
*mock.Call
}
// RetrieveEntity is a helper method to define mock.On call
// - ctx context.Context
// - in *v10.RetrieveEntityReq
// - opts ...grpc.CallOption
func (_e *ChannelsServiceClient_Expecter) RetrieveEntity(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RetrieveEntity_Call {
return &ChannelsServiceClient_RetrieveEntity_Call{Call: _e.mock.On("RetrieveEntity",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *ChannelsServiceClient_RetrieveEntity_Call) Run(run func(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RetrieveEntity_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v10.RetrieveEntityReq
if args[1] != nil {
arg1 = args[1].(*v10.RetrieveEntityReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *ChannelsServiceClient_RetrieveEntity_Call) Return(retrieveEntityRes *v10.RetrieveEntityRes, err error) *ChannelsServiceClient_RetrieveEntity_Call {
_c.Call.Return(retrieveEntityRes, err)
return _c
}
func (_c *ChannelsServiceClient_RetrieveEntity_Call) RunAndReturn(run func(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error)) *ChannelsServiceClient_RetrieveEntity_Call {
_c.Call.Return(run)
return _c
}
// RetrieveIDByRoute provides a mock function for the type ChannelsServiceClient
func (_mock *ChannelsServiceClient) RetrieveIDByRoute(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for RetrieveIDByRoute")
}
var r0 *v10.RetrieveEntityRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) (*v10.RetrieveEntityRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) *v10.RetrieveEntityRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v10.RetrieveEntityRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ChannelsServiceClient_RetrieveIDByRoute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveIDByRoute'
type ChannelsServiceClient_RetrieveIDByRoute_Call struct {
*mock.Call
}
// RetrieveIDByRoute is a helper method to define mock.On call
// - ctx context.Context
// - in *v10.RetrieveIDByRouteReq
// - opts ...grpc.CallOption
func (_e *ChannelsServiceClient_Expecter) RetrieveIDByRoute(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RetrieveIDByRoute_Call {
return &ChannelsServiceClient_RetrieveIDByRoute_Call{Call: _e.mock.On("RetrieveIDByRoute",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) Run(run func(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RetrieveIDByRoute_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v10.RetrieveIDByRouteReq
if args[1] != nil {
arg1 = args[1].(*v10.RetrieveIDByRouteReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) Return(retrieveEntityRes *v10.RetrieveEntityRes, err error) *ChannelsServiceClient_RetrieveIDByRoute_Call {
_c.Call.Return(retrieveEntityRes, err)
return _c
}
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) RunAndReturn(run func(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error)) *ChannelsServiceClient_RetrieveIDByRoute_Call {
_c.Call.Return(run)
return _c
}
// UnsetParentGroupFromChannels provides a mock function for the type ChannelsServiceClient
func (_mock *ChannelsServiceClient) UnsetParentGroupFromChannels(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for UnsetParentGroupFromChannels")
}
var r0 *v1.UnsetParentGroupFromChannelsRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) *v1.UnsetParentGroupFromChannelsRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.UnsetParentGroupFromChannelsRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// ChannelsServiceClient_UnsetParentGroupFromChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnsetParentGroupFromChannels'
type ChannelsServiceClient_UnsetParentGroupFromChannels_Call struct {
*mock.Call
}
// UnsetParentGroupFromChannels is a helper method to define mock.On call
// - ctx context.Context
// - in *v1.UnsetParentGroupFromChannelsReq
// - opts ...grpc.CallOption
func (_e *ChannelsServiceClient_Expecter) UnsetParentGroupFromChannels(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
return &ChannelsServiceClient_UnsetParentGroupFromChannels_Call{Call: _e.mock.On("UnsetParentGroupFromChannels",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) Run(run func(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption)) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v1.UnsetParentGroupFromChannelsReq
if args[1] != nil {
arg1 = args[1].(*v1.UnsetParentGroupFromChannelsReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) Return(unsetParentGroupFromChannelsRes *v1.UnsetParentGroupFromChannelsRes, err error) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
_c.Call.Return(unsetParentGroupFromChannelsRes, err)
return _c
}
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) RunAndReturn(run func(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error)) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
_c.Call.Return(run)
return _c
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-72
View File
@@ -1,72 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package operations
import (
"github.com/absmach/magistrala/pkg/permissions"
)
// Channel Operations.
const (
OpViewChannel permissions.Operation = iota
OpUpdateChannel
OpUpdateChannelTags
OpEnableChannel
OpDisableChannel
OpDeleteChannel
OpSetParentGroup
OpRemoveParentGroup
OpConnectClient
OpDisconnectClient
OpListUserChannels
)
func OperationDetails() map[permissions.Operation]permissions.OperationDetails {
return map[permissions.Operation]permissions.OperationDetails{
OpViewChannel: {
Name: "view",
PermissionRequired: true,
},
OpUpdateChannel: {
Name: "update",
PermissionRequired: true,
},
OpUpdateChannelTags: {
Name: "update_tags",
PermissionRequired: true,
},
OpEnableChannel: {
Name: "enable",
PermissionRequired: true,
},
OpDisableChannel: {
Name: "disable",
PermissionRequired: true,
},
OpDeleteChannel: {
Name: "delete",
PermissionRequired: true,
},
OpSetParentGroup: {
Name: "set_parent_group",
PermissionRequired: true,
},
OpRemoveParentGroup: {
Name: "remove_parent_group",
PermissionRequired: true,
},
OpConnectClient: {
Name: "connect_client",
PermissionRequired: true,
},
OpDisconnectClient: {
Name: "disconnect_client",
PermissionRequired: true,
},
OpListUserChannels: {
Name: "list_user_channels",
PermissionRequired: false, // hardcoded to superadmin
},
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
-26
View File
@@ -1,26 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/magistrala/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "unique_domain_route_not_null":
return errors.ErrRouteNotAvailable, true
case "channels_pkey":
return errors.NewRequestError("channel id already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
-121
View File
@@ -1,121 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
gpostgres "github.com/absmach/magistrala/groups/postgres"
"github.com/absmach/magistrala/pkg/errors"
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
rolesPostgres "github.com/absmach/magistrala/pkg/roles/repo/postgres"
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
migrate "github.com/rubenv/sql-migrate"
)
func Migration() (*migrate.MemoryMigrationSource, error) {
rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName)
if err != nil {
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
}
channelsMigration := &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "channels_01",
// VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters
// STATUS 0 to imply enabled and 1 to imply disabled
Up: []string{
`CREATE TABLE IF NOT EXISTS channels (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(1024),
domain_id VARCHAR(36) NOT NULL,
parent_group_id VARCHAR(36) DEFAULT NULL,
tags TEXT[],
metadata JSONB,
created_by VARCHAR(254),
created_at TIMESTAMP,
updated_at TIMESTAMP,
updated_by VARCHAR(254),
status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0),
UNIQUE (id, domain_id),
UNIQUE (domain_id, name)
)`,
`CREATE TABLE IF NOT EXISTS connections (
channel_id VARCHAR(36),
domain_id VARCHAR(36),
client_id VARCHAR(36),
type SMALLINT NOT NULL CHECK (type IN (1, 2)),
FOREIGN KEY (channel_id, domain_id) REFERENCES channels (id, domain_id) ON DELETE CASCADE ON UPDATE CASCADE,
PRIMARY KEY (channel_id, domain_id, client_id, type)
)`,
},
Down: []string{
`DROP TABLE IF EXISTS channels`,
`DROP TABLE IF EXISTS connections`,
},
},
{
Id: "channels_02",
Up: []string{
`ALTER TABLE channels DROP CONSTRAINT IF EXISTS channels_domain_id_name_key`,
},
Down: []string{
`ALTER TABLE channels ADD CONSTRAINT channels_domain_id_name_key UNIQUE (domain_id, name)`,
},
},
{
Id: "channels_03",
Up: []string{
`ALTER TABLE channels ADD COLUMN IF NOT EXISTS route VARCHAR(36);`,
`CREATE UNIQUE INDEX IF NOT EXISTS unique_domain_route_not_null ON channels (domain_id, route) WHERE route IS NOT NULL;`,
},
Down: []string{
`DROP INDEX IF EXISTS unique_domain_route_not_null;`,
`ALTER TABLE channels DROP COLUMN IF EXISTS route;`,
},
},
{
Id: "channels_04",
Up: []string{
`ALTER TABLE channels ALTER COLUMN created_at TYPE TIMESTAMPTZ;`,
`ALTER TABLE channels ALTER COLUMN updated_at TYPE TIMESTAMPTZ;`,
},
Down: []string{
`ALTER TABLE channels ALTER COLUMN created_at TYPE TIMESTAMP;`,
`ALTER TABLE channels ALTER COLUMN updated_at TYPE TIMESTAMP;`,
},
},
{
Id: "channels_05",
Up: []string{
`UPDATE channels
SET metadata = (COALESCE(metadata, '{}'::jsonb) || COALESCE(metadata->'ui', '{}'::jsonb)) - 'ui'
WHERE metadata ? 'ui' AND jsonb_typeof(metadata->'ui') = 'object'`,
},
Down: []string{
`SELECT 1`,
},
},
{
Id: "channels_06",
Up: []string{
`CREATE INDEX IF NOT EXISTS idx_channels_domain_id_status ON channels(domain_id, status);`,
`CREATE INDEX IF NOT EXISTS idx_channels_parent_group_id ON channels(parent_group_id);`,
},
Down: []string{
`DROP INDEX IF EXISTS idx_channels_domain_id_status;`,
`DROP INDEX IF EXISTS idx_channels_parent_group_id;`,
},
},
},
}
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
groupsMigration, err := gpostgres.Migration()
if err != nil {
return &migrate.MemoryMigrationSource{}, err
}
channelsMigration.Migrations = append(channelsMigration.Migrations, groupsMigration.Migrations...)
return channelsMigration, nil
}

Some files were not shown because too many files have changed in this diff Show More