mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
inital integration with ATOM
Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
@@ -10,7 +10,9 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/api-tests.yaml"
|
||||
- "api/**"
|
||||
- "auth/api/http/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "channels/api/http/**"
|
||||
- "clients/api/http/**"
|
||||
- "domains/api/http/**"
|
||||
@@ -20,9 +22,9 @@ on:
|
||||
- "bootstrap/api/**"
|
||||
- "certs/api/http/**"
|
||||
- "readers/api/http/**"
|
||||
- "re/api/**"
|
||||
- "alarms/api/**"
|
||||
- "reports/api/**"
|
||||
- "re/**"
|
||||
- "alarms/**"
|
||||
- "reports/**"
|
||||
- "apidocs/openapi/**"
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -30,7 +32,9 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/api-tests.yaml"
|
||||
- "api/**"
|
||||
- "auth/api/http/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "channels/api/http/**"
|
||||
- "clients/api/http/**"
|
||||
- "domains/api/http/**"
|
||||
@@ -40,9 +44,9 @@ on:
|
||||
- "bootstrap/api/**"
|
||||
- "certs/api/http/**"
|
||||
- "readers/api/http/**"
|
||||
- "re/api/**"
|
||||
- "alarms/api/**"
|
||||
- "reports/api/**"
|
||||
- "re/**"
|
||||
- "alarms/**"
|
||||
- "reports/**"
|
||||
- "apidocs/openapi/**"
|
||||
|
||||
concurrency:
|
||||
@@ -50,17 +54,16 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
TOKENS_URL: http://localhost:9002/users/tokens/issue
|
||||
CREATE_DOMAINS_URL: http://localhost:9003/domains
|
||||
TOKENS_URL: http://localhost:9000/users/tokens/issue
|
||||
CREATE_DOMAINS_URL: http://localhost:9000/domains
|
||||
USER_IDENTITY: admin@example.com
|
||||
USER_SECRET: 12345678
|
||||
DOMAIN_NAME: demo-test
|
||||
USERS_URL: http://localhost:9002
|
||||
DOMAIN_URL: http://localhost:9003
|
||||
CLIENTS_URL: http://localhost:9006
|
||||
CHANNELS_URL: http://localhost:9005
|
||||
GROUPS_URL: http://localhost:9004
|
||||
AUTH_URL: http://localhost:9001
|
||||
USERS_URL: http://localhost:9000
|
||||
DOMAIN_URL: http://localhost:9000
|
||||
CLIENTS_URL: http://localhost:9000
|
||||
CHANNELS_URL: http://localhost:9000
|
||||
GROUPS_URL: http://localhost:9000
|
||||
JOURNAL_URL: http://localhost:9021
|
||||
BOOTSTRAP_URL: http://localhost:9013
|
||||
CERTS_URL: http://localhost:9019
|
||||
@@ -96,28 +99,39 @@ jobs:
|
||||
- "apidocs/openapi/journal.yaml"
|
||||
- "journal/api/**"
|
||||
|
||||
auth:
|
||||
- "apidocs/openapi/auth.yaml"
|
||||
- "auth/api/http/**"
|
||||
|
||||
domains:
|
||||
- "apidocs/openapi/domains.yaml"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "domains/api/http/**"
|
||||
|
||||
clients:
|
||||
- "apidocs/openapi/clients.yaml"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "clients/api/http/**"
|
||||
|
||||
channels:
|
||||
- "apidocs/openapi/channels.yaml"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "channels/api/http/**"
|
||||
|
||||
groups:
|
||||
- "apidocs/openapi/groups.yaml"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "groups/api/http/**"
|
||||
|
||||
users:
|
||||
- "apidocs/openapi/users.yaml"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
- "users/api/**"
|
||||
|
||||
bootstrap:
|
||||
@@ -134,15 +148,21 @@ jobs:
|
||||
|
||||
re:
|
||||
- "apidocs/openapi/rules.yaml"
|
||||
- "re/api/**"
|
||||
- "re/**"
|
||||
- "cmd/re/**"
|
||||
- "internal/atom/**"
|
||||
|
||||
alarms:
|
||||
- "apidocs/openapi/alarms.yaml"
|
||||
- "alarms/api/**"
|
||||
- "alarms/**"
|
||||
- "cmd/alarms/**"
|
||||
- "internal/atom/**"
|
||||
|
||||
reports:
|
||||
- "apidocs/openapi/reports.yaml"
|
||||
- "reports/api/**"
|
||||
- "reports/**"
|
||||
- "cmd/reports/**"
|
||||
- "internal/atom/**"
|
||||
|
||||
- name: Build images
|
||||
run: make all -j $(nproc) && make dockers_dev -j $(nproc)
|
||||
@@ -157,7 +177,7 @@ jobs:
|
||||
|
||||
# Check if services are responding
|
||||
for i in {1..30}; do
|
||||
if curl -f -s http://localhost:9002/health > /dev/null 2>&1; then
|
||||
if curl -f -s http://localhost:9000/health > /dev/null 2>&1; then
|
||||
echo "Services are ready!"
|
||||
break
|
||||
fi
|
||||
@@ -209,15 +229,6 @@ jobs:
|
||||
checks: all
|
||||
args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples'
|
||||
|
||||
- name: Run Auth API tests
|
||||
if: steps.changes.outputs.auth == 'true' || steps.changes.outputs.workflow == 'true'
|
||||
uses: schemathesis/action@v3.0.0
|
||||
with:
|
||||
schema: apidocs/openapi/auth.yaml
|
||||
base-url: ${{ env.AUTH_URL }}
|
||||
checks: all
|
||||
args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples'
|
||||
|
||||
- name: Run Domains API tests
|
||||
if: steps.changes.outputs.domains == 'true' || steps.changes.outputs.workflow == 'true'
|
||||
uses: schemathesis/action@v3.0.0
|
||||
|
||||
@@ -66,28 +66,10 @@ jobs:
|
||||
workflow:
|
||||
- ".github/workflows/tests.yaml"
|
||||
|
||||
auth:
|
||||
- "auth/**"
|
||||
- "cmd/auth/**"
|
||||
- "auth.proto"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "pkg/ulid/**"
|
||||
- "pkg/uuid/**"
|
||||
|
||||
bootstrap:
|
||||
- "bootstrap/**"
|
||||
- "cmd/bootstrap/**"
|
||||
- "pkg/bootstrap/**"
|
||||
- "provision/**"
|
||||
- "pkg/sdk/**"
|
||||
|
||||
channels:
|
||||
- "channels/**"
|
||||
- "cmd/channels/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "pkg/sdk/**"
|
||||
- "clients/api/grpc/**"
|
||||
- "groups/api/grpc/**"
|
||||
@@ -101,10 +83,8 @@ jobs:
|
||||
|
||||
clients:
|
||||
- "clients/**"
|
||||
- "cmd/clients/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "pkg/ulid/**"
|
||||
- "pkg/uuid/**"
|
||||
- "pkg/events/**"
|
||||
@@ -115,18 +95,14 @@ jobs:
|
||||
|
||||
domains:
|
||||
- "domains/**"
|
||||
- "cmd/domains/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/grpc/**"
|
||||
|
||||
groups:
|
||||
- "groups/**"
|
||||
- "cmd/groups/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "pkg/ulid/**"
|
||||
- "pkg/uuid/**"
|
||||
- "clients/api/grpc/**"
|
||||
@@ -140,9 +116,6 @@ jobs:
|
||||
journal:
|
||||
- "journal/**"
|
||||
- "cmd/journal/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "pkg/events/**"
|
||||
|
||||
logger:
|
||||
@@ -165,7 +138,6 @@ jobs:
|
||||
- "pkg/sdk/**"
|
||||
- "pkg/errors/**"
|
||||
- "pkg/groups/**"
|
||||
- "auth/**"
|
||||
- "internal/*"
|
||||
- "clients/**"
|
||||
- "users/**"
|
||||
@@ -189,10 +161,8 @@ jobs:
|
||||
|
||||
users:
|
||||
- "users/**"
|
||||
- "cmd/users/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "auth/**"
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "pkg/ulid/**"
|
||||
- "pkg/uuid/**"
|
||||
- "pkg/events/**"
|
||||
@@ -200,8 +170,6 @@ jobs:
|
||||
notifications:
|
||||
- "notifications/**"
|
||||
- "cmd/notifications/**"
|
||||
- "auth.pb.go"
|
||||
- "auth_grpc.pb.go"
|
||||
- "consumers/notifier.go"
|
||||
- "pkg/events/**"
|
||||
|
||||
@@ -232,6 +200,10 @@ jobs:
|
||||
reports:
|
||||
- "reports/**"
|
||||
- "cmd/reports/**"
|
||||
core:
|
||||
- "core/**"
|
||||
- "cmd/core/**"
|
||||
- "internal/atom/**"
|
||||
|
||||
- name: Set matrix for changed modules
|
||||
id: set-matrix
|
||||
@@ -240,11 +212,11 @@ jobs:
|
||||
|
||||
if [[ "${{ steps.changes.outputs.workflow }}" == "true" || "${{ steps.changes.outputs.pkg-errors }}" == "true" ]]; then
|
||||
# If workflow or pkg/errors changed, test everything
|
||||
modules=("auth" "bootstrap" "channels" "cli" "clients" "domains" "groups" "internal" "journal" "logger" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers" "re" "alarms" "reports")
|
||||
modules=("auth" "core" "channels" "cli" "clients" "domains" "groups" "internal" "journal" "logger" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers" "re" "alarms" "reports")
|
||||
else
|
||||
# Add only changed modules
|
||||
[[ "${{ steps.changes.outputs.auth }}" == "true" ]] && modules+=("auth")
|
||||
[[ "${{ steps.changes.outputs.bootstrap }}" == "true" ]] && modules+=("bootstrap")
|
||||
[[ "${{ steps.changes.outputs.core }}" == "true" ]] && modules+=("core")
|
||||
[[ "${{ steps.changes.outputs.channels }}" == "true" ]] && modules+=("channels")
|
||||
[[ "${{ steps.changes.outputs.cli }}" == "true" ]] && modules+=("cli")
|
||||
[[ "${{ steps.changes.outputs.clients }}" == "true" ]] && modules+=("clients")
|
||||
|
||||
@@ -21,3 +21,6 @@ docker/addons/certs/openbao/
|
||||
|
||||
# Ignore SeaweedFS data directory as it contains runtime-generated data
|
||||
docker/data/*
|
||||
|
||||
demo-ui
|
||||
node_modules
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
override MG_DOCKER_IMAGE_NAME_PREFIX := ghcr.io/absmach/magistrala
|
||||
MG_DOCKER_VOLUME_NAME_PREFIX ?= magistrala
|
||||
BUILD_DIR ?= build
|
||||
SERVICES = auth users clients groups channels domains notifications certs re postgres-writer postgres-reader timescale-writer timescale-reader cli alarms reports bootstrap provision journal fluxmq
|
||||
TEST_API_SERVICES = journal auth certs clients users channels groups domains
|
||||
SERVICES = notifications certs re postgres-writer postgres-reader timescale-writer timescale-reader alarms reports journal fluxmq
|
||||
TEST_API_SERVICES = journal certs clients users channels groups domains
|
||||
TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES))
|
||||
DOCKERS = $(addprefix docker_,$(SERVICES))
|
||||
DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES))
|
||||
@@ -199,12 +199,11 @@ define test_api_service
|
||||
--phases=examples,stateful
|
||||
endef
|
||||
|
||||
test_api_users: TEST_API_URL := http://localhost:9002
|
||||
test_api_clients: TEST_API_URL := http://localhost:9006
|
||||
test_api_domains: TEST_API_URL := http://localhost:9003
|
||||
test_api_channels: TEST_API_URL := http://localhost:9005
|
||||
test_api_groups: TEST_API_URL := http://localhost:9004
|
||||
test_api_auth: TEST_API_URL := http://localhost:9001
|
||||
test_api_users: TEST_API_URL := http://localhost:9000
|
||||
test_api_clients: TEST_API_URL := http://localhost:9000
|
||||
test_api_domains: TEST_API_URL := http://localhost:9000
|
||||
test_api_channels: TEST_API_URL := http://localhost:9000
|
||||
test_api_groups: TEST_API_URL := http://localhost:9000
|
||||
test_api_certs: TEST_API_URL := http://localhost:9019
|
||||
test_api_journal: TEST_API_URL := http://localhost:9021
|
||||
|
||||
@@ -262,7 +261,7 @@ rundev:
|
||||
cd scripts && ./run.sh
|
||||
|
||||
grpc_mtls_certs:
|
||||
$(MAKE) -C docker/ssl auth_grpc_certs clients_grpc_certs
|
||||
$(MAKE) -C docker/ssl clients_grpc_certs
|
||||
|
||||
check_tls:
|
||||
ifeq ($(GRPC_TLS),true)
|
||||
@@ -284,7 +283,7 @@ check_certs: check_mtls check_tls
|
||||
ifeq ($(GRPC_MTLS_CERT_FILES_EXISTS),0)
|
||||
ifeq ($(filter true,$(GRPC_MTLS) $(GRPC_TLS)),true)
|
||||
ifeq ($(filter $(DEFAULT_DOCKER_COMPOSE_COMMAND),$(DOCKER_COMPOSE_COMMAND)),$(DEFAULT_DOCKER_COMPOSE_COMMAND))
|
||||
$(MAKE) -C docker/ssl auth_grpc_certs clients_grpc_certs
|
||||
$(MAKE) -C docker/ssl clients_grpc_certs
|
||||
endif
|
||||
endif
|
||||
endif
|
||||
@@ -312,7 +311,7 @@ run_stable: check_certs
|
||||
|
||||
run_addons: check_certs
|
||||
$(foreach SVC,$(RUN_ADDON_ARGS),$(if $(filter $(SVC),$(ADDON_SERVICES) $(EXTERNAL_SERVICES)),,$(error Invalid Service $(SVC))))
|
||||
@$(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file ./docker/.env -p $(DOCKER_PROJECT) up -d auth domains jaeger
|
||||
@$(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file ./docker/.env -p $(DOCKER_PROJECT) up -d atom jaeger
|
||||
@for SVC in $(RUN_ADDON_ARGS); do \
|
||||
MG_ADDONS_CERTS_PATH_PREFIX="../" $(DOCKER_PLATFORM) docker compose -f docker/addons/$$SVC/docker-compose.yaml -p $(DOCKER_PROJECT) --env-file ./docker/.env $(DOCKER_COMPOSE_COMMAND) $(args) & \
|
||||
done
|
||||
|
||||
+6
-11
@@ -26,16 +26,11 @@ The service is configured using the following environment variables (values show
|
||||
| `MG_MESSAGE_BROKER_URL` | Message broker URL for alarm ingestion | `nats://nats:4222` |
|
||||
| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` |
|
||||
| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` |
|
||||
| `MG_AUTH_GRPC_URL` | Auth gRPC endpoint | `auth:7001` |
|
||||
| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `300s` |
|
||||
| `MG_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt}` |
|
||||
| `MG_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key}` |
|
||||
| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` |
|
||||
| `MG_DOMAINS_GRPC_URL` | Domains gRPC endpoint | `domains:7003` |
|
||||
| `MG_DOMAINS_GRPC_TIMEOUT` | Domains gRPC timeout | `300s` |
|
||||
| `MG_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt}` |
|
||||
| `MG_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key}` |
|
||||
| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` |
|
||||
| `ATOM_URL` | Atom HTTP endpoint | `http://atom:8080` |
|
||||
| `ATOM_JWKS_URL` | Atom JWKS endpoint for JWT verification | `http://atom:8080/.well-known/jwks.json` |
|
||||
| `ATOM_ADMIN_USERNAME` | Atom admin login for service projections | `atom-admin` |
|
||||
| `ATOM_ADMIN_SECRET` | Atom admin secret for service projections | `change-me` |
|
||||
| `ATOM_TIMEOUT` | Atom request timeout | `5s` |
|
||||
| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to access | `true` |
|
||||
|
||||
## Features
|
||||
@@ -44,7 +39,7 @@ The service is configured using the following environment variables (values show
|
||||
- **Stateful updates**: Updates assignee, acknowledgment, resolution, and metadata fields.
|
||||
- **Filtering and paging**: Lists alarms by domain, rule, channel, client, subtopic, status, severity, and time range.
|
||||
- **Observability**: `/metrics` Prometheus endpoint and Jaeger tracing support.
|
||||
- **Auth and authorization**: Authn/authz enforced via gRPC auth and domains services.
|
||||
- **Auth and authorization**: Authn/authz enforced through Atom JWT verification and PDP checks.
|
||||
|
||||
## Architecture
|
||||
|
||||
|
||||
+1
-2
@@ -106,7 +106,7 @@ func (a Alarm) Validate() error {
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service.
|
||||
type Service interface {
|
||||
CreateAlarm(ctx context.Context, alarm Alarm) error
|
||||
CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error)
|
||||
UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error)
|
||||
ViewAlarm(ctx context.Context, session authn.Session, id string) (Alarm, error)
|
||||
ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error)
|
||||
@@ -118,6 +118,5 @@ type Repository interface {
|
||||
UpdateAlarm(ctx context.Context, alarm Alarm) (Alarm, error)
|
||||
ViewAlarm(ctx context.Context, alarmID, domainID string) (Alarm, error)
|
||||
ListAllAlarms(ctx context.Context, pm PageMetadata) (AlarmsPage, error)
|
||||
ListUserAlarms(ctx context.Context, userID string, pm PageMetadata) (AlarmsPage, error)
|
||||
DeleteAlarm(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package alarms
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/internal/atom"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
)
|
||||
|
||||
type atomService struct {
|
||||
Service
|
||||
projector atom.Projector
|
||||
}
|
||||
|
||||
func WithAtom(svc Service, projector atom.Projector) Service {
|
||||
if projector == nil {
|
||||
return svc
|
||||
}
|
||||
return atomService{Service: svc, projector: projector}
|
||||
}
|
||||
|
||||
func (svc atomService) CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) {
|
||||
created, err := svc.Service.CreateAlarm(ctx, alarm)
|
||||
if err != nil {
|
||||
return created, err
|
||||
}
|
||||
if created.ID == "" {
|
||||
return created, nil
|
||||
}
|
||||
if err := svc.projector.UpsertResource(ctx, alarmProjection(created)); err != nil {
|
||||
return created, nil
|
||||
}
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (svc atomService) UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) {
|
||||
updated, err := svc.Service.UpdateAlarm(ctx, session, alarm)
|
||||
if err != nil {
|
||||
return updated, err
|
||||
}
|
||||
if err := svc.projector.UpsertResource(ctx, alarmProjection(updated)); err != nil {
|
||||
return updated, nil
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (svc atomService) DeleteAlarm(ctx context.Context, session authn.Session, id string) error {
|
||||
if err := svc.Service.DeleteAlarm(ctx, session, id); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = svc.projector.DeleteResource(ctx, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func alarmProjection(a Alarm) atom.Resource {
|
||||
res := atom.ResourceFromFields(atom.ObjectFields{
|
||||
ID: a.ID,
|
||||
Kind: atom.KindAlarm,
|
||||
Name: a.Cause,
|
||||
TenantID: a.DomainID,
|
||||
OwnerID: a.AssigneeID,
|
||||
Status: a.Status.String(),
|
||||
Metadata: map[string]any(a.Metadata),
|
||||
UpdatedBy: a.UpdatedBy,
|
||||
CreatedAt: a.CreatedAt,
|
||||
UpdatedAt: a.UpdatedAt,
|
||||
})
|
||||
res.Attributes["rule_id"] = a.RuleID
|
||||
res.Attributes["channel_id"] = a.ChannelID
|
||||
res.Attributes["client_id"] = a.ClientID
|
||||
res.Attributes["severity"] = a.Severity
|
||||
res.Attributes["measurement"] = a.Measurement
|
||||
res.Attributes["assignee_id"] = a.AssigneeID
|
||||
return res
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package alarms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/internal/atom"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
)
|
||||
|
||||
func TestAtomServiceCreateAlarmProjectsCreatedAlarm(t *testing.T) {
|
||||
projector := &alarmProjector{}
|
||||
svc := WithAtom(alarmService{
|
||||
create: Alarm{
|
||||
ID: "alarm-1",
|
||||
RuleID: "rule-1",
|
||||
DomainID: "domain-1",
|
||||
ChannelID: "channel-1",
|
||||
ClientID: "client-1",
|
||||
Cause: "high temperature",
|
||||
Measurement: "temperature",
|
||||
Severity: 90,
|
||||
Status: ActiveStatus,
|
||||
},
|
||||
}, projector)
|
||||
|
||||
created, err := svc.CreateAlarm(context.Background(), Alarm{RuleID: "rule-1"})
|
||||
if err != nil {
|
||||
t.Fatalf("create alarm: %v", err)
|
||||
}
|
||||
if created.ID != "alarm-1" {
|
||||
t.Fatalf("unexpected created alarm: %#v", created)
|
||||
}
|
||||
if projector.resource.ID != "alarm-1" || projector.resource.Kind != atom.KindAlarm {
|
||||
t.Fatalf("unexpected projection: %#v", projector.resource)
|
||||
}
|
||||
if projector.resource.Attributes["rule_id"] != "rule-1" {
|
||||
t.Fatalf("missing rule projection: %#v", projector.resource.Attributes)
|
||||
}
|
||||
}
|
||||
|
||||
type alarmService struct {
|
||||
create Alarm
|
||||
}
|
||||
|
||||
func (svc alarmService) CreateAlarm(context.Context, Alarm) (Alarm, error) {
|
||||
return svc.create, nil
|
||||
}
|
||||
|
||||
func (svc alarmService) UpdateAlarm(context.Context, authn.Session, Alarm) (Alarm, error) {
|
||||
return Alarm{}, nil
|
||||
}
|
||||
|
||||
func (svc alarmService) ViewAlarm(context.Context, authn.Session, string) (Alarm, error) {
|
||||
return Alarm{}, nil
|
||||
}
|
||||
|
||||
func (svc alarmService) ListAlarms(context.Context, authn.Session, PageMetadata) (AlarmsPage, error) {
|
||||
return AlarmsPage{}, nil
|
||||
}
|
||||
|
||||
func (svc alarmService) DeleteAlarm(context.Context, authn.Session, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type alarmProjector struct {
|
||||
atom.Projector
|
||||
resource atom.Resource
|
||||
}
|
||||
|
||||
func (p *alarmProjector) UpsertResource(_ context.Context, resource atom.Resource) error {
|
||||
p.resource = resource
|
||||
return nil
|
||||
}
|
||||
@@ -48,7 +48,8 @@ func (h handler) Handle(msg *messaging.Message) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.svc.CreateAlarm(context.Background(), alarm)
|
||||
_, err = h.svc.CreateAlarm(context.Background(), alarm)
|
||||
return err
|
||||
}
|
||||
|
||||
func (h handler) Cancel() error {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/absmach/magistrala/alarms"
|
||||
"github.com/absmach/magistrala/alarms/operations"
|
||||
"github.com/absmach/magistrala/auth"
|
||||
"github.com/absmach/magistrala/internal/atom"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
smqauthz "github.com/absmach/magistrala/pkg/authz"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
@@ -26,6 +27,7 @@ var (
|
||||
type authorizationMiddleware struct {
|
||||
svc alarms.Service
|
||||
authz smqauthz.Authorization
|
||||
atomAuthz atom.Authorizer
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation]
|
||||
}
|
||||
|
||||
@@ -43,7 +45,19 @@ func NewAuthorizationMiddleware(svc alarms.Service, authz smqauthz.Authorization
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
|
||||
func NewAtomAuthorizationMiddleware(svc alarms.Service, authz atom.Authorizer, entitiesOps permissions.EntitiesOperations[permissions.Operation]) (alarms.Service, error) {
|
||||
if err := entitiesOps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
atomAuthz: authz,
|
||||
entitiesOps: entitiesOps,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
|
||||
return am.svc.CreateAlarm(ctx, alarm)
|
||||
}
|
||||
|
||||
@@ -58,17 +72,19 @@ func (am *authorizationMiddleware) UpdateAlarm(ctx context.Context, session auth
|
||||
if err := am.authorize(ctx, operations.OpAssignAlarm, session, policies.DomainType, session.DomainID); err != nil {
|
||||
return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err)
|
||||
}
|
||||
domainUserID := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID)
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Subject: domainUserID,
|
||||
Permission: policies.MembershipPermission,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: session.DomainID,
|
||||
}, nil); err != nil {
|
||||
return alarms.Alarm{}, err
|
||||
if am.atomAuthz == nil {
|
||||
domainUserID := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID)
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Subject: domainUserID,
|
||||
Permission: policies.MembershipPermission,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: session.DomainID,
|
||||
}, nil); err != nil {
|
||||
return alarms.Alarm{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -124,6 +140,9 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if am.atomAuthz != nil {
|
||||
return atom.Authorize(ctx, am.atomAuthz, session, perm.String(), objType, obj, atom.KindAlarm)
|
||||
}
|
||||
|
||||
pr := smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
@@ -159,6 +178,9 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
|
||||
if session.Role != authn.SuperAdminRole {
|
||||
return svcerr.ErrSuperAdminAction
|
||||
}
|
||||
if am.atomAuthz != nil {
|
||||
return atom.Authorize(ctx, am.atomAuthz, session, policies.AdminPermission, policies.PlatformType, policies.MagistralaObject, policies.PlatformType)
|
||||
}
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.UserID,
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewLoggingMiddleware(logger *slog.Logger, service alarms.Service) alarms.Se
|
||||
}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) {
|
||||
func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (created alarms.Alarm, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
@@ -52,7 +52,7 @@ func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm
|
||||
lm.logger.Warn("Create alarm failed", args...)
|
||||
return
|
||||
}
|
||||
if alarm.ID != "" {
|
||||
if created.ID != "" {
|
||||
lm.logger.Info("Create alarm completed successfully", args...)
|
||||
}
|
||||
}(time.Now())
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, se
|
||||
}
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
|
||||
func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "create_alarm").Add(1)
|
||||
mm.latency.With("method", "create_alarm").Observe(time.Since(begin).Seconds())
|
||||
|
||||
@@ -27,7 +27,7 @@ func NewTracingMiddleware(tracer trace.Tracer, svc alarms.Service) alarms.Servic
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
|
||||
func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
|
||||
ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "create_alarm", trace.WithAttributes(
|
||||
attribute.String("rule_id", alarm.RuleID),
|
||||
attribute.String("measurement", alarm.Measurement),
|
||||
|
||||
@@ -231,78 +231,6 @@ func (_c *Repository_ListAllAlarms_Call) RunAndReturn(run func(ctx context.Conte
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListUserAlarms provides a mock function for the type Repository
|
||||
func (_mock *Repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
|
||||
ret := _mock.Called(ctx, userID, pm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserAlarms")
|
||||
}
|
||||
|
||||
var r0 alarms.AlarmsPage
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok {
|
||||
return returnFunc(ctx, userID, pm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) alarms.AlarmsPage); ok {
|
||||
r0 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r0 = ret.Get(0).(alarms.AlarmsPage)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, alarms.PageMetadata) error); ok {
|
||||
r1 = returnFunc(ctx, userID, pm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_ListUserAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserAlarms'
|
||||
type Repository_ListUserAlarms_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListUserAlarms is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - pm alarms.PageMetadata
|
||||
func (_e *Repository_Expecter) ListUserAlarms(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserAlarms_Call {
|
||||
return &Repository_ListUserAlarms_Call{Call: _e.mock.On("ListUserAlarms", ctx, userID, pm)}
|
||||
}
|
||||
|
||||
func (_c *Repository_ListUserAlarms_Call) Run(run func(ctx context.Context, userID string, pm alarms.PageMetadata)) *Repository_ListUserAlarms_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 alarms.PageMetadata
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(alarms.PageMetadata)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListUserAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListUserAlarms_Call {
|
||||
_c.Call.Return(alarmsPage, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_ListUserAlarms_Call) RunAndReturn(run func(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListUserAlarms_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateAlarm provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
|
||||
ret := _mock.Called(ctx, alarm)
|
||||
|
||||
+17
-8
@@ -44,20 +44,29 @@ func (_m *Service) EXPECT() *Service_Expecter {
|
||||
}
|
||||
|
||||
// CreateAlarm provides a mock function for the type Service
|
||||
func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error {
|
||||
func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) {
|
||||
ret := _mock.Called(ctx, alarm)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for CreateAlarm")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) error); ok {
|
||||
var r0 alarms.Alarm
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok {
|
||||
return returnFunc(ctx, alarm)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok {
|
||||
r0 = returnFunc(ctx, alarm)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
r0 = ret.Get(0).(alarms.Alarm)
|
||||
}
|
||||
return r0
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok {
|
||||
r1 = returnFunc(ctx, alarm)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm'
|
||||
@@ -90,12 +99,12 @@ func (_c *Service_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alar
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_CreateAlarm_Call) Return(err error) *Service_CreateAlarm_Call {
|
||||
_c.Call.Return(err)
|
||||
func (_c *Service_CreateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Service_CreateAlarm_Call {
|
||||
_c.Call.Return(alarm1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) error) *Service_CreateAlarm_Call {
|
||||
func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Service_CreateAlarm_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -198,35 +198,6 @@ func (r *repository) ListAllAlarms(ctx context.Context, pm alarms.PageMetadata)
|
||||
return r.alarmsPage(ctx, comQuery, pm)
|
||||
}
|
||||
|
||||
func (r *repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
|
||||
clauses := []string{
|
||||
`(
|
||||
EXISTS (
|
||||
SELECT 1
|
||||
FROM rules_roles rr
|
||||
JOIN rules_role_members rrm ON rrm.role_id = rr.id
|
||||
WHERE rr.entity_id = alarms.rule_id AND rrm.member_id = :user_id
|
||||
)
|
||||
OR EXISTS (
|
||||
SELECT 1
|
||||
FROM domains_roles dr
|
||||
JOIN domains_role_members drm ON drm.role_id = dr.id
|
||||
JOIN domains_role_actions dra ON dra.role_id = dr.id
|
||||
WHERE dr.entity_id = alarms.domain_id
|
||||
AND drm.member_id = :user_id
|
||||
AND dra.action LIKE 'alarm%'
|
||||
)
|
||||
)`,
|
||||
}
|
||||
|
||||
clauses = append(clauses, pageQueryConditions(pm)...)
|
||||
query := fmt.Sprintf("WHERE %s", strings.Join(clauses, " AND "))
|
||||
pm.UserID = userID
|
||||
comQuery := fmt.Sprintf(`SELECT DISTINCT %s FROM alarms %s`, alarmColumns, query)
|
||||
|
||||
return r.alarmsPage(ctx, comQuery, pm)
|
||||
}
|
||||
|
||||
func (r *repository) alarmsPage(ctx context.Context, comQuery string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) {
|
||||
dir := api.DescDir
|
||||
if pm.Dir == api.AscDir {
|
||||
|
||||
@@ -415,215 +415,6 @@ func TestListAlarms(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserAlarms(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM domains_role_actions")
|
||||
require.Nil(t, err, fmt.Sprintf("clean domains_role_actions unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM domains_role_members")
|
||||
require.Nil(t, err, fmt.Sprintf("clean domains_role_members unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM domains_roles")
|
||||
require.Nil(t, err, fmt.Sprintf("clean domains_roles unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM domains")
|
||||
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM alarms")
|
||||
require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM rules")
|
||||
require.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewAlarmsRepo(db)
|
||||
|
||||
domainID := generateUUID(t)
|
||||
domainRoute := generateUUID(t)
|
||||
userID := generateUUID(t)
|
||||
otherUserID := generateUUID(t)
|
||||
adminUserID := generateUUID(t)
|
||||
domainUserID := generateUUID(t)
|
||||
|
||||
_, err := db.Exec(`INSERT INTO domains (id, name, route, status) VALUES ($1, $2, $3, $4)`, domainID, namegen.Generate(), domainRoute, 0)
|
||||
require.Nil(t, err, fmt.Sprintf("insert domains unexpected error: %s", err))
|
||||
|
||||
// Create 10 rules and 10 alarms referencing them.
|
||||
// Assign userID to the first 6 rules via role membership.
|
||||
var ruleIDs []string
|
||||
var createdAlarms []alarms.Alarm
|
||||
for i := range 10 {
|
||||
ruleID := generateUUID(t)
|
||||
_, err := db.Exec(`INSERT INTO rules (id, name, domain_id, status, logic_type, logic_value) VALUES ($1, $2, $3, 0, 0, '')`,
|
||||
ruleID, fmt.Sprintf("rule-%d", i), domainID)
|
||||
require.Nil(t, err, fmt.Sprintf("insert rule unexpected error: %s", err))
|
||||
ruleIDs = append(ruleIDs, ruleID)
|
||||
|
||||
alarm := alarms.Alarm{
|
||||
ID: generateUUID(t),
|
||||
RuleID: ruleID,
|
||||
DomainID: domainID,
|
||||
ChannelID: generateUUID(t),
|
||||
ClientID: generateUUID(t),
|
||||
Measurement: namegen.Generate(),
|
||||
Value: namegen.Generate(),
|
||||
Unit: namegen.Generate(),
|
||||
Threshold: namegen.Generate(),
|
||||
Cause: namegen.Generate(),
|
||||
Status: 0,
|
||||
AssigneeID: generateUUID(t),
|
||||
CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute),
|
||||
}
|
||||
alarm, err = repo.CreateAlarm(context.Background(), alarm)
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
createdAlarms = append(createdAlarms, alarm)
|
||||
}
|
||||
|
||||
// Assign userID to the first 6 rules via rules_roles + rules_role_members.
|
||||
userRoleIDs := make([]string, 6)
|
||||
for i := range 6 {
|
||||
roleID := generateUUID(t)
|
||||
userRoleIDs[i] = roleID
|
||||
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i])
|
||||
require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
|
||||
_, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, ruleIDs[i])
|
||||
require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
|
||||
}
|
||||
|
||||
for i := range 10 {
|
||||
var roleID string
|
||||
if i < 6 {
|
||||
roleID = userRoleIDs[i]
|
||||
} else {
|
||||
roleID = generateUUID(t)
|
||||
_, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i])
|
||||
require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err))
|
||||
}
|
||||
_, err := db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, adminUserID, ruleIDs[i])
|
||||
require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err))
|
||||
}
|
||||
|
||||
domainRoleID := generateUUID(t)
|
||||
_, err = db.Exec(`INSERT INTO domains_roles (id, name, entity_id) VALUES ($1, $2, $3)`, domainRoleID, "admin", domainID)
|
||||
require.Nil(t, err, fmt.Sprintf("insert domains_roles unexpected error: %s", err))
|
||||
_, err = db.Exec(`INSERT INTO domains_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, domainRoleID, domainUserID, domainID)
|
||||
require.Nil(t, err, fmt.Sprintf("insert domains_role_members unexpected error: %s", err))
|
||||
_, err = db.Exec(`INSERT INTO domains_role_actions (role_id, action) VALUES ($1, $2)`, domainRoleID, "alarm_read")
|
||||
require.Nil(t, err, fmt.Sprintf("insert domains_role_actions unexpected error: %s", err))
|
||||
|
||||
_ = createdAlarms
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
pm alarms.PageMetadata
|
||||
count int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list user alarms returns only accessible alarms",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 6,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user alarms with limit",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 3,
|
||||
},
|
||||
count: 3,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user alarms with offset",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 4,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user alarms with domain filter",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
DomainID: domainID,
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 6,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user alarms with non-existing domain returns 0",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
DomainID: generateUUID(t),
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 0,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list alarms for user with no role assignments returns 0",
|
||||
userID: otherUserID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 0,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list alarms for admin user with role on all rules returns all alarms",
|
||||
userID: adminUserID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 10,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list alarms for user with domain-level rule access returns all alarms",
|
||||
userID: domainUserID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
},
|
||||
count: 10,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user alarms ordered by created_at ascending",
|
||||
userID: userID,
|
||||
pm: alarms.PageMetadata{
|
||||
Offset: 0,
|
||||
Limit: 100,
|
||||
Order: "created_at",
|
||||
Dir: "asc",
|
||||
},
|
||||
count: 6,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
page, err := repo.ListUserAlarms(context.Background(), tc.userID, tc.pm)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
return
|
||||
}
|
||||
require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err))
|
||||
assert.Equal(t, tc.count, len(page.Alarms), fmt.Sprintf("%s: expected %d alarms, got %d", tc.desc, tc.count, len(page.Alarms)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteAlarm(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM alarms")
|
||||
|
||||
@@ -4,9 +4,6 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
rpostgres "github.com/absmach/magistrala/re/postgres"
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
)
|
||||
@@ -54,12 +51,5 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
},
|
||||
}
|
||||
|
||||
rulesMigration, err := rpostgres.Migration()
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
|
||||
}
|
||||
|
||||
alarmsMigration.Migrations = append(alarmsMigration.Migrations, rulesMigration.Migrations...)
|
||||
|
||||
return alarmsMigration, nil
|
||||
}
|
||||
|
||||
+12
-10
@@ -26,10 +26,10 @@ func NewService(idp magistrala.IDProvider, repo Repository) Service {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error {
|
||||
func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) {
|
||||
id, err := s.idp.ID()
|
||||
if err != nil {
|
||||
return err
|
||||
return Alarm{}, err
|
||||
}
|
||||
alarm.ID = id
|
||||
if alarm.CreatedAt.IsZero() {
|
||||
@@ -37,14 +37,18 @@ func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error {
|
||||
}
|
||||
|
||||
if err := alarm.Validate(); err != nil {
|
||||
return err
|
||||
return Alarm{}, err
|
||||
}
|
||||
|
||||
if _, err = s.repo.CreateAlarm(ctx, alarm); err != nil && err != repoerr.ErrNotFound {
|
||||
return err
|
||||
created, err := s.repo.CreateAlarm(ctx, alarm)
|
||||
if err != nil && err != repoerr.ErrNotFound {
|
||||
return Alarm{}, err
|
||||
}
|
||||
if err == repoerr.ErrNotFound {
|
||||
return Alarm{}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return created, nil
|
||||
}
|
||||
|
||||
func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID string) (Alarm, error) {
|
||||
@@ -52,10 +56,8 @@ func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID
|
||||
}
|
||||
|
||||
func (s *service) ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) {
|
||||
if session.SuperAdmin {
|
||||
return s.repo.ListAllAlarms(ctx, pm)
|
||||
}
|
||||
return s.repo.ListUserAlarms(ctx, session.UserID, pm)
|
||||
pm.DomainID = session.DomainID
|
||||
return s.repo.ListAllAlarms(ctx, pm)
|
||||
}
|
||||
|
||||
func (s *service) DeleteAlarm(ctx context.Context, session authn.Session, alarmID string) error {
|
||||
|
||||
@@ -72,7 +72,7 @@ func TestCreateAlarm(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := repo.On("CreateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err)
|
||||
err := svc.CreateAlarm(context.Background(), tc.alarm)
|
||||
_, err := svc.CreateAlarm(context.Background(), tc.alarm)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
repoCall.Unset()
|
||||
})
|
||||
@@ -205,7 +205,7 @@ func TestListAlarms(t *testing.T) {
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
s := authn.Session{DomainID: tc.pm.DomainID}
|
||||
repoCall := repo.On("ListUserAlarms", context.Background(), s.UserID, tc.pm).Return(tc.page, tc.err)
|
||||
repoCall := repo.On("ListAllAlarms", context.Background(), tc.pm).Return(tc.page, tc.err)
|
||||
_, err := svc.ListAlarms(context.Background(), s, tc.pm)
|
||||
if tc.err != nil {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
+3
-6
@@ -13,10 +13,7 @@ import (
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/clients"
|
||||
"github.com/absmach/magistrala/groups"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/users"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
)
|
||||
|
||||
@@ -80,9 +77,9 @@ const (
|
||||
DefStartLevel = 1
|
||||
DefEndLevel = 0
|
||||
DefStatus = "enabled"
|
||||
DefClientStatus = clients.Enabled
|
||||
DefUserStatus = users.Enabled
|
||||
DefGroupStatus = groups.Enabled
|
||||
DefClientStatus = "enabled"
|
||||
DefUserStatus = "enabled"
|
||||
DefGroupStatus = "enabled"
|
||||
|
||||
// ContentType represents JSON content type.
|
||||
ContentType = "application/json"
|
||||
|
||||
@@ -1,401 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
grpcAuthV1 "github.com/absmach/magistrala/api/grpc/auth/v1"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/auth"
|
||||
grpcapi "github.com/absmach/magistrala/auth/api/grpc/auth"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
const (
|
||||
port = 8081
|
||||
id = "testID"
|
||||
usersType = "users"
|
||||
adminPermission = "admin"
|
||||
authoritiesObj = "authorities"
|
||||
memberRelation = "member"
|
||||
validToken = "valid"
|
||||
inValidToken = "invalid"
|
||||
validPATToken = "valid"
|
||||
)
|
||||
|
||||
var (
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
authAddr = fmt.Sprintf("localhost:%d", port)
|
||||
clientID = testsutil.GenerateUUID(&testing.T{})
|
||||
)
|
||||
|
||||
func startGRPCServer(svc auth.Service, port int) *grpc.Server {
|
||||
listener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
server := grpc.NewServer()
|
||||
grpcAuthV1.RegisterAuthServiceServer(server, grpcapi.NewAuthServer(svc))
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
assert.Nil(&testing.T{}, err, fmt.Sprintf(`"Unexpected error creating auth server %s"`, err))
|
||||
}()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func TestIdentify(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
defer conn.Close()
|
||||
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
key auth.Key
|
||||
idt *grpcAuthV1.AuthNRes
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "authenticate user with valid user token",
|
||||
token: validToken,
|
||||
key: auth.Key{ID: "", Subject: id, Role: auth.UserRole},
|
||||
idt: &grpcAuthV1.AuthNRes{UserId: id, UserRole: uint32(auth.UserRole)},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authenticate user with invalid user token",
|
||||
token: "invalid",
|
||||
key: auth.Key{},
|
||||
idt: &grpcAuthV1.AuthNRes{},
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "authenticate user with empty token",
|
||||
token: "",
|
||||
idt: &grpcAuthV1.AuthNRes{},
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "authenticate user with valid PAT token",
|
||||
token: "pat_" + validPATToken,
|
||||
key: auth.Key{ID: id, Type: auth.PersonalAccessToken, Subject: clientID, Role: auth.UserRole},
|
||||
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: clientID, UserRole: uint32(auth.UserRole), TokenType: uint32(auth.PersonalAccessToken)},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authenticate user with invalid PAT token",
|
||||
token: "pat_invalid",
|
||||
key: auth.Key{},
|
||||
idt: &grpcAuthV1.AuthNRes{},
|
||||
svcErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Identify", mock.Anything, tc.token).Return(tc.key, tc.svcErr)
|
||||
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
|
||||
if idt != nil {
|
||||
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
defer conn.Close()
|
||||
|
||||
grpcClient := grpcapi.NewAuthClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
authRequest *grpcAuthV1.AuthZReq
|
||||
authResponse *grpcAuthV1.AuthZRes
|
||||
expectedReq *policies.Policy
|
||||
expectedPAT *auth.PATAuthz
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "authorize user with authorized token",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: usersType,
|
||||
Object: authoritiesObj,
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with unauthorized token",
|
||||
token: inValidToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: usersType,
|
||||
Object: authoritiesObj,
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with empty subject",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: "",
|
||||
SubjectType: usersType,
|
||||
Object: authoritiesObj,
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingPolicySub,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with empty subject type",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: "",
|
||||
Object: authoritiesObj,
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingPolicySub,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with empty object",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: usersType,
|
||||
Object: "",
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingPolicyObj,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with empty object type",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: usersType,
|
||||
Object: authoritiesObj,
|
||||
ObjectType: "",
|
||||
Relation: memberRelation,
|
||||
Permission: adminPermission,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingPolicyObj,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with empty permission",
|
||||
token: validToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: usersType,
|
||||
Object: authoritiesObj,
|
||||
ObjectType: usersType,
|
||||
Relation: memberRelation,
|
||||
Permission: "",
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMalformedPolicyPer,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with valid PAT token",
|
||||
token: validPATToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Permission: policies.ViewPermission,
|
||||
ObjectType: policies.ClientType,
|
||||
Domain: domainID,
|
||||
Object: clientID,
|
||||
},
|
||||
PatReq: &grpcAuthV1.PATReq{
|
||||
PatId: id,
|
||||
Domain: domainID,
|
||||
Operation: "view",
|
||||
UserId: id,
|
||||
EntityId: clientID,
|
||||
EntityType: auth.ClientsScopeStr,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authorize bootstrap PAT keeps PAT domain when policy domain is empty",
|
||||
token: validPATToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Permission: policies.MembershipPermission,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: domainID,
|
||||
},
|
||||
PatReq: &grpcAuthV1.PATReq{
|
||||
PatId: id,
|
||||
Domain: domainID,
|
||||
Operation: "create",
|
||||
UserId: id,
|
||||
EntityId: auth.AnyIDs,
|
||||
EntityType: auth.BootstrapStr,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
|
||||
expectedReq: &policies.Policy{
|
||||
Domain: domainID,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Subject: id,
|
||||
Permission: policies.MembershipPermission,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: domainID,
|
||||
},
|
||||
expectedPAT: &auth.PATAuthz{
|
||||
PatID: id,
|
||||
UserID: id,
|
||||
EntityType: auth.BootstrapType,
|
||||
EntityID: auth.AnyIDs,
|
||||
Operation: "create",
|
||||
Domain: domainID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authorize user with unauthorized PAT token",
|
||||
token: inValidToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Permission: policies.ViewPermission,
|
||||
ObjectType: policies.ClientType,
|
||||
Domain: domainID,
|
||||
Object: clientID,
|
||||
},
|
||||
PatReq: &grpcAuthV1.PATReq{
|
||||
PatId: id,
|
||||
Domain: domainID,
|
||||
Operation: "view",
|
||||
UserId: id,
|
||||
EntityId: clientID,
|
||||
EntityType: auth.ClientsScopeStr,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "authorize PAT with missing user id",
|
||||
token: validPATToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Permission: policies.ViewPermission,
|
||||
ObjectType: policies.ClientType,
|
||||
Domain: domainID,
|
||||
Object: clientID,
|
||||
},
|
||||
PatReq: &grpcAuthV1.PATReq{
|
||||
PatId: id,
|
||||
Domain: domainID,
|
||||
Operation: "view",
|
||||
EntityId: clientID,
|
||||
EntityType: auth.ClientsScopeStr,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingUserID,
|
||||
},
|
||||
{
|
||||
desc: "authorize PAT with missing entity id",
|
||||
token: validPATToken,
|
||||
authRequest: &grpcAuthV1.AuthZReq{
|
||||
PolicyReq: &grpcAuthV1.PolicyReq{
|
||||
Subject: id,
|
||||
SubjectType: policies.UserType,
|
||||
SubjectKind: policies.UsersKind,
|
||||
Permission: policies.ViewPermission,
|
||||
ObjectType: policies.ClientType,
|
||||
Domain: domainID,
|
||||
Object: clientID,
|
||||
},
|
||||
PatReq: &grpcAuthV1.PATReq{
|
||||
PatId: id,
|
||||
Domain: domainID,
|
||||
Operation: "view",
|
||||
UserId: id,
|
||||
EntityType: auth.ClientsScopeStr,
|
||||
},
|
||||
},
|
||||
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(tc.err)
|
||||
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
|
||||
if ar != nil {
|
||||
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/auth/mocks"
|
||||
)
|
||||
|
||||
var svc *mocks.Service
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
svc = new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
|
||||
code := m.Run()
|
||||
|
||||
server.GracefulStop()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package token_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
grpcTokenV1 "github.com/absmach/magistrala/api/grpc/token/v1"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/auth"
|
||||
grpcapi "github.com/absmach/magistrala/auth/api/grpc/token"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
const (
|
||||
port = 8082
|
||||
validToken = "valid"
|
||||
inValidToken = "invalid"
|
||||
invalidID = "invalid"
|
||||
)
|
||||
|
||||
var (
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
authAddr = fmt.Sprintf("localhost:%d", port)
|
||||
)
|
||||
|
||||
func startGRPCServer(svc auth.Service, port int) *grpc.Server {
|
||||
listener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
server := grpc.NewServer()
|
||||
grpcTokenV1.RegisterTokenServiceServer(server, grpcapi.NewTokenServer(svc))
|
||||
go func() {
|
||||
err := server.Serve(listener)
|
||||
assert.Nil(&testing.T{}, err, fmt.Sprintf(`"Unexpected error creating auth server %s"`, err))
|
||||
}()
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userId string
|
||||
kind auth.KeyType
|
||||
issueResponse auth.Token
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "issue for user with valid token",
|
||||
userId: validID,
|
||||
kind: auth.AccessKey,
|
||||
issueResponse: auth.Token{
|
||||
AccessToken: validToken,
|
||||
RefreshToken: validToken,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue recovery key",
|
||||
userId: validID,
|
||||
kind: auth.RecoveryKey,
|
||||
issueResponse: auth.Token{
|
||||
AccessToken: validToken,
|
||||
RefreshToken: validToken,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue API key unauthenticated",
|
||||
userId: validID,
|
||||
kind: auth.APIKey,
|
||||
issueResponse: auth.Token{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "issue for invalid key type",
|
||||
userId: validID,
|
||||
kind: 32,
|
||||
issueResponse: auth.Token{},
|
||||
err: errors.ErrMalformedEntity,
|
||||
},
|
||||
{
|
||||
desc: "issue for user that does notexist",
|
||||
userId: "",
|
||||
kind: auth.APIKey,
|
||||
issueResponse: auth.Token{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
|
||||
_, err := grpcClient.Issue(context.Background(), &grpcTokenV1.IssueReq{UserId: tc.userId, Type: uint32(tc.kind)})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
issueResponse auth.Token
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "refresh token with valid token",
|
||||
token: validToken,
|
||||
issueResponse: auth.Token{
|
||||
AccessToken: validToken,
|
||||
RefreshToken: validToken,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "refresh token with invalid token",
|
||||
token: inValidToken,
|
||||
issueResponse: auth.Token{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "refresh token with empty token",
|
||||
token: "",
|
||||
issueResponse: auth.Token{},
|
||||
err: apiutil.ErrMissingSecret,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
|
||||
_, err := grpcClient.Refresh(context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: tc.token})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke token with valid id",
|
||||
id: validID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with invalid id",
|
||||
id: invalidID,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with empty id",
|
||||
id: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "revoke already revoked token",
|
||||
id: validID,
|
||||
err: svcerr.ErrConflict,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("RevokeToken", mock.Anything, mock.Anything, tc.id).Return(tc.err)
|
||||
_, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{TokenId: tc.id})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserRefreshTokens(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
listResponse []auth.TokenInfo
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list tokens for user with valid id",
|
||||
userID: validID,
|
||||
listResponse: []auth.TokenInfo{
|
||||
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 1"},
|
||||
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 2"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list tokens for user with empty list",
|
||||
userID: validID,
|
||||
listResponse: []auth.TokenInfo{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list tokens with invalid user id",
|
||||
userID: invalidID,
|
||||
listResponse: nil,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list tokens with empty user id",
|
||||
userID: "",
|
||||
listResponse: nil,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("ListUserRefreshTokens", mock.Anything, tc.userID).Return(tc.listResponse, tc.err)
|
||||
_, err := grpcClient.ListUserRefreshTokens(context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.userID})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package token_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/auth/mocks"
|
||||
)
|
||||
|
||||
var svc *mocks.Service
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
svc = new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
|
||||
code := m.Run()
|
||||
|
||||
server.GracefulStop()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
//go:build oldservices
|
||||
// +build oldservices
|
||||
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
# BOOTSTRAP SERVICE
|
||||
|
||||
New devices need to be configured properly and connected to the Magistrala. Bootstrap service is used in order to accomplish that. This service provides the following features:
|
||||
|
||||
1. Creating new Magistrala Clients
|
||||
2. Providing basic configuration for the newly created Clients
|
||||
3. Enabling/disabling bootstrap enrollments
|
||||
|
||||
Pre-provisioning a new Client is as simple as sending Configuration data to the Bootstrap service. Once the Client is online, it sends a request for initial config to Bootstrap service. Bootstrap service provides an API for enabling and disabling bootstrap enrollments. Bootstrapping does not implicitly enable an enrollment; it has to be done manually.
|
||||
|
||||
In order to bootstrap successfully, the Client needs to send bootstrapping request to the specific URL, as well as a secret key. This key and URL are pre-provisioned during the manufacturing process. If the Client is provisioned on the Bootstrap service side, the corresponding configuration will be sent as a response. Otherwise, the Client will be saved so that it can be provisioned later.
|
||||
|
||||
## Client Configuration Entity
|
||||
|
||||
Client Configuration consists of two logical parts: the custom configuration that can be interpreted by the Client itself and Magistrala-related configuration. Magistrala config contains:
|
||||
|
||||
1. corresponding Magistrala Client ID
|
||||
2. corresponding Magistrala Client key
|
||||
3. list of the Magistrala channels the Client is connected to
|
||||
|
||||
> Note: list of channels contains IDs of the Magistrala channels. These channels are _pre-provisioned_ on the Magistrala side and, unlike corresponding Magistrala Client, Bootstrap service is not able to create Magistrala Channels.
|
||||
|
||||
Enabling and disabling a bootstrap enrollment is an enrollment toggle. Configuration keeps a _status_:
|
||||
|
||||
| Status | What it means |
|
||||
| -------- | ----------------------------------------------------------- |
|
||||
| disabled | Enrollment exists, but bootstrap is not allowed |
|
||||
| enabled | Enrollment can be used to fetch bootstrap configuration |
|
||||
|
||||
Switching between statuses `enabled` and `disabled` enables and disables the enrollment, respectively.
|
||||
|
||||
Client configuration also contains the so-called `external ID` and `external key`. An external ID is a unique identifier of corresponding Client. For example, a device MAC address is a good choice for external ID. External key is a secret key that is used for authentication during the bootstrapping procedure.
|
||||
|
||||
## Configuration
|
||||
|
||||
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | -------------------------------------------------------------------------------- | --------------------------------- |
|
||||
| MG_BOOTSTRAP_LOG_LEVEL | Log level for Bootstrap (debug, info, warn, error) | info |
|
||||
| MG_BOOTSTRAP_DB_HOST | Database host address | localhost |
|
||||
| MG_BOOTSTRAP_DB_PORT | Database host port | 5432 |
|
||||
| MG_BOOTSTRAP_DB_USER | Database user | magistrala |
|
||||
| MG_BOOTSTRAP_DB_PASS | Database password | magistrala |
|
||||
| MG_BOOTSTRAP_DB_NAME | Name of the database used by the service | bootstrap |
|
||||
| MG_BOOTSTRAP_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
|
||||
| MG_BOOTSTRAP_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
|
||||
| MG_BOOTSTRAP_DB_SSL_KEY | Path to the PEM encoded key file | "" |
|
||||
| MG_BOOTSTRAP_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
|
||||
| MG_BOOTSTRAP_ENCRYPT_KEY | Secret key for secure bootstrapping encryption | 12345678910111213141516171819202 |
|
||||
| MG_BOOTSTRAP_HTTP_HOST | Bootstrap service HTTP host | "" |
|
||||
| MG_BOOTSTRAP_HTTP_PORT | Bootstrap service HTTP port | 9013 |
|
||||
| MG_BOOTSTRAP_HTTP_SERVER_CERT | Path to server certificate in pem format | "" |
|
||||
| MG_BOOTSTRAP_HTTP_SERVER_KEY | Path to server key in pem format | "" |
|
||||
| MG_BOOTSTRAP_EVENT_CONSUMER | Bootstrap service event source consumer name | bootstrap |
|
||||
| MG_ES_URL | Event store URL | <nats://localhost:4222> |
|
||||
| MG_AUTH_GRPC_URL | Auth service Auth gRPC URL | <localhost:8181> |
|
||||
| MG_AUTH_GRPC_TIMEOUT | Auth service Auth gRPC request timeout in seconds | 1s |
|
||||
| MG_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded auth service Auth gRPC client certificate file | "" |
|
||||
| MG_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded auth service Auth gRPC client key file | "" |
|
||||
| MG_AUTH_GRPC_SERVER_CERTS | Path to the PEM encoded auth server Auth gRPC server trusted CA certificate file | "" |
|
||||
| MG_CLIENTS_URL | Base URL for Magistrala Clients | <http://localhost:9000> |
|
||||
| MG_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
|
||||
| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
|
||||
| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true |
|
||||
| MG_BOOTSTRAP_INSTANCE_ID | Bootstrap service instance ID | "" |
|
||||
|
||||
## Deployment
|
||||
|
||||
The service itself is distributed as Docker container. Check the [`bootstrap`](https://github.com/absmach/magistrala/blob/main/docker/addons/bootstrap/docker-compose.yaml) service section in docker-compose file to see how service is deployed.
|
||||
|
||||
To start the service outside of the container, execute the following shell script:
|
||||
|
||||
```bash
|
||||
# download the latest version of the service
|
||||
git clone https://github.com/absmach/magistrala
|
||||
|
||||
cd magistrala
|
||||
|
||||
# compile the servic e
|
||||
make bootstrap
|
||||
|
||||
# copy binary to bin
|
||||
make install
|
||||
|
||||
# set the environment variables and run the service
|
||||
MG_BOOTSTRAP_LOG_LEVEL=info \
|
||||
MG_BOOTSTRAP_DB_HOST=localhost \
|
||||
MG_BOOTSTRAP_DB_PORT=5432 \
|
||||
MG_BOOTSTRAP_DB_USER=magistrala \
|
||||
MG_BOOTSTRAP_DB_PASS=magistrala \
|
||||
MG_BOOTSTRAP_DB_NAME=bootstrap \
|
||||
MG_BOOTSTRAP_DB_SSL_MODE=disable \
|
||||
MG_BOOTSTRAP_DB_SSL_CERT="" \
|
||||
MG_BOOTSTRAP_DB_SSL_KEY="" \
|
||||
MG_BOOTSTRAP_DB_SSL_ROOT_CERT="" \
|
||||
MG_BOOTSTRAP_HTTP_HOST=localhost \
|
||||
MG_BOOTSTRAP_HTTP_PORT=9013 \
|
||||
MG_BOOTSTRAP_HTTP_SERVER_CERT="" \
|
||||
MG_BOOTSTRAP_HTTP_SERVER_KEY="" \
|
||||
MG_BOOTSTRAP_EVENT_CONSUMER=bootstrap \
|
||||
MG_ES_URL=nats://localhost:4222 \
|
||||
MG_AUTH_GRPC_URL=localhost:8181 \
|
||||
MG_AUTH_GRPC_TIMEOUT=1s \
|
||||
MG_AUTH_GRPC_CLIENT_CERT="" \
|
||||
MG_AUTH_GRPC_CLIENT_KEY="" \
|
||||
MG_AUTH_GRPC_SERVER_CERTS="" \
|
||||
MG_CLIENTS_URL=http://localhost:9000 \
|
||||
MG_JAEGER_URL=http://localhost:14268/api/traces \
|
||||
MG_JAEGER_TRACE_RATIO=1.0 \
|
||||
MG_SEND_TELEMETRY=true \
|
||||
MG_BOOTSTRAP_INSTANCE_ID="" \
|
||||
$GOBIN/magistrala-bootstrap
|
||||
```
|
||||
|
||||
Setting `MG_BOOTSTRAP_HTTP_SERVER_CERT` and `MG_BOOTSTRAP_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key.
|
||||
|
||||
Setting `MG_AUTH_GRPC_CLIENT_CERT` and `MG_AUTH_GRPC_CLIENT_KEY` will enable TLS against the auth service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_AUTH_GRPC_SERVER_CERTS` will enable TLS against the auth service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs.
|
||||
|
||||
## Usage
|
||||
|
||||
For more information about service capabilities and its usage, please check out the [API documentation](https://docs.api.magistrala.absmach.eu/?urls.primaryName=bootstrap.yaml).
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package api contains implementation of bootstrap service HTTP API.
|
||||
package api
|
||||
@@ -1,506 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
func addEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(addReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
config := bootstrap.Config{
|
||||
ExternalID: req.ExternalID,
|
||||
ExternalKey: req.ExternalKey,
|
||||
Name: req.Name,
|
||||
ClientCert: req.ClientCert,
|
||||
ClientKey: req.ClientKey,
|
||||
CACert: req.CACert,
|
||||
Content: req.Content,
|
||||
ProfileID: req.ProfileID,
|
||||
RenderContext: req.RenderContext,
|
||||
}
|
||||
|
||||
saved, err := svc.Add(ctx, session, req.token, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := configRes{
|
||||
ID: saved.ID,
|
||||
ExternalID: saved.ExternalID,
|
||||
Name: saved.Name,
|
||||
Content: saved.Content,
|
||||
Status: saved.Status,
|
||||
ProfileID: saved.ProfileID,
|
||||
RenderContext: saved.RenderContext,
|
||||
ClientCert: saved.ClientCert,
|
||||
CACert: saved.CACert,
|
||||
ClientKey: saved.ClientKey,
|
||||
created: true,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateCertEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(updateCertReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
cfg, err := svc.UpdateCert(ctx, session, req.configID, req.ClientCert, req.ClientKey, req.CACert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := updateConfigRes{
|
||||
ID: cfg.ID,
|
||||
ClientCert: cfg.ClientCert,
|
||||
CACert: cfg.CACert,
|
||||
ClientKey: cfg.ClientKey,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func viewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(entityReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
config, err := svc.View(ctx, session, req.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := viewRes{
|
||||
ID: config.ID,
|
||||
ExternalID: config.ExternalID,
|
||||
Name: config.Name,
|
||||
Content: config.Content,
|
||||
Status: config.Status,
|
||||
ProfileID: config.ProfileID,
|
||||
RenderContext: config.RenderContext,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(updateReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
config := bootstrap.Config{
|
||||
ID: req.id,
|
||||
Name: req.Name,
|
||||
Content: req.Content,
|
||||
RenderContext: req.RenderContext,
|
||||
}
|
||||
|
||||
if err := svc.Update(ctx, session, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updateRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(listReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
page, err := svc.List(ctx, session, req.filter, req.offset, req.limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := listRes{
|
||||
Total: page.Total,
|
||||
Offset: page.Offset,
|
||||
Limit: page.Limit,
|
||||
Configs: []viewRes{},
|
||||
}
|
||||
|
||||
for _, cfg := range page.Configs {
|
||||
view := viewRes{
|
||||
ID: cfg.ID,
|
||||
ExternalID: cfg.ExternalID,
|
||||
Name: cfg.Name,
|
||||
Content: cfg.Content,
|
||||
Status: cfg.Status,
|
||||
ProfileID: cfg.ProfileID,
|
||||
RenderContext: cfg.RenderContext,
|
||||
}
|
||||
res.Configs = append(res.Configs, view)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func removeEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(entityReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return removeRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
if err := svc.Remove(ctx, session, req.id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return removeRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bootstrapEndpoint(svc bootstrap.Service, reader bootstrap.ConfigReader, secure bool) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(bootstrapReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
cfg, err := svc.Bootstrap(ctx, req.key, req.id, secure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reader.ReadConfig(cfg, secure)
|
||||
}
|
||||
}
|
||||
|
||||
func enableConfigEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(changeConfigStatusReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
cfg, err := svc.EnableConfig(ctx, session, req.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return changeConfigStatusRes{Config: cfg}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func disableConfigEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(changeConfigStatusReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
cfg, err := svc.DisableConfig(ctx, session, req.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return changeConfigStatusRes{Config: cfg}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(createProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
saved, err := svc.CreateProfile(ctx, session, req.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profileRes{Profile: saved, created: true}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func uploadProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(uploadProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
saved, err := svc.CreateProfile(ctx, session, req.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profileRes{Profile: saved, created: true}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func viewProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(viewProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
p, err := svc.ViewProfile(ctx, session, req.profileID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profileRes{Profile: p}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func profileSlotsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(viewProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
p, err := svc.ViewProfile(ctx, session, req.profileID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profileSlotsRes{BindingSlots: p.BindingSlots}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func renderPreviewEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(renderPreviewReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
p, err := svc.ViewProfile(ctx, session, req.profileID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := req.Config
|
||||
bindings := req.Bindings
|
||||
|
||||
if req.ConfigID != "" {
|
||||
stored, err := svc.View(ctx, session, req.ConfigID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg = stored
|
||||
bindings, err = svc.ListBindings(ctx, session, req.ConfigID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cfg.DomainID = session.DomainID
|
||||
cfg.ProfileID = p.ID
|
||||
if cfg.RenderContext == nil {
|
||||
cfg.RenderContext = req.RenderContext
|
||||
}
|
||||
|
||||
rendered, err := bootstrap.NewRenderer().Render(p, cfg, bindings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return renderPreviewRes{Content: string(rendered)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(updateProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
req.Profile.ID = req.profileID
|
||||
updated, err := svc.UpdateProfile(ctx, session, req.Profile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profileRes{Profile: updated}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func deleteProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(deleteProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
if err := svc.DeleteProfile(ctx, session, req.profileID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return removeRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listProfilesEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(listProfilesReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
page, err := svc.ListProfiles(ctx, session, req.offset, req.limit, req.name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return profilesPageRes{ProfilesPage: page}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func assignProfileEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(assignProfileReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
if err := svc.AssignProfile(ctx, session, req.configID, req.ProfileID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return removeRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func bindResourcesEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(bindResourcesReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
if err := svc.BindResources(ctx, session, req.token, req.configID, req.Bindings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return removeRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listBindingsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(listBindingsReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
snapshots, err := svc.ListBindings(ctx, session, req.configID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bindingsRes{Bindings: snapshots}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func refreshBindingsEndpoint(svc bootstrap.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(refreshBindingsReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
if err := svc.RefreshBindings(ctx, session, req.token, req.configID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return removeRes{}, nil
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,280 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
)
|
||||
|
||||
const maxLimitSize = 100
|
||||
|
||||
type addReq struct {
|
||||
token string
|
||||
ExternalID string `json:"external_id"`
|
||||
ExternalKey string `json:"external_key"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
ClientCert string `json:"client_cert"`
|
||||
ClientKey string `json:"client_key"`
|
||||
CACert string `json:"ca_cert"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
RenderContext map[string]any `json:"render_context"`
|
||||
}
|
||||
|
||||
func (req addReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrBearerToken
|
||||
}
|
||||
|
||||
if req.ExternalID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
if req.ExternalKey == "" {
|
||||
return apiutil.ErrBearerKey
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type entityReq struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (req entityReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateReq struct {
|
||||
id string
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
RenderContext map[string]any `json:"render_context"`
|
||||
}
|
||||
|
||||
func (req updateReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateCertReq struct {
|
||||
configID string
|
||||
ClientCert string `json:"client_cert"`
|
||||
ClientKey string `json:"client_key"`
|
||||
CACert string `json:"ca_cert"`
|
||||
}
|
||||
|
||||
func (req updateCertReq) validate() error {
|
||||
if req.configID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type listReq struct {
|
||||
filter bootstrap.Filter
|
||||
offset uint64
|
||||
limit uint64
|
||||
}
|
||||
|
||||
func (req listReq) validate() error {
|
||||
if req.limit > maxLimitSize {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type bootstrapReq struct {
|
||||
key string
|
||||
id string
|
||||
}
|
||||
|
||||
func (req bootstrapReq) validate() error {
|
||||
if req.key == "" {
|
||||
return apiutil.ErrBearerKey
|
||||
}
|
||||
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type changeConfigStatusReq struct {
|
||||
token string
|
||||
id string
|
||||
}
|
||||
|
||||
func (req changeConfigStatusReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrBearerToken
|
||||
}
|
||||
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Profile requests ---
|
||||
|
||||
type createProfileReq struct {
|
||||
bootstrap.Profile
|
||||
}
|
||||
|
||||
func (req createProfileReq) validate() error {
|
||||
if req.Name == "" {
|
||||
return apiutil.ErrMissingName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type uploadProfileReq struct {
|
||||
bootstrap.Profile
|
||||
}
|
||||
|
||||
func (req uploadProfileReq) validate() error {
|
||||
if req.Name == "" {
|
||||
return apiutil.ErrMissingName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type viewProfileReq struct {
|
||||
profileID string
|
||||
}
|
||||
|
||||
func (req viewProfileReq) validate() error {
|
||||
if req.profileID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateProfileReq struct {
|
||||
profileID string
|
||||
bootstrap.Profile
|
||||
}
|
||||
|
||||
func (req updateProfileReq) validate() error {
|
||||
if req.profileID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type renderPreviewReq struct {
|
||||
profileID string
|
||||
ConfigID string `json:"config_id,omitempty"`
|
||||
Config bootstrap.Config `json:"config"`
|
||||
RenderContext map[string]any `json:"render_context,omitempty"`
|
||||
Bindings []bootstrap.BindingSnapshot `json:"bindings,omitempty"`
|
||||
}
|
||||
|
||||
func (req renderPreviewReq) validate() error {
|
||||
if req.profileID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type deleteProfileReq struct {
|
||||
profileID string
|
||||
}
|
||||
|
||||
func (req deleteProfileReq) validate() error {
|
||||
if req.profileID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type listProfilesReq struct {
|
||||
offset uint64
|
||||
limit uint64
|
||||
name string
|
||||
}
|
||||
|
||||
func (req listProfilesReq) validate() error {
|
||||
if req.limit == 0 || req.limit > maxLimitSize {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Enrollment binding requests ---
|
||||
|
||||
type assignProfileReq struct {
|
||||
configID string
|
||||
ProfileID string `json:"profile_id"`
|
||||
}
|
||||
|
||||
func (req assignProfileReq) validate() error {
|
||||
if req.configID == "" || req.ProfileID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type bindResourcesReq struct {
|
||||
token string
|
||||
configID string
|
||||
Bindings []bootstrap.BindingRequest `json:"bindings"`
|
||||
}
|
||||
|
||||
func (req bindResourcesReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrBearerToken
|
||||
}
|
||||
if req.configID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
if len(req.Bindings) == 0 {
|
||||
return apiutil.ErrEmptyList
|
||||
}
|
||||
for _, b := range req.Bindings {
|
||||
if b.Slot == "" || b.Type == "" || b.ResourceID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type listBindingsReq struct {
|
||||
configID string
|
||||
}
|
||||
|
||||
func (req listBindingsReq) validate() error {
|
||||
if req.configID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type refreshBindingsReq struct {
|
||||
token string
|
||||
configID string
|
||||
}
|
||||
|
||||
func (req refreshBindingsReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrBearerToken
|
||||
}
|
||||
if req.configID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,245 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
externalID string
|
||||
externalKey string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty token",
|
||||
token: "",
|
||||
externalID: "external-id",
|
||||
externalKey: "external-key",
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "empty external ID",
|
||||
token: "token",
|
||||
externalID: "",
|
||||
externalKey: "external-key",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "empty external key",
|
||||
token: "token",
|
||||
externalID: "external-id",
|
||||
externalKey: "",
|
||||
err: apiutil.ErrBearerKey,
|
||||
},
|
||||
{
|
||||
desc: "empty external key and external ID",
|
||||
token: "token",
|
||||
externalID: "",
|
||||
externalKey: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := addReq{
|
||||
token: tc.token,
|
||||
ExternalID: tc.externalID,
|
||||
ExternalKey: tc.externalKey,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntityReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "empty id",
|
||||
id: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := entityReq{
|
||||
id: tc.id,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
id: "id",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty id",
|
||||
id: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := updateReq{
|
||||
id: tc.id,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateCertReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
configID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "empty config id",
|
||||
configID: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := updateCertReq{
|
||||
configID: tc.configID,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
offset uint64
|
||||
limit uint64
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "too large limit",
|
||||
offset: 0,
|
||||
limit: maxLimitSize + 1,
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "default limit",
|
||||
offset: 0,
|
||||
limit: defLimit,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := listReq{
|
||||
offset: tc.offset,
|
||||
limit: tc.limit,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrapReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
externKey string
|
||||
externID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "empty external key",
|
||||
externKey: "",
|
||||
externID: "id",
|
||||
err: apiutil.ErrBearerKey,
|
||||
},
|
||||
{
|
||||
desc: "empty external id",
|
||||
externKey: "key",
|
||||
externID: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := bootstrapReq{
|
||||
id: tc.externID,
|
||||
key: tc.externKey,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeConfigStatusReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "empty token",
|
||||
token: "",
|
||||
id: "id",
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "empty id",
|
||||
token: "token",
|
||||
id: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "valid request",
|
||||
token: "token",
|
||||
id: "id",
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
req := changeConfigStatusReq{
|
||||
token: tc.token,
|
||||
id: tc.id,
|
||||
}
|
||||
|
||||
err := req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
)
|
||||
|
||||
var (
|
||||
_ magistrala.Response = (*removeRes)(nil)
|
||||
_ magistrala.Response = (*configRes)(nil)
|
||||
_ magistrala.Response = (*changeConfigStatusRes)(nil)
|
||||
_ magistrala.Response = (*viewRes)(nil)
|
||||
_ magistrala.Response = (*listRes)(nil)
|
||||
)
|
||||
|
||||
type removeRes struct{}
|
||||
|
||||
func (res removeRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res removeRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res removeRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type updateRes struct{}
|
||||
|
||||
func (res updateRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res updateRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res updateRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type configRes struct {
|
||||
ID string `json:"id"`
|
||||
ExternalID string `json:"external_id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Status bootstrap.Status `json:"status"`
|
||||
ProfileID string `json:"profile_id,omitempty"`
|
||||
RenderContext map[string]any `json:"render_context,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
created bool
|
||||
}
|
||||
|
||||
func (res configRes) Code() int {
|
||||
if res.created {
|
||||
return http.StatusCreated
|
||||
}
|
||||
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res configRes) Headers() map[string]string {
|
||||
if res.created {
|
||||
return map[string]string{
|
||||
"Location": fmt.Sprintf("/clients/configs/%s", res.ID),
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res configRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type viewRes struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
ExternalID string `json:"external_id"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Status bootstrap.Status `json:"status"`
|
||||
ProfileID string `json:"profile_id,omitempty"`
|
||||
RenderContext map[string]any `json:"render_context,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
}
|
||||
|
||||
func (res viewRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res viewRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res viewRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type listRes struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Configs []viewRes `json:"configs"`
|
||||
}
|
||||
|
||||
func (res listRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res listRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res listRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type changeConfigStatusRes struct {
|
||||
bootstrap.Config
|
||||
}
|
||||
|
||||
func (res changeConfigStatusRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res changeConfigStatusRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res changeConfigStatusRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type updateConfigRes struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
}
|
||||
|
||||
func (res updateConfigRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res updateConfigRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res updateConfigRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// profileRes is returned on create (201) or update (200).
|
||||
type profileRes struct {
|
||||
bootstrap.Profile
|
||||
created bool
|
||||
}
|
||||
|
||||
func (res profileRes) Code() int {
|
||||
if res.created {
|
||||
return http.StatusCreated
|
||||
}
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res profileRes) Headers() map[string]string {
|
||||
if res.created {
|
||||
return map[string]string{
|
||||
"Location": fmt.Sprintf("/bootstrap/profiles/%s", res.ID),
|
||||
}
|
||||
}
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res profileRes) Empty() bool { return false }
|
||||
|
||||
// profilesPageRes is returned by ListProfiles.
|
||||
type profilesPageRes struct {
|
||||
bootstrap.ProfilesPage
|
||||
}
|
||||
|
||||
func (res profilesPageRes) Code() int { return http.StatusOK }
|
||||
func (res profilesPageRes) Headers() map[string]string { return map[string]string{} }
|
||||
func (res profilesPageRes) Empty() bool { return false }
|
||||
|
||||
// profileSlotsRes is returned by profile slots endpoint.
|
||||
type profileSlotsRes struct {
|
||||
BindingSlots []bootstrap.BindingSlot `json:"binding_slots"`
|
||||
}
|
||||
|
||||
func (res profileSlotsRes) Code() int { return http.StatusOK }
|
||||
func (res profileSlotsRes) Headers() map[string]string { return map[string]string{} }
|
||||
func (res profileSlotsRes) Empty() bool { return false }
|
||||
|
||||
// renderPreviewRes is returned by profile render-preview endpoint.
|
||||
type renderPreviewRes struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (res renderPreviewRes) Code() int { return http.StatusOK }
|
||||
func (res renderPreviewRes) Headers() map[string]string { return map[string]string{} }
|
||||
func (res renderPreviewRes) Empty() bool { return false }
|
||||
|
||||
// bindingsRes is returned by ListBindings.
|
||||
type bindingsRes struct {
|
||||
Bindings []bootstrap.BindingSnapshot `json:"bindings"`
|
||||
}
|
||||
|
||||
func (res bindingsRes) Code() int { return http.StatusOK }
|
||||
func (res bindingsRes) Headers() map[string]string { return map[string]string{} }
|
||||
func (res bindingsRes) Empty() bool { return false }
|
||||
@@ -1,512 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
api "github.com/absmach/magistrala/api/http"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
contentType = "application/json"
|
||||
yamlContentType = "yaml"
|
||||
tomlContentType = "toml"
|
||||
byteContentType = "application/octet-stream"
|
||||
offsetKey = "offset"
|
||||
limitKey = "limit"
|
||||
defOffset = 0
|
||||
defLimit = 10
|
||||
)
|
||||
|
||||
var (
|
||||
fullMatch = []string{"status", "external_id", "id"}
|
||||
partialMatch = []string{"name"}
|
||||
// ErrBootstrap indicates error in getting bootstrap configuration.
|
||||
ErrBootstrap = errors.New("failed to read bootstrap configuration")
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
func MakeHandler(svc bootstrap.Service, authn smqauthn.AuthNMiddleware, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Route("/{domainID}/clients", func(r chi.Router) {
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
|
||||
r.Route("/configs", func(r chi.Router) {
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
addEndpoint(svc),
|
||||
decodeAddRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "add").ServeHTTP)
|
||||
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listEndpoint(svc),
|
||||
decodeListRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "list").ServeHTTP)
|
||||
|
||||
r.Get("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewEndpoint(svc),
|
||||
decodeEntityRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "view").ServeHTTP)
|
||||
|
||||
r.Patch("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateEndpoint(svc),
|
||||
decodeUpdateRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "update").ServeHTTP)
|
||||
|
||||
r.Delete("/{configID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
removeEndpoint(svc),
|
||||
decodeEntityRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "remove").ServeHTTP)
|
||||
|
||||
r.Patch("/certs/{configID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateCertEndpoint(svc),
|
||||
decodeUpdateCertRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "update_cert").ServeHTTP)
|
||||
|
||||
r.Post("/{configID}/enable", otelhttp.NewHandler(kithttp.NewServer(
|
||||
enableConfigEndpoint(svc),
|
||||
decodeChangeConfigStatusRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "enable_config").ServeHTTP)
|
||||
|
||||
r.Post("/{configID}/disable", otelhttp.NewHandler(kithttp.NewServer(
|
||||
disableConfigEndpoint(svc),
|
||||
decodeChangeConfigStatusRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "disable_config").ServeHTTP)
|
||||
})
|
||||
})
|
||||
|
||||
// Profile and enrollment binding endpoints.
|
||||
r.Route("/bootstrap", func(r chi.Router) {
|
||||
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware())
|
||||
|
||||
r.Route("/profiles", func(r chi.Router) {
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
createProfileEndpoint(svc),
|
||||
decodeCreateProfileRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "create_profile").ServeHTTP)
|
||||
|
||||
r.Post("/upload", otelhttp.NewHandler(kithttp.NewServer(
|
||||
uploadProfileEndpoint(svc),
|
||||
decodeUploadProfileRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "upload_profile").ServeHTTP)
|
||||
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listProfilesEndpoint(svc),
|
||||
decodeListProfilesRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "list_profiles").ServeHTTP)
|
||||
|
||||
r.Get("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewProfileEndpoint(svc),
|
||||
decodeProfileEntityRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "view_profile").ServeHTTP)
|
||||
|
||||
r.Get("/{profileID}/slots", otelhttp.NewHandler(kithttp.NewServer(
|
||||
profileSlotsEndpoint(svc),
|
||||
decodeProfileEntityRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "profile_slots").ServeHTTP)
|
||||
|
||||
r.Post("/{profileID}/render-preview", otelhttp.NewHandler(kithttp.NewServer(
|
||||
renderPreviewEndpoint(svc),
|
||||
decodeRenderPreviewRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "render_preview").ServeHTTP)
|
||||
|
||||
r.Patch("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateProfileEndpoint(svc),
|
||||
decodeUpdateProfileRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "update_profile").ServeHTTP)
|
||||
|
||||
r.Delete("/{profileID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
deleteProfileEndpoint(svc),
|
||||
decodeDeleteProfileRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "delete_profile").ServeHTTP)
|
||||
})
|
||||
|
||||
r.Route("/enrollments", func(r chi.Router) {
|
||||
r.Patch("/{configID}/profile", otelhttp.NewHandler(kithttp.NewServer(
|
||||
assignProfileEndpoint(svc),
|
||||
decodeAssignProfileRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "assign_profile").ServeHTTP)
|
||||
|
||||
r.Put("/{configID}/bindings", otelhttp.NewHandler(kithttp.NewServer(
|
||||
bindResourcesEndpoint(svc),
|
||||
decodeBindResourcesRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "bind_resources").ServeHTTP)
|
||||
|
||||
r.Get("/{configID}/bindings", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listBindingsEndpoint(svc),
|
||||
decodeEnrollmentEntityRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "list_bindings").ServeHTTP)
|
||||
|
||||
r.Post("/{configID}/bindings/refresh", otelhttp.NewHandler(kithttp.NewServer(
|
||||
refreshBindingsEndpoint(svc),
|
||||
decodeRefreshBindingsRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "refresh_bindings").ServeHTTP)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/clients/bootstrap", func(r chi.Router) {
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
bootstrapEndpoint(svc, reader, false),
|
||||
decodeBootstrapRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "bootstrap").ServeHTTP)
|
||||
r.Get("/{externalID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
bootstrapEndpoint(svc, reader, false),
|
||||
decodeBootstrapRequest,
|
||||
api.EncodeResponse,
|
||||
opts...), "bootstrap").ServeHTTP)
|
||||
r.Get("/secure/{externalID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
bootstrapEndpoint(svc, reader, true),
|
||||
decodeBootstrapRequest,
|
||||
encodeSecureRes,
|
||||
opts...), "bootstrap_secure").ServeHTTP)
|
||||
})
|
||||
|
||||
r.Get("/health", magistrala.Health("bootstrap", instanceID))
|
||||
r.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func decodeAddRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
|
||||
req := addReq{
|
||||
token: apiutil.ExtractBearerToken(r),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeUpdateRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
|
||||
req := updateReq{
|
||||
id: chi.URLParam(r, "configID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeUpdateCertRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
|
||||
req := updateCertReq{
|
||||
configID: chi.URLParam(r, "configID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeListRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
q, err := url.ParseQuery(r.URL.RawQuery)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidQueryParams)
|
||||
}
|
||||
|
||||
req := listReq{
|
||||
filter: parseFilter(q),
|
||||
offset: o,
|
||||
limit: l,
|
||||
}
|
||||
|
||||
rawStatus := q.Get("status")
|
||||
parsed, err := bootstrap.ToStatus(rawStatus)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidQueryParams)
|
||||
}
|
||||
if parsed == bootstrap.AllStatus {
|
||||
delete(req.filter.FullMatch, "status")
|
||||
} else {
|
||||
req.filter.FullMatch["status"] = parsed.String()
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBootstrapRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
req := bootstrapReq{
|
||||
id: chi.URLParam(r, "externalID"),
|
||||
key: apiutil.ExtractClientSecret(r),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeChangeConfigStatusRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
return changeConfigStatusReq{
|
||||
token: apiutil.ExtractBearerToken(r),
|
||||
id: chi.URLParam(r, "configID"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeEntityRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
req := entityReq{
|
||||
id: chi.URLParam(r, "configID"),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func encodeSecureRes(_ context.Context, w http.ResponseWriter, response any) error {
|
||||
w.Header().Set("Content-Type", byteContentType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if b, ok := response.([]byte); ok {
|
||||
if _, err := w.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseFilter(values url.Values) bootstrap.Filter {
|
||||
ret := bootstrap.Filter{
|
||||
FullMatch: make(map[string]string),
|
||||
PartialMatch: make(map[string]string),
|
||||
}
|
||||
for k := range values {
|
||||
if contains(fullMatch, k) {
|
||||
ret.FullMatch[k] = values.Get(k)
|
||||
}
|
||||
if contains(partialMatch, k) {
|
||||
ret.PartialMatch[k] = strings.ToLower(values.Get(k))
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func contains(l []string, s string) bool {
|
||||
for _, v := range l {
|
||||
if v == s {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func decodeCreateProfileRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
var req createProfileReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeUploadProfileRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
var req uploadProfileReq
|
||||
var inferredFormat bootstrap.ContentFormat
|
||||
|
||||
switch {
|
||||
case strings.Contains(contentType, "json"):
|
||||
inferredFormat = bootstrap.ContentFormatJSON
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
case strings.Contains(contentType, yamlContentType):
|
||||
inferredFormat = bootstrap.ContentFormatYAML
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
if err := decodeYAMLProfile(body, &req.Profile); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
case strings.Contains(contentType, tomlContentType):
|
||||
inferredFormat = bootstrap.ContentFormatTOML
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
if err := decodeTOMLProfile(body, &req.Profile); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
default:
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
|
||||
if req.Profile.ContentFormat == "" {
|
||||
req.Profile.ContentFormat = inferredFormat
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeYAMLProfile(body []byte, profile *bootstrap.Profile) error {
|
||||
var raw map[string]any
|
||||
if err := yaml.Unmarshal(body, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
return decodeProfileMap(raw, profile)
|
||||
}
|
||||
|
||||
func decodeTOMLProfile(body []byte, profile *bootstrap.Profile) error {
|
||||
var raw map[string]any
|
||||
if err := toml.Unmarshal(body, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
return decodeProfileMap(raw, profile)
|
||||
}
|
||||
|
||||
func decodeProfileMap(raw map[string]any, profile *bootstrap.Profile) error {
|
||||
body, err := json.Marshal(raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(body, profile)
|
||||
}
|
||||
|
||||
func decodeListProfilesRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
n, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
return listProfilesReq{offset: o, limit: l, name: n}, nil
|
||||
}
|
||||
|
||||
func decodeProfileEntityRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
return viewProfileReq{profileID: chi.URLParam(r, "profileID")}, nil
|
||||
}
|
||||
|
||||
func decodeDeleteProfileRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
return deleteProfileReq{profileID: chi.URLParam(r, "profileID")}, nil
|
||||
}
|
||||
|
||||
func decodeUpdateProfileRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
req := updateProfileReq{profileID: chi.URLParam(r, "profileID")}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Profile); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeRenderPreviewRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
req := renderPreviewReq{profileID: chi.URLParam(r, "profileID")}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeAssignProfileRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
req := assignProfileReq{configID: chi.URLParam(r, "configID")}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBindResourcesRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
|
||||
return nil, apiutil.ErrUnsupportedContentType
|
||||
}
|
||||
req := bindResourcesReq{
|
||||
token: apiutil.ExtractBearerToken(r),
|
||||
configID: chi.URLParam(r, "configID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeEnrollmentEntityRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
return listBindingsReq{configID: chi.URLParam(r, "configID")}, nil
|
||||
}
|
||||
|
||||
func decodeRefreshBindingsRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
return refreshBindingsReq{
|
||||
token: apiutil.ExtractBearerToken(r),
|
||||
configID: chi.URLParam(r, "configID"),
|
||||
}, nil
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package bootstrap
|
||||
|
||||
import "context"
|
||||
|
||||
// Config represents a bootstrap enrollment.
|
||||
type Config struct {
|
||||
ID string `json:"id"`
|
||||
DomainID string `json:"domain_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
ExternalID string `json:"external_id"`
|
||||
ExternalKey string `json:"external_key"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Status Status `json:"status"`
|
||||
ProfileID string `json:"profile_id,omitempty"`
|
||||
RenderContext map[string]any `json:"render_context,omitempty"`
|
||||
}
|
||||
|
||||
// Filter is used for the search filters.
|
||||
type Filter struct {
|
||||
FullMatch map[string]string
|
||||
PartialMatch map[string]string
|
||||
}
|
||||
|
||||
// ConfigsPage contains page related metadata as well as list of Configs that
|
||||
// belong to this page.
|
||||
type ConfigsPage struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
Configs []Config `json:"configs"`
|
||||
}
|
||||
|
||||
// ConfigRepository specifies a Config persistence API.
|
||||
type ConfigRepository interface {
|
||||
// Save persists the Config. Successful operation is indicated by non-nil
|
||||
// error response.
|
||||
Save(ctx context.Context, cfg Config) (string, error)
|
||||
|
||||
// RetrieveByID retrieves the Config having the provided identifier, that is owned
|
||||
// by the specified user.
|
||||
RetrieveByID(ctx context.Context, domainID, id string) (Config, error)
|
||||
|
||||
// RetrieveAll retrieves a subset of Configs that belong to the given domain,
|
||||
// with given filter parameters.
|
||||
RetrieveAll(ctx context.Context, domainID string, filter Filter, offset, limit uint64) ConfigsPage
|
||||
|
||||
// RetrieveByExternalID returns Config for given external ID.
|
||||
RetrieveByExternalID(ctx context.Context, externalID string) (Config, error)
|
||||
|
||||
// Update updates an existing Config. A non-nil error is returned
|
||||
// to indicate operation failure.
|
||||
Update(ctx context.Context, cfg Config) error
|
||||
|
||||
// AssignProfile sets the profile reference for the given Config.
|
||||
AssignProfile(ctx context.Context, domainID, id, profileID string) error
|
||||
|
||||
// UpdateCerts updates and returns an existing Config certificate and domainID.
|
||||
// A non-nil error is returned to indicate operation failure.
|
||||
UpdateCert(ctx context.Context, domainID, id, clientCert, clientKey, caCert string) (Config, error)
|
||||
|
||||
// Remove removes the Config having the provided identifier, that is owned
|
||||
// by the specified user.
|
||||
Remove(ctx context.Context, domainID, id string) error
|
||||
|
||||
// ChangeStatus changes the Status of the Config owned by the specific user.
|
||||
ChangeStatus(ctx context.Context, domainID, id string, status Status) error
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package bootstrap contains the domain concept definitions needed to support
|
||||
// Magistrala bootstrap service functionality.
|
||||
package bootstrap
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package events provides the domain concept definitions needed to support
|
||||
// bootstrap events functionality.
|
||||
package events
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package producer contains the domain events needed to support
|
||||
// event sourcing of Bootstrap service actions.
|
||||
package producer
|
||||
@@ -1,288 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package producer
|
||||
|
||||
import (
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
)
|
||||
|
||||
const (
|
||||
configPrefix = "bootstrap.config."
|
||||
configCreate = configPrefix + "create"
|
||||
configUpdate = configPrefix + "update"
|
||||
configRemove = configPrefix + "remove"
|
||||
configView = configPrefix + "view"
|
||||
configList = configPrefix + "list"
|
||||
clientPrefix = "bootstrap.client."
|
||||
clientBootstrap = clientPrefix + "bootstrap"
|
||||
configEnable = configPrefix + "enable"
|
||||
configDisable = configPrefix + "disable"
|
||||
certUpdate = "bootstrap.cert.update"
|
||||
|
||||
profilePrefix = "bootstrap.profile."
|
||||
profileCreate = profilePrefix + "create"
|
||||
profileView = profilePrefix + "view"
|
||||
profileUpdate = profilePrefix + "update"
|
||||
profileList = profilePrefix + "list"
|
||||
profileDelete = profilePrefix + "delete"
|
||||
profileAssign = profilePrefix + "assign"
|
||||
bindingsPrefix = "bootstrap.bindings."
|
||||
bindingsBind = bindingsPrefix + "bind"
|
||||
bindingsList = bindingsPrefix + "list"
|
||||
bindingsRefresh = bindingsPrefix + "refresh"
|
||||
)
|
||||
|
||||
var (
|
||||
_ events.Event = (*configEvent)(nil)
|
||||
_ events.Event = (*removeConfigEvent)(nil)
|
||||
_ events.Event = (*bootstrapEvent)(nil)
|
||||
_ events.Event = (*enableConfigEvent)(nil)
|
||||
_ events.Event = (*disableConfigEvent)(nil)
|
||||
_ events.Event = (*updateCertEvent)(nil)
|
||||
_ events.Event = (*listConfigsEvent)(nil)
|
||||
_ events.Event = (*profileEvent)(nil)
|
||||
_ events.Event = (*deleteProfileEvent)(nil)
|
||||
_ events.Event = (*assignProfileEvent)(nil)
|
||||
_ events.Event = (*bindResourcesEvent)(nil)
|
||||
_ events.Event = (*listBindingsEvent)(nil)
|
||||
_ events.Event = (*refreshBindingsEvent)(nil)
|
||||
)
|
||||
|
||||
type configEvent struct {
|
||||
bootstrap.Config
|
||||
operation string
|
||||
}
|
||||
|
||||
func (ce configEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"status": ce.Status.String(),
|
||||
"operation": ce.operation,
|
||||
}
|
||||
if ce.ID != "" {
|
||||
val["config_id"] = ce.ID
|
||||
}
|
||||
if ce.Content != "" {
|
||||
val["content"] = ce.Content
|
||||
}
|
||||
if ce.DomainID != "" {
|
||||
val["domain_id"] = ce.DomainID
|
||||
}
|
||||
if ce.Name != "" {
|
||||
val["name"] = ce.Name
|
||||
}
|
||||
if ce.ExternalID != "" {
|
||||
val["external_id"] = ce.ExternalID
|
||||
}
|
||||
if ce.ClientCert != "" {
|
||||
val["client_cert"] = ce.ClientCert
|
||||
}
|
||||
if ce.ClientKey != "" {
|
||||
val["client_key"] = ce.ClientKey
|
||||
}
|
||||
if ce.CACert != "" {
|
||||
val["ca_cert"] = ce.CACert
|
||||
}
|
||||
if ce.Content != "" {
|
||||
val["content"] = ce.Content
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type removeConfigEvent struct {
|
||||
config string
|
||||
}
|
||||
|
||||
func (rce removeConfigEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": rce.config,
|
||||
"operation": configRemove,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type listConfigsEvent struct {
|
||||
offset uint64
|
||||
limit uint64
|
||||
fullMatch map[string]string
|
||||
partialMatch map[string]string
|
||||
}
|
||||
|
||||
func (rce listConfigsEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"offset": rce.offset,
|
||||
"limit": rce.limit,
|
||||
"operation": configList,
|
||||
}
|
||||
if len(rce.fullMatch) > 0 {
|
||||
val["full_match"] = rce.fullMatch
|
||||
}
|
||||
|
||||
if len(rce.partialMatch) > 0 {
|
||||
val["full_match"] = rce.partialMatch
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type bootstrapEvent struct {
|
||||
bootstrap.Config
|
||||
externalID string
|
||||
success bool
|
||||
}
|
||||
|
||||
func (be bootstrapEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"external_id": be.externalID,
|
||||
"success": be.success,
|
||||
"operation": clientBootstrap,
|
||||
}
|
||||
|
||||
if be.ID != "" {
|
||||
val["config_id"] = be.ID
|
||||
}
|
||||
if be.Content != "" {
|
||||
val["content"] = be.Content
|
||||
}
|
||||
if be.DomainID != "" {
|
||||
val["domain_id"] = be.DomainID
|
||||
}
|
||||
if be.Name != "" {
|
||||
val["name"] = be.Name
|
||||
}
|
||||
if be.ExternalID != "" {
|
||||
val["external_id"] = be.ExternalID
|
||||
}
|
||||
if be.ClientCert != "" {
|
||||
val["client_cert"] = be.ClientCert
|
||||
}
|
||||
if be.ClientKey != "" {
|
||||
val["client_key"] = be.ClientKey
|
||||
}
|
||||
if be.CACert != "" {
|
||||
val["ca_cert"] = be.CACert
|
||||
}
|
||||
if be.Content != "" {
|
||||
val["content"] = be.Content
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type enableConfigEvent struct {
|
||||
configID string
|
||||
}
|
||||
|
||||
func (e enableConfigEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": e.configID,
|
||||
"operation": configEnable,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type disableConfigEvent struct {
|
||||
configID string
|
||||
}
|
||||
|
||||
func (e disableConfigEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": e.configID,
|
||||
"operation": configDisable,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type updateCertEvent struct {
|
||||
configID string
|
||||
clientCert string
|
||||
clientKey string
|
||||
caCert string
|
||||
}
|
||||
|
||||
func (uce updateCertEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": uce.configID,
|
||||
"client_cert": uce.clientCert,
|
||||
"client_key": uce.clientKey,
|
||||
"ca_cert": uce.caCert,
|
||||
"operation": certUpdate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type profileEvent struct {
|
||||
bootstrap.Profile
|
||||
operation string
|
||||
}
|
||||
|
||||
func (pe profileEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": pe.operation,
|
||||
}
|
||||
if pe.ID != "" {
|
||||
val["profile_id"] = pe.ID
|
||||
}
|
||||
if pe.DomainID != "" {
|
||||
val["domain_id"] = pe.DomainID
|
||||
}
|
||||
if pe.Name != "" {
|
||||
val["name"] = pe.Name
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type deleteProfileEvent struct {
|
||||
profileID string
|
||||
}
|
||||
|
||||
func (dpe deleteProfileEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"profile_id": dpe.profileID,
|
||||
"operation": profileDelete,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type assignProfileEvent struct {
|
||||
configID string
|
||||
profileID string
|
||||
}
|
||||
|
||||
func (ape assignProfileEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": ape.configID,
|
||||
"profile_id": ape.profileID,
|
||||
"operation": profileAssign,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type bindResourcesEvent struct {
|
||||
configID string
|
||||
slots []string
|
||||
}
|
||||
|
||||
func (bre bindResourcesEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": bre.configID,
|
||||
"slots": bre.slots,
|
||||
"operation": bindingsBind,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type listBindingsEvent struct {
|
||||
configID string
|
||||
}
|
||||
|
||||
func (lbe listBindingsEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": lbe.configID,
|
||||
"operation": bindingsList,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type refreshBindingsEvent struct {
|
||||
configID string
|
||||
}
|
||||
|
||||
func (rbe refreshBindingsEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"config_id": rbe.configID,
|
||||
"operation": bindingsRefresh,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package producer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
redisClient *redis.Client
|
||||
redisURL string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "redis",
|
||||
Tag: "7.2.4-alpine",
|
||||
}, func(config *docker.HostConfig) {
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
|
||||
opts, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not parse redis URL: %s", err)
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
redisClient = redis.NewClient(opts)
|
||||
|
||||
return redisClient.Ping(context.Background()).Err()
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -1,284 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package producer
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
)
|
||||
|
||||
var _ bootstrap.Service = (*eventStore)(nil)
|
||||
|
||||
const (
|
||||
magistralaPrefix = "magistrala."
|
||||
createStream = magistralaPrefix + configCreate
|
||||
listStream = magistralaPrefix + configList
|
||||
removeStream = magistralaPrefix + configRemove
|
||||
updateCertStream = magistralaPrefix + certUpdate
|
||||
bootstrapStream = magistralaPrefix + clientBootstrap
|
||||
enableConfigStream = magistralaPrefix + configEnable
|
||||
disableConfigStream = magistralaPrefix + configDisable
|
||||
createProfileStream = magistralaPrefix + profileCreate
|
||||
viewProfileStream = magistralaPrefix + profileView
|
||||
updateProfileStream = magistralaPrefix + profileUpdate
|
||||
listProfilesStream = magistralaPrefix + profileList
|
||||
deleteProfileStream = magistralaPrefix + profileDelete
|
||||
assignProfileStream = magistralaPrefix + profileAssign
|
||||
bindResourcesStream = magistralaPrefix + bindingsBind
|
||||
listBindingsStream = magistralaPrefix + bindingsList
|
||||
refreshBindingsStream = magistralaPrefix + bindingsRefresh
|
||||
)
|
||||
|
||||
type eventStore struct {
|
||||
events.Publisher
|
||||
svc bootstrap.Service
|
||||
}
|
||||
|
||||
// NewEventStoreMiddleware returns wrapper around bootstrap service that sends
|
||||
// events to event store.
|
||||
func NewEventStoreMiddleware(svc bootstrap.Service, publisher events.Publisher) bootstrap.Service {
|
||||
return &eventStore{
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
}
|
||||
}
|
||||
|
||||
func (es *eventStore) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
|
||||
saved, err := es.svc.Add(ctx, session, token, cfg)
|
||||
if err != nil {
|
||||
return saved, err
|
||||
}
|
||||
|
||||
ev := configEvent{
|
||||
saved, configCreate,
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, createStream, ev); err != nil {
|
||||
return saved, err
|
||||
}
|
||||
|
||||
return saved, err
|
||||
}
|
||||
|
||||
func (es *eventStore) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
cfg, err := es.svc.View(ctx, session, id)
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
ev := configEvent{
|
||||
cfg, configView,
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, magistralaPrefix+configView, ev); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func (es *eventStore) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
|
||||
if err := es.svc.Update(ctx, session, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ev := configEvent{
|
||||
cfg, configUpdate,
|
||||
}
|
||||
|
||||
return es.Publish(ctx, magistralaPrefix+configUpdate, ev)
|
||||
}
|
||||
|
||||
func (es eventStore) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
cfg, err := es.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
ev := updateCertEvent{
|
||||
configID: id,
|
||||
clientCert: clientCert,
|
||||
clientKey: clientKey,
|
||||
caCert: caCert,
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, updateCertStream, ev); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
|
||||
bp, err := es.svc.List(ctx, session, filter, offset, limit)
|
||||
if err != nil {
|
||||
return bp, err
|
||||
}
|
||||
|
||||
ev := listConfigsEvent{
|
||||
offset: offset,
|
||||
limit: limit,
|
||||
fullMatch: filter.FullMatch,
|
||||
partialMatch: filter.PartialMatch,
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, listStream, ev); err != nil {
|
||||
return bp, err
|
||||
}
|
||||
|
||||
return bp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) Remove(ctx context.Context, session smqauthn.Session, id string) error {
|
||||
if err := es.svc.Remove(ctx, session, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ev := removeConfigEvent{
|
||||
config: id,
|
||||
}
|
||||
|
||||
return es.Publish(ctx, removeStream, ev)
|
||||
}
|
||||
|
||||
func (es *eventStore) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
|
||||
cfg, err := es.svc.Bootstrap(ctx, externalKey, externalID, secure)
|
||||
|
||||
ev := bootstrapEvent{
|
||||
cfg,
|
||||
externalID,
|
||||
true,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
ev.success = false
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, bootstrapStream, ev); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
func (es *eventStore) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
cfg, err := es.svc.EnableConfig(ctx, session, id)
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
ev := enableConfigEvent{configID: id}
|
||||
if err := es.Publish(ctx, enableConfigStream, ev); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
cfg, err := es.svc.DisableConfig(ctx, session, id)
|
||||
if err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
|
||||
ev := disableConfigEvent{configID: id}
|
||||
if err := es.Publish(ctx, disableConfigStream, ev); err != nil {
|
||||
return cfg, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
saved, err := es.svc.CreateProfile(ctx, session, p)
|
||||
if err != nil {
|
||||
return saved, err
|
||||
}
|
||||
ev := profileEvent{saved, profileCreate}
|
||||
if err := es.Publish(ctx, createProfileStream, ev); err != nil {
|
||||
return saved, err
|
||||
}
|
||||
return saved, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
|
||||
p, err := es.svc.ViewProfile(ctx, session, profileID)
|
||||
if err != nil {
|
||||
return p, err
|
||||
}
|
||||
ev := profileEvent{p, profileView}
|
||||
if err := es.Publish(ctx, viewProfileStream, ev); err != nil {
|
||||
return p, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
updated, err := es.svc.UpdateProfile(ctx, session, p)
|
||||
if err != nil {
|
||||
return bootstrap.Profile{}, err
|
||||
}
|
||||
ev := profileEvent{updated, profileUpdate}
|
||||
return updated, es.Publish(ctx, updateProfileStream, ev)
|
||||
}
|
||||
|
||||
func (es *eventStore) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
|
||||
pp, err := es.svc.ListProfiles(ctx, session, offset, limit, name)
|
||||
if err != nil {
|
||||
return pp, err
|
||||
}
|
||||
ev := profileEvent{operation: profileList}
|
||||
if err := es.Publish(ctx, listProfilesStream, ev); err != nil {
|
||||
return pp, err
|
||||
}
|
||||
return pp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
|
||||
if err := es.svc.DeleteProfile(ctx, session, profileID); err != nil {
|
||||
return err
|
||||
}
|
||||
ev := deleteProfileEvent{profileID: profileID}
|
||||
return es.Publish(ctx, deleteProfileStream, ev)
|
||||
}
|
||||
|
||||
func (es *eventStore) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
|
||||
if err := es.svc.AssignProfile(ctx, session, configID, profileID); err != nil {
|
||||
return err
|
||||
}
|
||||
ev := assignProfileEvent{configID: configID, profileID: profileID}
|
||||
return es.Publish(ctx, assignProfileStream, ev)
|
||||
}
|
||||
|
||||
func (es *eventStore) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
|
||||
if err := es.svc.BindResources(ctx, session, token, configID, bindings); err != nil {
|
||||
return err
|
||||
}
|
||||
slots := make([]string, len(bindings))
|
||||
for i, b := range bindings {
|
||||
slots[i] = b.Slot
|
||||
}
|
||||
ev := bindResourcesEvent{configID: configID, slots: slots}
|
||||
return es.Publish(ctx, bindResourcesStream, ev)
|
||||
}
|
||||
|
||||
func (es *eventStore) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
|
||||
bs, err := es.svc.ListBindings(ctx, session, configID)
|
||||
if err != nil {
|
||||
return bs, err
|
||||
}
|
||||
ev := listBindingsEvent{configID: configID}
|
||||
if err := es.Publish(ctx, listBindingsStream, ev); err != nil {
|
||||
return bs, err
|
||||
}
|
||||
return bs, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
|
||||
if err := es.svc.RefreshBindings(ctx, session, token, configID); err != nil {
|
||||
return err
|
||||
}
|
||||
ev := refreshBindingsEvent{configID: configID}
|
||||
return es.Publish(ctx, refreshBindingsStream, ev)
|
||||
}
|
||||
@@ -1,914 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package producer_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/bootstrap/events/producer"
|
||||
bootstraphasher "github.com/absmach/magistrala/bootstrap/hasher"
|
||||
"github.com/absmach/magistrala/bootstrap/mocks"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/events/store"
|
||||
sdkmocks "github.com/absmach/magistrala/pkg/sdk/mocks"
|
||||
"github.com/absmach/magistrala/pkg/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
streamID = "magistrala.bootstrap"
|
||||
validToken = "validToken"
|
||||
unknownID = "unknown"
|
||||
|
||||
configPrefix = "config."
|
||||
configCreate = configPrefix + "create"
|
||||
configView = configPrefix + "view"
|
||||
configUpdate = configPrefix + "update"
|
||||
configRemove = configPrefix + "remove"
|
||||
configList = configPrefix + "list"
|
||||
clientPrefix = "client."
|
||||
clientBootstrap = clientPrefix + "bootstrap"
|
||||
configEnable = configPrefix + "enable"
|
||||
configDisable = configPrefix + "disable"
|
||||
|
||||
certUpdate = "cert.update"
|
||||
)
|
||||
|
||||
var (
|
||||
encKey = []byte("1234567891011121")
|
||||
|
||||
domainID = testsutil.GenerateUUID(&testing.T{})
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
|
||||
config = bootstrap.Config{
|
||||
ID: testsutil.GenerateUUID(&testing.T{}),
|
||||
ExternalID: testsutil.GenerateUUID(&testing.T{}),
|
||||
ExternalKey: testsutil.GenerateUUID(&testing.T{}),
|
||||
Content: "config",
|
||||
Status: bootstrap.EnabledStatus,
|
||||
}
|
||||
)
|
||||
|
||||
type testVariable struct {
|
||||
svc bootstrap.Service
|
||||
boot *mocks.ConfigRepository
|
||||
sdk *sdkmocks.SDK
|
||||
}
|
||||
|
||||
func newTestVariable(t *testing.T, redisURL string) testVariable {
|
||||
boot := new(mocks.ConfigRepository)
|
||||
sdk := new(sdkmocks.SDK)
|
||||
idp := uuid.NewMock()
|
||||
svc := bootstrap.New(boot, nil, nil, nil, nil, sdk, bootstraphasher.New(), encKey, idp)
|
||||
publisher, err := store.NewPublisher(context.Background(), redisURL, "bootstrap-es-pub-test")
|
||||
require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
svc = producer.NewEventStoreMiddleware(svc, publisher)
|
||||
return testVariable{
|
||||
svc: svc,
|
||||
boot: boot,
|
||||
sdk: sdk,
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
token string
|
||||
session smqauthn.Session
|
||||
id string
|
||||
domainID string
|
||||
saveErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "create config successfully",
|
||||
config: config,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
event: map[string]any{
|
||||
"config_id": "1",
|
||||
"domain_id": domainID,
|
||||
"name": config.Name,
|
||||
"external_id": config.ExternalID,
|
||||
"content": config.Content,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configCreate,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "create config with failed to save",
|
||||
config: config,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
event: nil,
|
||||
saveErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("Save", context.Background(), mock.Anything).Return(mock.Anything, tc.saveErr)
|
||||
|
||||
_, err := tv.svc.Add(context.Background(), tc.session, tc.token, tc.config)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestView(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
nonExisting := config
|
||||
nonExisting.ID = unknownID
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
token string
|
||||
session smqauthn.Session
|
||||
id string
|
||||
domainID string
|
||||
retrieveErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "view successfully",
|
||||
config: config,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": config.ID,
|
||||
"domain_id": config.DomainID,
|
||||
"name": config.Name,
|
||||
"external_id": config.ExternalID,
|
||||
"content": config.Content,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configView,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "view with failed retrieve",
|
||||
config: nonExisting,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
retrieveErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.config.ID).Return(config, tc.retrieveErr)
|
||||
_, err := tv.svc.View(context.Background(), tc.session, tc.config.ID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
msg := streams[0].Messages[0]
|
||||
event = msg.Values
|
||||
event["timestamp"] = msg.ID
|
||||
lastID = msg.ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
modified := config
|
||||
modified.Content = "new-config"
|
||||
modified.Name = "new name"
|
||||
|
||||
nonExisting := config
|
||||
nonExisting.ID = unknownID
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
token string
|
||||
session smqauthn.Session
|
||||
id string
|
||||
domainID string
|
||||
updateErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "update config successfully",
|
||||
config: modified,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"name": modified.Name,
|
||||
"content": modified.Content,
|
||||
"timestamp": time.Now().UnixNano(),
|
||||
"operation": configUpdate,
|
||||
"external_id": modified.ExternalID,
|
||||
"config_id": modified.ID,
|
||||
"domain_id": domainID,
|
||||
"status": bootstrap.Disabled,
|
||||
"occurred_at": time.Now().UnixNano(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "update with failed update",
|
||||
config: nonExisting,
|
||||
token: validToken,
|
||||
id: validID,
|
||||
domainID: domainID,
|
||||
updateErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("Update", context.Background(), mock.Anything).Return(tc.updateErr)
|
||||
err := tv.svc.Update(context.Background(), tc.session, tc.config)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
msg := streams[0].Messages[0]
|
||||
event = msg.Values
|
||||
event["timestamp"] = msg.ID
|
||||
lastID = msg.ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateCert(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
configID string
|
||||
userID string
|
||||
domainID string
|
||||
token string
|
||||
session smqauthn.Session
|
||||
clientCert string
|
||||
clientKey string
|
||||
caCert string
|
||||
updateErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "update cert successfully",
|
||||
configID: config.ID,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
token: validToken,
|
||||
clientCert: "clientCert",
|
||||
clientKey: "clientKey",
|
||||
caCert: "caCert",
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"client_cert": "clientCert",
|
||||
"client_key": "clientKey",
|
||||
"ca_cert": "caCert",
|
||||
"operation": certUpdate,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "update cert with failed update",
|
||||
configID: "clientID",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
clientCert: "clientCert",
|
||||
clientKey: "clientKey",
|
||||
caCert: "caCert",
|
||||
updateErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "update cert with empty client certificate",
|
||||
configID: config.ID,
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
clientCert: "",
|
||||
clientKey: "clientKey",
|
||||
caCert: "caCert",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "update cert with empty client key",
|
||||
configID: config.ID,
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
clientCert: "clientCert",
|
||||
clientKey: "",
|
||||
caCert: "caCert",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "update cert with empty CA certificate",
|
||||
configID: config.ID,
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
clientCert: "clientCert",
|
||||
clientKey: "clientKey",
|
||||
caCert: "",
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("UpdateCert", context.Background(), tc.domainID, tc.configID, tc.clientCert, tc.clientKey, tc.caCert).Return(config, tc.updateErr)
|
||||
_, err := tv.svc.UpdateCert(context.Background(), tc.session, tc.configID, tc.clientCert, tc.clientKey, tc.caCert)
|
||||
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
numClients := 101
|
||||
var c bootstrap.Config
|
||||
saved := make([]bootstrap.Config, 0)
|
||||
for i := 0; i < numClients; i++ {
|
||||
c = config
|
||||
c.ExternalID = testsutil.GenerateUUID(t)
|
||||
c.ExternalKey = testsutil.GenerateUUID(t)
|
||||
c.Name = fmt.Sprintf("%s-%d", config.Name, i)
|
||||
if i == 41 {
|
||||
c.Status = bootstrap.Active
|
||||
}
|
||||
saved = append(saved, c)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
session smqauthn.Session
|
||||
userID string
|
||||
domainID string
|
||||
config bootstrap.ConfigsPage
|
||||
filter bootstrap.Filter
|
||||
offset uint64
|
||||
limit uint64
|
||||
retrieveErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "list successfully as super admin",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
|
||||
config: bootstrap.ConfigsPage{
|
||||
Total: uint64(len(saved)),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Configs: saved[0:10],
|
||||
},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": c.ID,
|
||||
"domain_id": c.DomainID,
|
||||
"name": c.Name,
|
||||
"external_id": c.ExternalID,
|
||||
"content": c.Content,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configList,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "list successfully as domain admin",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
|
||||
config: bootstrap.ConfigsPage{
|
||||
Total: uint64(len(saved)),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Configs: saved[0:10],
|
||||
},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": c.ID,
|
||||
"domain_id": c.DomainID,
|
||||
"name": c.Name,
|
||||
"external_id": c.ExternalID,
|
||||
"content": c.Content,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configList,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "list successfully as non admin",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID},
|
||||
config: bootstrap.ConfigsPage{
|
||||
Total: uint64(len(saved)),
|
||||
Offset: 0,
|
||||
Limit: 10,
|
||||
Configs: saved[0:10],
|
||||
},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": c.ID,
|
||||
"domain_id": c.DomainID,
|
||||
"name": c.Name,
|
||||
"external_id": c.ExternalID,
|
||||
"content": c.Content,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configList,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "list as super admin with failed retrieve all",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
retrieveErr: nil,
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "list as domain admin with failed retrieve all",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
retrieveErr: nil,
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "list as non admin with failed retrieve all",
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID},
|
||||
filter: bootstrap.Filter{},
|
||||
offset: 0,
|
||||
limit: 10,
|
||||
retrieveErr: nil,
|
||||
err: nil,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
repoCall := tv.boot.On("RetrieveAll", context.Background(), mock.Anything, tc.filter, tc.offset, tc.limit).Return(tc.config, tc.retrieveErr)
|
||||
|
||||
_, err := tv.svc.List(context.Background(), tc.session, tc.filter, tc.offset, tc.limit)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
nonExisting := config
|
||||
nonExisting.ID = unknownID
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
configID string
|
||||
userID string
|
||||
domainID string
|
||||
token string
|
||||
session smqauthn.Session
|
||||
removeErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "remove config successfully",
|
||||
configID: config.ID,
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": config.ID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configRemove,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "remove config with failed removal",
|
||||
configID: nonExisting.ID,
|
||||
token: validToken,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
removeErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("Remove", context.Background(), mock.Anything, mock.Anything).Return(tc.removeErr)
|
||||
err := tv.svc.Remove(context.Background(), tc.session, tc.configID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestBootstrap(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
externalID string
|
||||
externalKey string
|
||||
err error
|
||||
retrieveErr error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "bootstrap successfully",
|
||||
externalID: config.ExternalID,
|
||||
externalKey: config.ExternalKey,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"external_id": config.ExternalID,
|
||||
"success": "1",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": clientBootstrap,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "bootstrap with an error",
|
||||
externalID: "external_id1",
|
||||
externalKey: "external_id",
|
||||
retrieveErr: bootstrap.ErrBootstrap,
|
||||
err: bootstrap.ErrBootstrap,
|
||||
event: map[string]any{
|
||||
"external_id": "external_id",
|
||||
"success": "0",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": clientBootstrap,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
repoCall := tv.boot.On("RetrieveByExternalID", context.Background(), mock.Anything).Return(config, tc.retrieveErr)
|
||||
_, err = tv.svc.Bootstrap(context.Background(), tc.externalKey, tc.externalID, false)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableConfig(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
userID string
|
||||
domainID string
|
||||
session smqauthn.Session
|
||||
retrieveErr error
|
||||
statusErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "enable config",
|
||||
id: config.ID,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": config.ID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configEnable,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "enable with failed retrieve by ID",
|
||||
id: "",
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
retrieveErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "enable with repo status error",
|
||||
id: config.ID,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
statusErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
disabledConfig := config
|
||||
disabledConfig.Status = bootstrap.DisabledStatus
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(disabledConfig, tc.retrieveErr)
|
||||
repoCall1 := tv.boot.On("ChangeStatus", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.statusErr)
|
||||
_, err := tv.svc.EnableConfig(context.Background(), tc.session, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableConfig(t *testing.T) {
|
||||
err := redisClient.FlushAll(context.Background()).Err()
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err))
|
||||
|
||||
tv := newTestVariable(t, redisURL)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
userID string
|
||||
domainID string
|
||||
session smqauthn.Session
|
||||
retrieveErr error
|
||||
statusErr error
|
||||
err error
|
||||
event map[string]any
|
||||
}{
|
||||
{
|
||||
desc: "disable config",
|
||||
id: config.ID,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
err: nil,
|
||||
event: map[string]any{
|
||||
"config_id": config.ID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
"operation": configDisable,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "disable with failed retrieve by ID",
|
||||
id: "",
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
retrieveErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
event: nil,
|
||||
},
|
||||
{
|
||||
desc: "disable with repo status error",
|
||||
id: config.ID,
|
||||
userID: validID,
|
||||
domainID: domainID,
|
||||
statusErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
event: nil,
|
||||
},
|
||||
}
|
||||
|
||||
lastID := "0"
|
||||
for _, tc := range cases {
|
||||
tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID}
|
||||
repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(config, tc.retrieveErr)
|
||||
repoCall1 := tv.boot.On("ChangeStatus", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.statusErr)
|
||||
_, err := tv.svc.DisableConfig(context.Background(), tc.session, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
streams := redisClient.XRead(context.Background(), &redis.XReadArgs{
|
||||
Streams: []string{streamID, lastID},
|
||||
Count: 1,
|
||||
Block: time.Second,
|
||||
}).Val()
|
||||
|
||||
var event map[string]any
|
||||
if len(streams) > 0 && len(streams[0].Messages) > 0 {
|
||||
event := streams[0].Messages
|
||||
lastID = event[0].ID
|
||||
}
|
||||
|
||||
test(t, tc.event, event, tc.desc)
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func test(t *testing.T, expected, actual map[string]any, description string) {
|
||||
if expected != nil && actual != nil {
|
||||
ts1 := expected["timestamp"].(int64)
|
||||
ats := actual["timestamp"].(string)
|
||||
ts2, err := strconv.ParseInt(strings.Split(ats, "-")[0], 10, 64)
|
||||
require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid timestamp, got %s", description, err))
|
||||
ts1 = ts1 / 1e9
|
||||
ts2 = ts2 / 1e3
|
||||
if assert.WithinDuration(t, time.Unix(ts1, 0), time.Unix(ts2, 0), time.Second, fmt.Sprintf("%s: timestamp is not in valid range of 1 second", description)) {
|
||||
delete(expected, "timestamp")
|
||||
delete(actual, "timestamp")
|
||||
}
|
||||
|
||||
oa1 := expected["occurred_at"].(int64)
|
||||
aoa := actual["occurred_at"].(string)
|
||||
oa2, err := strconv.ParseInt(aoa, 10, 64)
|
||||
require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid occurred_at, got %s", description, err))
|
||||
oa1 = oa1 / 1e9
|
||||
oa2 = oa2 / 1e9
|
||||
if assert.WithinDuration(t, time.Unix(oa1, 0), time.Unix(oa2, 0), time.Second, fmt.Sprintf("%s: occurred_at is not in valid range of 1 second", description)) {
|
||||
delete(expected, "occurred_at")
|
||||
delete(actual, "occurred_at")
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, actual, fmt.Sprintf("%s: got incorrect event\n", description))
|
||||
}
|
||||
}
|
||||
@@ -1,219 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/auth"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/authz"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
)
|
||||
|
||||
const (
|
||||
createOperation = "create"
|
||||
viewOperation = "view"
|
||||
updateOperation = "update"
|
||||
updateCertOperation = "update_cert"
|
||||
listOperation = "list"
|
||||
removeOperation = "remove"
|
||||
changeStateOperation = "change_state"
|
||||
)
|
||||
|
||||
var _ bootstrap.Service = (*authorizationMiddleware)(nil)
|
||||
|
||||
type authorizationMiddleware struct {
|
||||
svc bootstrap.Service
|
||||
authz authz.Authorization
|
||||
}
|
||||
|
||||
// AuthorizationMiddleware adds authorization to the clients service.
|
||||
func AuthorizationMiddleware(svc bootstrap.Service, authz authz.Authorization) bootstrap.Service {
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
authz: authz,
|
||||
}
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, createOperation, auth.AnyIDs); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
return am.svc.Add(ctx, session, token, cfg)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, id); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
return am.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, cfg.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.svc.Update(ctx, session, cfg)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateCertOperation, id); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
return am.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
|
||||
if err := am.checkSuperAdmin(ctx, session); err == nil {
|
||||
session.SuperAdmin = true
|
||||
}
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err == nil {
|
||||
session.SuperAdmin = true
|
||||
}
|
||||
|
||||
return am.svc.List(ctx, session, filter, offset, limit)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, removeOperation, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return am.svc.Remove(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
|
||||
return am.svc.Bootstrap(ctx, externalKey, externalID, secure)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, changeStateOperation, id); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
return am.svc.EnableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, changeStateOperation, id); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
return am.svc.DisableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, createOperation, auth.AnyIDs); err != nil {
|
||||
return bootstrap.Profile{}, err
|
||||
}
|
||||
return am.svc.CreateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, auth.AnyIDs); err != nil {
|
||||
return bootstrap.Profile{}, err
|
||||
}
|
||||
return am.svc.ViewProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, auth.AnyIDs); err != nil {
|
||||
return bootstrap.Profile{}, err
|
||||
}
|
||||
return am.svc.UpdateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, listOperation, auth.AnyIDs); err != nil {
|
||||
return bootstrap.ProfilesPage{}, err
|
||||
}
|
||||
return am.svc.ListProfiles(ctx, session, offset, limit, name)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, removeOperation, auth.AnyIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
return am.svc.DeleteProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
|
||||
return err
|
||||
}
|
||||
return am.svc.AssignProfile(ctx, session, configID, profileID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
|
||||
return err
|
||||
}
|
||||
return am.svc.BindResources(ctx, session, token, configID, bindings)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, viewOperation, configID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.svc.ListBindings(ctx, session, configID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
|
||||
if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID, updateOperation, configID); err != nil {
|
||||
return err
|
||||
}
|
||||
return am.svc.RefreshBindings(ctx, session, token, configID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session smqauthn.Session) error {
|
||||
if session.Role != smqauthn.SuperAdminRole {
|
||||
return svcerr.ErrSuperAdminAction
|
||||
}
|
||||
if err := am.authz.Authorize(ctx, authz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.UserID,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.MagistralaObject,
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) authorize(ctx context.Context, session smqauthn.Session, domain, subjType, subjKind, subj, perm, objType, obj, operation, entityID string) error {
|
||||
req := authz.PolicyReq{
|
||||
Domain: domain,
|
||||
SubjectType: subjType,
|
||||
SubjectKind: subjKind,
|
||||
Subject: subj,
|
||||
Permission: perm,
|
||||
ObjectType: objType,
|
||||
Object: obj,
|
||||
}
|
||||
|
||||
var pat *authz.PATReq
|
||||
if session.PatID != "" {
|
||||
pat = &authz.PATReq{
|
||||
UserID: session.UserID,
|
||||
PatID: session.PatID,
|
||||
EntityID: entityID,
|
||||
EntityType: auth.BootstrapType.String(),
|
||||
Operation: operation,
|
||||
Domain: session.DomainID,
|
||||
}
|
||||
}
|
||||
|
||||
if err := am.authz.Authorize(ctx, req, pat); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,355 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
)
|
||||
|
||||
var _ bootstrap.Service = (*loggingMiddleware)(nil)
|
||||
|
||||
type loggingMiddleware struct {
|
||||
logger *slog.Logger
|
||||
svc bootstrap.Service
|
||||
}
|
||||
|
||||
// LoggingMiddleware adds logging facilities to the bootstrap service.
|
||||
func LoggingMiddleware(svc bootstrap.Service, logger *slog.Logger) bootstrap.Service {
|
||||
return &loggingMiddleware{logger, svc}
|
||||
}
|
||||
|
||||
// Add logs the add request. It logs the client ID and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", saved.ID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Add new bootstrap failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Add new bootstrap completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Add(ctx, session, token, cfg)
|
||||
}
|
||||
|
||||
// View logs the view request. It logs the client ID and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("View client config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("View client config completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
// Update logs the update request. It logs bootstrap client ID and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.Group("config",
|
||||
slog.String("config_id", cfg.ID),
|
||||
slog.String("name", cfg.Name),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Update bootstrap config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Update bootstrap config completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Update(ctx, session, cfg)
|
||||
}
|
||||
|
||||
// UpdateCert logs the update_cert request. It logs config ID and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", cfg.ID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Update bootstrap config certificate failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Update bootstrap config certificate completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
|
||||
}
|
||||
|
||||
// List logs the list request. It logs offset, limit and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (res bootstrap.ConfigsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.Group("page",
|
||||
slog.Any("filter", filter),
|
||||
slog.Uint64("offset", offset),
|
||||
slog.Uint64("limit", limit),
|
||||
slog.Uint64("total", res.Total),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("List configs failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List configs completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.List(ctx, session, filter, offset, limit)
|
||||
}
|
||||
|
||||
// Remove logs the remove request. It logs bootstrap ID and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Remove bootstrap config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Remove bootstrap config completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Remove(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("external_id", externalID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("View bootstrap config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("View bootstrap completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.Bootstrap(ctx, externalKey, externalID, secure)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Enable config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Enable config completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.EnableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Disable config failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Disable config completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.DisableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (saved bootstrap.Profile, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("profile_id", saved.ID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Create profile failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Create profile completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.CreateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (p bootstrap.Profile, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("profile_id", profileID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("View profile failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("View profile completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.ViewProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (updated bootstrap.Profile, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("profile_id", p.ID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Update profile failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Update profile completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.UpdateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (page bootstrap.ProfilesPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.Uint64("offset", offset),
|
||||
slog.Uint64("limit", limit),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("List profiles failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List profiles completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.ListProfiles(ctx, session, offset, limit, name)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("profile_id", profileID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Delete profile failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Delete profile completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.DeleteProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", configID),
|
||||
slog.String("profile_id", profileID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Assign profile failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Assign profile completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.AssignProfile(ctx, session, configID, profileID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", configID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Bind resources failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Bind resources completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.BindResources(ctx, session, token, configID, bindings)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) (snapshots []bootstrap.BindingSnapshot, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", configID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("List bindings failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List bindings completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.ListBindings(ctx, session, configID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("config_id", configID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.Any("error", err))
|
||||
lm.logger.Warn("Refresh bindings failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Refresh bindings completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.RefreshBindings(ctx, session, token, configID)
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
//go:build !test
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
)
|
||||
|
||||
var _ bootstrap.Service = (*metricsMiddleware)(nil)
|
||||
|
||||
type metricsMiddleware struct {
|
||||
counter metrics.Counter
|
||||
latency metrics.Histogram
|
||||
svc bootstrap.Service
|
||||
}
|
||||
|
||||
// MetricsMiddleware instruments core service by tracking request count and latency.
|
||||
func MetricsMiddleware(svc bootstrap.Service, counter metrics.Counter, latency metrics.Histogram) bootstrap.Service {
|
||||
return &metricsMiddleware{
|
||||
counter: counter,
|
||||
latency: latency,
|
||||
svc: svc,
|
||||
}
|
||||
}
|
||||
|
||||
// Add instruments Add method with metrics.
|
||||
func (mm *metricsMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "add").Add(1)
|
||||
mm.latency.With("method", "add").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Add(ctx, session, token, cfg)
|
||||
}
|
||||
|
||||
// View instruments View method with metrics.
|
||||
func (mm *metricsMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "view").Add(1)
|
||||
mm.latency.With("method", "view").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
// Update instruments Update method with metrics.
|
||||
func (mm *metricsMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update").Add(1)
|
||||
mm.latency.With("method", "update").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Update(ctx, session, cfg)
|
||||
}
|
||||
|
||||
// UpdateCert instruments UpdateCert method with metrics.
|
||||
func (mm *metricsMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update_cert").Add(1)
|
||||
mm.latency.With("method", "update_cert").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
|
||||
}
|
||||
|
||||
// List instruments List method with metrics.
|
||||
func (mm *metricsMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (saved bootstrap.ConfigsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "list").Add(1)
|
||||
mm.latency.With("method", "list").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.List(ctx, session, filter, offset, limit)
|
||||
}
|
||||
|
||||
// Remove instruments Remove method with metrics.
|
||||
func (mm *metricsMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "remove").Add(1)
|
||||
mm.latency.With("method", "remove").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Remove(ctx, session, id)
|
||||
}
|
||||
|
||||
// Bootstrap instruments Bootstrap method with metrics.
|
||||
func (mm *metricsMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "bootstrap").Add(1)
|
||||
mm.latency.With("method", "bootstrap").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.Bootstrap(ctx, externalKey, externalID, secure)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "enable_config").Add(1)
|
||||
mm.latency.With("method", "enable_config").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.EnableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "disable_config").Add(1)
|
||||
mm.latency.With("method", "disable_config").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return mm.svc.DisableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "create_profile").Add(1)
|
||||
mm.latency.With("method", "create_profile").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.CreateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "view_profile").Add(1)
|
||||
mm.latency.With("method", "view_profile").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.ViewProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "update_profile").Add(1)
|
||||
mm.latency.With("method", "update_profile").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.UpdateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "list_profiles").Add(1)
|
||||
mm.latency.With("method", "list_profiles").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.ListProfiles(ctx, session, offset, limit, name)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "delete_profile").Add(1)
|
||||
mm.latency.With("method", "delete_profile").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.DeleteProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "assign_profile").Add(1)
|
||||
mm.latency.With("method", "assign_profile").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.AssignProfile(ctx, session, configID, profileID)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "bind_resources").Add(1)
|
||||
mm.latency.With("method", "bind_resources").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.BindResources(ctx, session, token, configID, bindings)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "list_bindings").Add(1)
|
||||
mm.latency.With("method", "list_bindings").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.ListBindings(ctx, session, configID)
|
||||
}
|
||||
|
||||
func (mm *metricsMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
|
||||
defer func(begin time.Time) {
|
||||
mm.counter.With("method", "refresh_bindings").Add(1)
|
||||
mm.latency.With("method", "refresh_bindings").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return mm.svc.RefreshBindings(ctx, session, token, configID)
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewConfigReader creates a new instance of ConfigReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewConfigReader(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *ConfigReader {
|
||||
mock := &ConfigReader{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// ConfigReader is an autogenerated mock type for the ConfigReader type
|
||||
type ConfigReader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type ConfigReader_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *ConfigReader) EXPECT() *ConfigReader_Expecter {
|
||||
return &ConfigReader_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// ReadConfig provides a mock function for the type ConfigReader
|
||||
func (_mock *ConfigReader) ReadConfig(config bootstrap.Config, b bool) (any, error) {
|
||||
ret := _mock.Called(config, b)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ReadConfig")
|
||||
}
|
||||
|
||||
var r0 any
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) (any, error)); ok {
|
||||
return returnFunc(config, b)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) any); ok {
|
||||
r0 = returnFunc(config, b)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(any)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(bootstrap.Config, bool) error); ok {
|
||||
r1 = returnFunc(config, b)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ConfigReader_ReadConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadConfig'
|
||||
type ConfigReader_ReadConfig_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ReadConfig is a helper method to define mock.On call
|
||||
// - config bootstrap.Config
|
||||
// - b bool
|
||||
func (_e *ConfigReader_Expecter) ReadConfig(config interface{}, b interface{}) *ConfigReader_ReadConfig_Call {
|
||||
return &ConfigReader_ReadConfig_Call{Call: _e.mock.On("ReadConfig", config, b)}
|
||||
}
|
||||
|
||||
func (_c *ConfigReader_ReadConfig_Call) Run(run func(config bootstrap.Config, b bool)) *ConfigReader_ReadConfig_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 bootstrap.Config
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(bootstrap.Config)
|
||||
}
|
||||
var arg1 bool
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(bool)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigReader_ReadConfig_Call) Return(v any, err error) *ConfigReader_ReadConfig_Call {
|
||||
_c.Call.Return(v, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigReader_ReadConfig_Call) RunAndReturn(run func(config bootstrap.Config, b bool) (any, error)) *ConfigReader_ReadConfig_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -1,670 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewConfigRepository creates a new instance of ConfigRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewConfigRepository(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *ConfigRepository {
|
||||
mock := &ConfigRepository{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// ConfigRepository is an autogenerated mock type for the ConfigRepository type
|
||||
type ConfigRepository struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type ConfigRepository_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *ConfigRepository) EXPECT() *ConfigRepository_Expecter {
|
||||
return &ConfigRepository_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AssignProfile provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) AssignProfile(ctx context.Context, domainID string, id string, profileID string) error {
|
||||
ret := _mock.Called(ctx, domainID, id, profileID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AssignProfile")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, domainID, id, profileID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// ConfigRepository_AssignProfile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AssignProfile'
|
||||
type ConfigRepository_AssignProfile_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AssignProfile is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - id string
|
||||
// - profileID string
|
||||
func (_e *ConfigRepository_Expecter) AssignProfile(ctx interface{}, domainID interface{}, id interface{}, profileID interface{}) *ConfigRepository_AssignProfile_Call {
|
||||
return &ConfigRepository_AssignProfile_Call{Call: _e.mock.On("AssignProfile", ctx, domainID, id, profileID)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_AssignProfile_Call) Run(run func(ctx context.Context, domainID string, id string, profileID string)) *ConfigRepository_AssignProfile_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 string
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_AssignProfile_Call) Return(err error) *ConfigRepository_AssignProfile_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_AssignProfile_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, profileID string) error) *ConfigRepository_AssignProfile_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ChangeStatus provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) ChangeStatus(ctx context.Context, domainID string, id string, status bootstrap.Status) error {
|
||||
ret := _mock.Called(ctx, domainID, id, status)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ChangeStatus")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, bootstrap.Status) error); ok {
|
||||
r0 = returnFunc(ctx, domainID, id, status)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// ConfigRepository_ChangeStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangeStatus'
|
||||
type ConfigRepository_ChangeStatus_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ChangeStatus is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - id string
|
||||
// - status bootstrap.Status
|
||||
func (_e *ConfigRepository_Expecter) ChangeStatus(ctx interface{}, domainID interface{}, id interface{}, status interface{}) *ConfigRepository_ChangeStatus_Call {
|
||||
return &ConfigRepository_ChangeStatus_Call{Call: _e.mock.On("ChangeStatus", ctx, domainID, id, status)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_ChangeStatus_Call) Run(run func(ctx context.Context, domainID string, id string, status bootstrap.Status)) *ConfigRepository_ChangeStatus_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 bootstrap.Status
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(bootstrap.Status)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_ChangeStatus_Call) Return(err error) *ConfigRepository_ChangeStatus_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_ChangeStatus_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, status bootstrap.Status) error) *ConfigRepository_ChangeStatus_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Remove provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) Remove(ctx context.Context, domainID string, id string) error {
|
||||
ret := _mock.Called(ctx, domainID, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Remove")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, domainID, id)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// ConfigRepository_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove'
|
||||
type ConfigRepository_Remove_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Remove is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - id string
|
||||
func (_e *ConfigRepository_Expecter) Remove(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_Remove_Call {
|
||||
return &ConfigRepository_Remove_Call{Call: _e.mock.On("Remove", ctx, domainID, id)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Remove_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_Remove_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Remove_Call) Return(err error) *ConfigRepository_Remove_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Remove_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) error) *ConfigRepository_Remove_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveAll provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) RetrieveAll(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage {
|
||||
ret := _mock.Called(ctx, domainID, filter, offset, limit)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveAll")
|
||||
}
|
||||
|
||||
var r0 bootstrap.ConfigsPage
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok {
|
||||
r0 = returnFunc(ctx, domainID, filter, offset, limit)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.ConfigsPage)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// ConfigRepository_RetrieveAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAll'
|
||||
type ConfigRepository_RetrieveAll_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveAll is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - filter bootstrap.Filter
|
||||
// - offset uint64
|
||||
// - limit uint64
|
||||
func (_e *ConfigRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, filter interface{}, offset interface{}, limit interface{}) *ConfigRepository_RetrieveAll_Call {
|
||||
return &ConfigRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, filter, offset, limit)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64)) *ConfigRepository_RetrieveAll_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 bootstrap.Filter
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(bootstrap.Filter)
|
||||
}
|
||||
var arg3 uint64
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(uint64)
|
||||
}
|
||||
var arg4 uint64
|
||||
if args[4] != nil {
|
||||
arg4 = args[4].(uint64)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
arg4,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveAll_Call) Return(configsPage bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call {
|
||||
_c.Call.Return(configsPage)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveByExternalID provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
|
||||
ret := _mock.Called(ctx, externalID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByExternalID")
|
||||
}
|
||||
|
||||
var r0 bootstrap.Config
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bootstrap.Config, error)); ok {
|
||||
return returnFunc(ctx, externalID)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) bootstrap.Config); ok {
|
||||
r0 = returnFunc(ctx, externalID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.Config)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, externalID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ConfigRepository_RetrieveByExternalID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByExternalID'
|
||||
type ConfigRepository_RetrieveByExternalID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveByExternalID is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - externalID string
|
||||
func (_e *ConfigRepository_Expecter) RetrieveByExternalID(ctx interface{}, externalID interface{}) *ConfigRepository_RetrieveByExternalID_Call {
|
||||
return &ConfigRepository_RetrieveByExternalID_Call{Call: _e.mock.On("RetrieveByExternalID", ctx, externalID)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByExternalID_Call) Run(run func(ctx context.Context, externalID string)) *ConfigRepository_RetrieveByExternalID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByExternalID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByExternalID_Call {
|
||||
_c.Call.Return(config, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByExternalID_Call) RunAndReturn(run func(ctx context.Context, externalID string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByExternalID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveByID provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) RetrieveByID(ctx context.Context, domainID string, id string) (bootstrap.Config, error) {
|
||||
ret := _mock.Called(ctx, domainID, id)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveByID")
|
||||
}
|
||||
|
||||
var r0 bootstrap.Config
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (bootstrap.Config, error)); ok {
|
||||
return returnFunc(ctx, domainID, id)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) bootstrap.Config); ok {
|
||||
r0 = returnFunc(ctx, domainID, id)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.Config)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, domainID, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ConfigRepository_RetrieveByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByID'
|
||||
type ConfigRepository_RetrieveByID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveByID is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - id string
|
||||
func (_e *ConfigRepository_Expecter) RetrieveByID(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_RetrieveByID_Call {
|
||||
return &ConfigRepository_RetrieveByID_Call{Call: _e.mock.On("RetrieveByID", ctx, domainID, id)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByID_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_RetrieveByID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByID_Call {
|
||||
_c.Call.Return(config, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_RetrieveByID_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Save provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) Save(ctx context.Context, cfg bootstrap.Config) (string, error) {
|
||||
ret := _mock.Called(ctx, cfg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Save")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) (string, error)); ok {
|
||||
return returnFunc(ctx, cfg)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) string); ok {
|
||||
r0 = returnFunc(ctx, cfg)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, bootstrap.Config) error); ok {
|
||||
r1 = returnFunc(ctx, cfg)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ConfigRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save'
|
||||
type ConfigRepository_Save_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Save is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cfg bootstrap.Config
|
||||
func (_e *ConfigRepository_Expecter) Save(ctx interface{}, cfg interface{}) *ConfigRepository_Save_Call {
|
||||
return &ConfigRepository_Save_Call{Call: _e.mock.On("Save", ctx, cfg)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Save_Call) Run(run func(ctx context.Context, cfg bootstrap.Config)) *ConfigRepository_Save_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 bootstrap.Config
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(bootstrap.Config)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Save_Call) Return(s string, err error) *ConfigRepository_Save_Call {
|
||||
_c.Call.Return(s, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Save_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config) (string, error)) *ConfigRepository_Save_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Update provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
|
||||
ret := _mock.Called(ctx, cfg)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Update")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) error); ok {
|
||||
r0 = returnFunc(ctx, cfg)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// ConfigRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update'
|
||||
type ConfigRepository_Update_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Update is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - cfg bootstrap.Config
|
||||
func (_e *ConfigRepository_Expecter) Update(ctx interface{}, cfg interface{}) *ConfigRepository_Update_Call {
|
||||
return &ConfigRepository_Update_Call{Call: _e.mock.On("Update", ctx, cfg)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Update_Call) Run(run func(ctx context.Context, cfg bootstrap.Config)) *ConfigRepository_Update_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 bootstrap.Config
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(bootstrap.Config)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Update_Call) Return(err error) *ConfigRepository_Update_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_Update_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config) error) *ConfigRepository_Update_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateCert provides a mock function for the type ConfigRepository
|
||||
func (_mock *ConfigRepository) UpdateCert(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) {
|
||||
ret := _mock.Called(ctx, domainID, id, clientCert, clientKey, caCert)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateCert")
|
||||
}
|
||||
|
||||
var r0 bootstrap.Config
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (bootstrap.Config, error)); ok {
|
||||
return returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) bootstrap.Config); ok {
|
||||
r0 = returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bootstrap.Config)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, domainID, id, clientCert, clientKey, caCert)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ConfigRepository_UpdateCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCert'
|
||||
type ConfigRepository_UpdateCert_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateCert is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - domainID string
|
||||
// - id string
|
||||
// - clientCert string
|
||||
// - clientKey string
|
||||
// - caCert string
|
||||
func (_e *ConfigRepository_Expecter) UpdateCert(ctx interface{}, domainID interface{}, id interface{}, clientCert interface{}, clientKey interface{}, caCert interface{}) *ConfigRepository_UpdateCert_Call {
|
||||
return &ConfigRepository_UpdateCert_Call{Call: _e.mock.On("UpdateCert", ctx, domainID, id, clientCert, clientKey, caCert)}
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_UpdateCert_Call) Run(run func(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string)) *ConfigRepository_UpdateCert_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 string
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(string)
|
||||
}
|
||||
var arg4 string
|
||||
if args[4] != nil {
|
||||
arg4 = args[4].(string)
|
||||
}
|
||||
var arg5 string
|
||||
if args[5] != nil {
|
||||
arg5 = args[5].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
arg4,
|
||||
arg5,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_UpdateCert_Call) Return(config bootstrap.Config, err error) *ConfigRepository_UpdateCert_Call {
|
||||
_c.Call.Return(config, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ConfigRepository_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error)) *ConfigRepository_UpdateCert_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,420 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
"github.com/absmach/magistrala/pkg/postgres"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
const jsonNull = "null"
|
||||
|
||||
var _ bootstrap.ConfigRepository = (*configRepository)(nil)
|
||||
|
||||
type configRepository struct {
|
||||
db postgres.Database
|
||||
log *slog.Logger
|
||||
}
|
||||
|
||||
// NewConfigRepository instantiates a PostgreSQL implementation of config
|
||||
// repository.
|
||||
func NewConfigRepository(db postgres.Database, log *slog.Logger) bootstrap.ConfigRepository {
|
||||
return &configRepository{db: db, log: log}
|
||||
}
|
||||
|
||||
func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config) (string, error) {
|
||||
q := `INSERT INTO configs (id, domain_id, name, client_cert, client_key, ca_cert, external_id, external_key, content, status, profile_id, render_context)
|
||||
VALUES (:id, :domain_id, :name, :client_cert, :client_key, :ca_cert, :external_id, :external_key, :content, :status, :profile_id, :render_context)`
|
||||
|
||||
dbcfg, err := toDBConfig(cfg)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
|
||||
switch pgErr := err.(type) {
|
||||
case *pgconn.PgError:
|
||||
if pgErr.Code == pgerrcode.UniqueViolation {
|
||||
return "", repoerr.ErrConflict
|
||||
}
|
||||
}
|
||||
return "", errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return cfg.ID, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Config, error) {
|
||||
q := `SELECT id, external_id, name, content, status, client_cert, client_key, ca_cert, profile_id, render_context
|
||||
FROM configs
|
||||
WHERE id = :id AND domain_id = :domain_id`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ID: id,
|
||||
DomainID: domainID,
|
||||
}
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if !row.Next() {
|
||||
return bootstrap.Config{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbcfg); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
cfg, err := toConfig(dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveAll(ctx context.Context, domainID string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage {
|
||||
search, params := buildRetrieveQueryParams(domainID, filter)
|
||||
n := len(params)
|
||||
|
||||
q := `SELECT id, external_id, name, content, status, profile_id, render_context
|
||||
FROM configs %s ORDER BY id LIMIT $%d OFFSET $%d`
|
||||
q = fmt.Sprintf(q, search, n+1, n+2)
|
||||
|
||||
rows, err := cr.db.QueryContext(ctx, q, append(params, limit, offset)...)
|
||||
if err != nil {
|
||||
cr.log.Error(fmt.Sprintf("Failed to retrieve configs due to %s", err))
|
||||
return bootstrap.ConfigsPage{}
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var name, content, profileID sql.NullString
|
||||
var renderContext []byte
|
||||
configs := []bootstrap.Config{}
|
||||
|
||||
for rows.Next() {
|
||||
c := bootstrap.Config{DomainID: domainID}
|
||||
if err := rows.Scan(&c.ID, &c.ExternalID, &name, &content, &c.Status, &profileID, &renderContext); err != nil {
|
||||
cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err))
|
||||
return bootstrap.ConfigsPage{}
|
||||
}
|
||||
|
||||
c.Name = name.String
|
||||
c.Content = content.String
|
||||
if profileID.Valid {
|
||||
c.ProfileID = profileID.String
|
||||
}
|
||||
if len(renderContext) > 0 && string(renderContext) != jsonNull {
|
||||
if err := json.Unmarshal(renderContext, &c.RenderContext); err != nil {
|
||||
cr.log.Error(fmt.Sprintf("Failed to decode render context due to %s", err))
|
||||
return bootstrap.ConfigsPage{}
|
||||
}
|
||||
}
|
||||
configs = append(configs, c)
|
||||
}
|
||||
|
||||
q = fmt.Sprintf(`SELECT COUNT(*) FROM configs %s`, search)
|
||||
|
||||
var total uint64
|
||||
if err := cr.db.QueryRowxContext(ctx, q, params...).Scan(&total); err != nil {
|
||||
cr.log.Error(fmt.Sprintf("Failed to count configs due to %s", err))
|
||||
return bootstrap.ConfigsPage{}
|
||||
}
|
||||
|
||||
return bootstrap.ConfigsPage{
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) {
|
||||
q := `SELECT id, external_key, domain_id, name, client_cert, client_key, ca_cert, content, status, profile_id, render_context
|
||||
FROM configs
|
||||
WHERE external_id = :external_id`
|
||||
dbcfg := dbConfig{
|
||||
ExternalID: externalID,
|
||||
}
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if !row.Next() {
|
||||
return bootstrap.Config{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbcfg); err != nil {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
cfg, err := toConfig(dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) error {
|
||||
q := `UPDATE configs SET name = :name, content = :content, render_context = :render_context WHERE id = :id AND domain_id = :domain_id `
|
||||
|
||||
renderContext, err := json.Marshal(cfg.RenderContext)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
dbcfg := dbConfig{
|
||||
Name: nullString(cfg.Name),
|
||||
Content: nullString(cfg.Content),
|
||||
RenderContext: renderContext,
|
||||
ID: cfg.ID,
|
||||
DomainID: cfg.DomainID,
|
||||
}
|
||||
|
||||
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
cnt, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) AssignProfile(ctx context.Context, domainID, id, profileID string) error {
|
||||
q := `UPDATE configs SET profile_id = :profile_id WHERE id = :id AND domain_id = :domain_id`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ID: id,
|
||||
DomainID: domainID,
|
||||
ProfileID: nullString(profileID),
|
||||
}
|
||||
|
||||
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
cnt, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) UpdateCert(ctx context.Context, domainID, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE id = :id AND domain_id = :domain_id
|
||||
RETURNING id, client_cert, client_key, ca_cert, domain_id`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ID: id,
|
||||
ClientCert: nullString(clientCert),
|
||||
DomainID: domainID,
|
||||
ClientKey: nullString(clientKey),
|
||||
CaCert: nullString(caCert),
|
||||
}
|
||||
|
||||
row, err := cr.db.NamedQueryContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
defer row.Close()
|
||||
|
||||
if ok := row.Next(); !ok {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrNotFound, row.Err())
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbcfg); err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
|
||||
cfg, err := toConfig(dbcfg)
|
||||
if err != nil {
|
||||
return bootstrap.Config{}, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (cr configRepository) Remove(ctx context.Context, domainID, id string) error {
|
||||
q := `DELETE FROM configs WHERE id = :id AND domain_id = :domain_id`
|
||||
dbcfg := dbConfig{
|
||||
ID: id,
|
||||
DomainID: domainID,
|
||||
}
|
||||
|
||||
if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cr configRepository) ChangeStatus(ctx context.Context, domainID, id string, status bootstrap.Status) error {
|
||||
q := `UPDATE configs SET status = :status WHERE id = :id AND domain_id = :domain_id;`
|
||||
|
||||
dbcfg := dbConfig{
|
||||
ID: id,
|
||||
Status: status,
|
||||
DomainID: domainID,
|
||||
}
|
||||
|
||||
res, err := cr.db.NamedExecContext(ctx, q, dbcfg)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
cnt, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if cnt == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildRetrieveQueryParams(domainID string, filter bootstrap.Filter) (string, []any) {
|
||||
params := []any{}
|
||||
queries := []string{}
|
||||
|
||||
if domainID != "" {
|
||||
params = append(params, domainID)
|
||||
queries = append(queries, fmt.Sprintf("domain_id = $%d", len(params)))
|
||||
}
|
||||
|
||||
counter := len(params) + 1
|
||||
for k, v := range filter.FullMatch {
|
||||
if k == "status" {
|
||||
status, err := bootstrap.ToStatus(v)
|
||||
if err != nil {
|
||||
return "", nil
|
||||
}
|
||||
if status == bootstrap.AllStatus {
|
||||
continue
|
||||
}
|
||||
params = append(params, status)
|
||||
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
|
||||
counter++
|
||||
continue
|
||||
}
|
||||
params = append(params, v)
|
||||
queries = append(queries, fmt.Sprintf("%s = $%d", k, counter))
|
||||
counter++
|
||||
}
|
||||
for k, v := range filter.PartialMatch {
|
||||
params = append(params, v)
|
||||
queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter))
|
||||
counter++
|
||||
}
|
||||
|
||||
if len(queries) > 0 {
|
||||
return "WHERE " + strings.Join(queries, " AND "), params
|
||||
}
|
||||
return "", params
|
||||
}
|
||||
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
|
||||
type dbConfig struct {
|
||||
DomainID string `db:"domain_id"`
|
||||
ID string `db:"id"`
|
||||
Name sql.NullString `db:"name"`
|
||||
ClientCert sql.NullString `db:"client_cert"`
|
||||
ClientKey sql.NullString `db:"client_key"`
|
||||
CaCert sql.NullString `db:"ca_cert"`
|
||||
ExternalID string `db:"external_id"`
|
||||
ExternalKey string `db:"external_key"`
|
||||
Content sql.NullString `db:"content"`
|
||||
Status bootstrap.Status `db:"status"`
|
||||
ProfileID sql.NullString `db:"profile_id"`
|
||||
RenderContext []byte `db:"render_context"`
|
||||
}
|
||||
|
||||
func toDBConfig(cfg bootstrap.Config) (dbConfig, error) {
|
||||
renderContext, err := json.Marshal(cfg.RenderContext)
|
||||
if err != nil {
|
||||
return dbConfig{}, err
|
||||
}
|
||||
|
||||
return dbConfig{
|
||||
ID: cfg.ID,
|
||||
DomainID: cfg.DomainID,
|
||||
Name: nullString(cfg.Name),
|
||||
ClientCert: nullString(cfg.ClientCert),
|
||||
ClientKey: nullString(cfg.ClientKey),
|
||||
CaCert: nullString(cfg.CACert),
|
||||
ExternalID: cfg.ExternalID,
|
||||
ExternalKey: cfg.ExternalKey,
|
||||
Content: nullString(cfg.Content),
|
||||
Status: cfg.Status,
|
||||
ProfileID: nullString(cfg.ProfileID),
|
||||
RenderContext: renderContext,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func toConfig(dbcfg dbConfig) (bootstrap.Config, error) {
|
||||
cfg := bootstrap.Config{
|
||||
ID: dbcfg.ID,
|
||||
DomainID: dbcfg.DomainID,
|
||||
ExternalID: dbcfg.ExternalID,
|
||||
ExternalKey: dbcfg.ExternalKey,
|
||||
Status: dbcfg.Status,
|
||||
}
|
||||
if dbcfg.ProfileID.Valid {
|
||||
cfg.ProfileID = dbcfg.ProfileID.String
|
||||
}
|
||||
|
||||
if dbcfg.Name.Valid {
|
||||
cfg.Name = dbcfg.Name.String
|
||||
}
|
||||
if dbcfg.Content.Valid {
|
||||
cfg.Content = dbcfg.Content.String
|
||||
}
|
||||
if len(dbcfg.RenderContext) > 0 && string(dbcfg.RenderContext) != jsonNull {
|
||||
if err := json.Unmarshal(dbcfg.RenderContext, &cfg.RenderContext); err != nil {
|
||||
return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
}
|
||||
if dbcfg.ClientCert.Valid {
|
||||
cfg.ClientCert = dbcfg.ClientCert.String
|
||||
}
|
||||
if dbcfg.ClientKey.Valid {
|
||||
cfg.ClientKey = dbcfg.ClientKey.String
|
||||
}
|
||||
if dbcfg.CaCert.Valid {
|
||||
cfg.CACert = dbcfg.CaCert.String
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
@@ -1,471 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/bootstrap/postgres"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
"github.com/gofrs/uuid/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const numConfigs = 10
|
||||
|
||||
var config = bootstrap.Config{
|
||||
ID: "mg-client",
|
||||
ExternalID: "external-id",
|
||||
ExternalKey: "external-key",
|
||||
DomainID: testsutil.GenerateUUID(&testing.T{}),
|
||||
Content: "content",
|
||||
Status: bootstrap.Inactive,
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
diff := "different"
|
||||
|
||||
duplicateClient := config
|
||||
duplicateClient.ExternalID = diff
|
||||
|
||||
duplicateExternal := config
|
||||
duplicateExternal.ID = diff
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "save a config",
|
||||
config: config,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "save config with same Client ID",
|
||||
config: duplicateClient,
|
||||
err: repoerr.ErrConflict,
|
||||
},
|
||||
{
|
||||
desc: "save config with same external ID",
|
||||
config: duplicateExternal,
|
||||
err: repoerr.ErrConflict,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
id, err := repo.Save(context.Background(), tc.config)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, id, tc.config.ID, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.config.ID, id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveByID(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
id, err := repo.Save(context.Background(), c)
|
||||
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
nonexistentConfID, err := uuid.NewV4()
|
||||
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve config",
|
||||
domainID: c.DomainID,
|
||||
id: id,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve config with wrong domain ID ",
|
||||
domainID: "2",
|
||||
id: id,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a non-existing config",
|
||||
domainID: c.DomainID,
|
||||
id: nonexistentConfID.String(),
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a config with invalid ID",
|
||||
domainID: c.DomainID,
|
||||
id: "invalid",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
_, err := repo.RetrieveByID(context.Background(), tc.domainID, tc.id)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveAll(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
for i := 0; i < numConfigs; i++ {
|
||||
c := config
|
||||
|
||||
// Use UUID to prevent conflict errors.
|
||||
uid, err := uuid.NewV4()
|
||||
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ExternalID = uid.String()
|
||||
c.Name = fmt.Sprintf("name %d", i)
|
||||
c.ID = uid.String()
|
||||
|
||||
if i%2 == 0 {
|
||||
c.Status = bootstrap.Active
|
||||
}
|
||||
|
||||
_, err = repo.Save(context.Background(), c)
|
||||
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
}
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
offset uint64
|
||||
limit uint64
|
||||
filter bootstrap.Filter
|
||||
size int
|
||||
}{
|
||||
{
|
||||
desc: "retrieve all configs",
|
||||
domainID: config.DomainID,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: numConfigs,
|
||||
},
|
||||
{
|
||||
desc: "retrieve a subset of configs",
|
||||
domainID: config.DomainID,
|
||||
offset: 5,
|
||||
limit: uint64(numConfigs - 5),
|
||||
size: numConfigs - 5,
|
||||
},
|
||||
{
|
||||
desc: "retrieve with wrong domain ID ",
|
||||
domainID: "2",
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
size: 0,
|
||||
},
|
||||
{
|
||||
desc: "retrieve all active configs ",
|
||||
domainID: config.DomainID,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{FullMatch: map[string]string{"status": bootstrap.Active.String()}},
|
||||
size: numConfigs / 2,
|
||||
},
|
||||
{
|
||||
desc: "retrieve all with partial match filter",
|
||||
domainID: config.DomainID,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
|
||||
size: 1,
|
||||
},
|
||||
{
|
||||
desc: "retrieve search by name",
|
||||
domainID: config.DomainID,
|
||||
offset: 0,
|
||||
limit: uint64(numConfigs),
|
||||
filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}},
|
||||
size: 1,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
ret := repo.RetrieveAll(context.Background(), tc.domainID, tc.filter, tc.offset, tc.limit)
|
||||
size := len(ret.Configs)
|
||||
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveByExternalID(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
_, err = repo.Save(context.Background(), c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
externalID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve with invalid external ID",
|
||||
externalID: strconv.Itoa(numConfigs + 1),
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "retrieve with external key",
|
||||
externalID: c.ExternalID,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
_, err := repo.RetrieveByExternalID(context.Background(), tc.externalID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
_, err = repo.Save(context.Background(), c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
c.Content = "new content"
|
||||
c.Name = "new name"
|
||||
|
||||
withRenderContext := c
|
||||
withRenderContext.RenderContext = map[string]any{
|
||||
"site": "warehouse-2",
|
||||
"region": "mombasa",
|
||||
}
|
||||
|
||||
wrongDomainID := c
|
||||
wrongDomainID.DomainID = "3"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
renderContext map[string]any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update with wrong domainID",
|
||||
config: wrongDomainID,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update a config",
|
||||
config: c,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update a config render_context",
|
||||
config: withRenderContext,
|
||||
renderContext: map[string]any{"site": "warehouse-2", "region": "mombasa"},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := repo.Update(context.Background(), tc.config)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if tc.err == nil && tc.renderContext != nil {
|
||||
saved, err := repo.RetrieveByID(context.Background(), tc.config.DomainID, tc.config.ID)
|
||||
require.Nil(t, err, fmt.Sprintf("%s: unexpected retrieve error: %s\n", tc.desc, err))
|
||||
assert.Equal(t, tc.renderContext, saved.RenderContext, fmt.Sprintf("%s: expected render_context %v got %v\n", tc.desc, tc.renderContext, saved.RenderContext))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateCert(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
_, err = repo.Save(context.Background(), c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
c.Content = "new content"
|
||||
c.Name = "new name"
|
||||
|
||||
wrongDomainID := c
|
||||
wrongDomainID.DomainID = "3"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
configID string
|
||||
domainID string
|
||||
cert string
|
||||
certKey string
|
||||
ca string
|
||||
expectedConfig bootstrap.Config
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update with wrong domain ID ",
|
||||
configID: "",
|
||||
cert: "cert",
|
||||
certKey: "certKey",
|
||||
ca: "",
|
||||
domainID: wrongDomainID.DomainID,
|
||||
expectedConfig: bootstrap.Config{},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update a config",
|
||||
configID: c.ID,
|
||||
cert: "cert",
|
||||
certKey: "certKey",
|
||||
ca: "ca",
|
||||
domainID: c.DomainID,
|
||||
expectedConfig: bootstrap.Config{
|
||||
ID: c.ID,
|
||||
ClientCert: "cert",
|
||||
CACert: "ca",
|
||||
ClientKey: "certKey",
|
||||
DomainID: c.DomainID,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
cfg, err := repo.UpdateCert(context.Background(), tc.domainID, tc.configID, tc.cert, tc.certKey, tc.ca)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
id, err := repo.Save(context.Background(), c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
// Removal works the same for both existing and non-existing
|
||||
// (removed) config
|
||||
for i := 0; i < 2; i++ {
|
||||
err := repo.Remove(context.Background(), c.DomainID, id)
|
||||
assert.Nil(t, err, fmt.Sprintf("%d: failed to remove config due to: %s", i, err))
|
||||
|
||||
_, err = repo.RetrieveByID(context.Background(), c.DomainID, id)
|
||||
assert.True(t, errors.Contains(err, repoerr.ErrNotFound), fmt.Sprintf("%d: expected %s got %s", i, repoerr.ErrNotFound, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeStatus(t *testing.T) {
|
||||
repo := postgres.NewConfigRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
// Use UUID to prevent conflicts.
|
||||
uid, err := uuid.NewV4()
|
||||
assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
saved, err := repo.Save(context.Background(), c)
|
||||
assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
id string
|
||||
status bootstrap.Status
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "change status with wrong domain ID ",
|
||||
id: saved,
|
||||
domainID: "2",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "change status with wrong id",
|
||||
id: "wrong",
|
||||
domainID: c.DomainID,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "change status to Active",
|
||||
id: saved,
|
||||
domainID: c.DomainID,
|
||||
status: bootstrap.Active,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "change status to Inactive",
|
||||
id: saved,
|
||||
domainID: c.DomainID,
|
||||
status: bootstrap.Inactive,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := repo.ChangeStatus(context.Background(), tc.domainID, tc.id, tc.status)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignProfile(t *testing.T) {
|
||||
configRepo := postgres.NewConfigRepository(db, testLog)
|
||||
profileRepo := postgres.NewProfileRepository(db, testLog)
|
||||
|
||||
c := config
|
||||
uid, err := uuid.NewV4()
|
||||
require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err))
|
||||
c.ID = uid.String()
|
||||
c.ExternalID = uid.String()
|
||||
c.ExternalKey = uid.String()
|
||||
saved, err := configRepo.Save(context.Background(), c)
|
||||
require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err))
|
||||
|
||||
profileID := testsutil.GenerateUUID(t)
|
||||
_, err = profileRepo.Save(context.Background(), bootstrap.Profile{
|
||||
ID: profileID,
|
||||
DomainID: c.DomainID,
|
||||
Name: "edge-gateway",
|
||||
ContentFormat: bootstrap.ContentFormatGoTemplate,
|
||||
Version: 1,
|
||||
})
|
||||
require.Nil(t, err, fmt.Sprintf("Saving profile expected to succeed: %s.\n", err))
|
||||
|
||||
err = configRepo.AssignProfile(context.Background(), c.DomainID, saved, profileID)
|
||||
require.Nil(t, err, fmt.Sprintf("Assigning profile expected to succeed: %s.\n", err))
|
||||
|
||||
stored, err := configRepo.RetrieveByID(context.Background(), c.DomainID, saved)
|
||||
require.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, profileID, stored.ProfileID, "expected profile assignment to round-trip through the repository")
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package postgres contains repository implementations using PostgreSQL as
|
||||
// the underlying database.
|
||||
package postgres
|
||||
@@ -1,329 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import migrate "github.com/rubenv/sql-migrate"
|
||||
|
||||
// Migration of bootstrap service.
|
||||
func Migration() *migrate.MemoryMigrationSource {
|
||||
return &migrate.MemoryMigrationSource{
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "configs_1",
|
||||
Up: []string{
|
||||
`CREATE TABLE IF NOT EXISTS configs (
|
||||
mainflux_client TEXT UNIQUE NOT NULL,
|
||||
owner VARCHAR(254),
|
||||
name TEXT,
|
||||
mainflux_key CHAR(36) UNIQUE NOT NULL,
|
||||
external_id TEXT UNIQUE NOT NULL,
|
||||
external_key TEXT NOT NULL,
|
||||
content TEXT,
|
||||
client_cert TEXT,
|
||||
client_key TEXT,
|
||||
ca_cert TEXT,
|
||||
state BIGINT NOT NULL,
|
||||
PRIMARY KEY (mainflux_client, owner)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS unknown_configs (
|
||||
external_id TEXT UNIQUE NOT NULL,
|
||||
external_key TEXT NOT NULL,
|
||||
PRIMARY KEY (external_id, external_key)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS channels (
|
||||
mainflux_channel TEXT UNIQUE NOT NULL,
|
||||
owner VARCHAR(254),
|
||||
name TEXT,
|
||||
metadata JSON,
|
||||
PRIMARY KEY (mainflux_channel, owner)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS connections (
|
||||
channel_id TEXT,
|
||||
channel_owner VARCHAR(256),
|
||||
config_id TEXT,
|
||||
config_owner VARCHAR(256),
|
||||
FOREIGN KEY (channel_id, channel_owner) REFERENCES channels (mainflux_channel, owner) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (config_id, config_owner) REFERENCES configs (mainflux_client, owner) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
PRIMARY KEY (channel_id, channel_owner, config_id, config_owner)
|
||||
)`,
|
||||
},
|
||||
Down: []string{
|
||||
"DROP TABLE connections",
|
||||
"DROP TABLE configs",
|
||||
"DROP TABLE channels",
|
||||
"DROP TABLE unknown_configs",
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_2",
|
||||
Up: []string{
|
||||
"DROP TABLE IF EXISTS unknown_configs",
|
||||
},
|
||||
Down: []string{
|
||||
"CREATE TABLE IF NOT EXISTS unknown_configs",
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_3",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS parent_id VARCHAR(36)`,
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS description VARCHAR(1024)`,
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS created_at TIMESTAMP`,
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP`,
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_by VARCHAR(254)`,
|
||||
`ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_4",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_client TO magistrala_client`,
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_key TO magistrala_secret`,
|
||||
`ALTER TABLE IF EXISTS channels RENAME COLUMN mainflux_channel TO magistrala_channel`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_5",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN owner TO domain_id`,
|
||||
`ALTER TABLE IF EXISTS channels RENAME COLUMN owner TO domain_id`,
|
||||
`ALTER TABLE IF EXISTS configs ADD CONSTRAINT configs_name_domain_id_key UNIQUE (name, domain_id)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_6",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS connections DROP CONSTRAINT IF EXISTS connections_pkey`,
|
||||
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS channel_owner`,
|
||||
`ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS config_owner`,
|
||||
`ALTER TABLE IF EXISTS connections ADD COLUMN IF NOT EXISTS domain_id VARCHAR(256) NOT NULL`,
|
||||
`ALTER TABLE IF EXISTS connections ADD CONSTRAINT connections_pkey PRIMARY KEY (channel_id, config_id, domain_id)`,
|
||||
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (channel_id, domain_id) REFERENCES channels (magistrala_channel, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
|
||||
`ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (config_id, domain_id) REFERENCES configs (magistrala_client, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_7",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN magistrala_client TO client_id`,
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN magistrala_secret TO client_secret`,
|
||||
`CREATE UNIQUE INDEX IF NOT EXISTS configs_client_id_key ON configs (client_id)`,
|
||||
`CREATE UNIQUE INDEX IF NOT EXISTS configs_client_id_domain_id_key ON configs (client_id, domain_id)`,
|
||||
`DROP TABLE IF EXISTS connections`,
|
||||
`DROP TABLE IF EXISTS channels`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN client_id TO magistrala_client`,
|
||||
`ALTER TABLE IF EXISTS configs RENAME COLUMN client_secret TO magistrala_secret`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_8",
|
||||
Up: []string{
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'client_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'id'
|
||||
) THEN
|
||||
ALTER TABLE configs RENAME COLUMN client_id TO id;
|
||||
END IF;
|
||||
END $$`,
|
||||
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS client_secret`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS client_secret TEXT`,
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'client_id'
|
||||
) THEN
|
||||
ALTER TABLE configs RENAME COLUMN id TO client_id;
|
||||
END IF;
|
||||
END $$`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_10",
|
||||
Up: []string{
|
||||
`CREATE TABLE IF NOT EXISTS profiles (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
domain_id VARCHAR(36) NOT NULL,
|
||||
name VARCHAR(1024) NOT NULL,
|
||||
description TEXT,
|
||||
template_format VARCHAR(64) NOT NULL DEFAULT 'go-template',
|
||||
content_template TEXT,
|
||||
defaults JSONB,
|
||||
binding_slots JSONB,
|
||||
version INT NOT NULL DEFAULT 1,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
UNIQUE (domain_id, name)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_profiles_domain_id ON profiles (domain_id)`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP TABLE IF EXISTS profiles`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_11",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS profile_id VARCHAR(36) REFERENCES profiles (id) ON DELETE SET NULL`,
|
||||
`ALTER TABLE IF EXISTS configs ADD COLUMN IF NOT EXISTS render_context JSONB`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS render_context`,
|
||||
`ALTER TABLE IF EXISTS configs DROP COLUMN IF EXISTS profile_id`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_12",
|
||||
Up: []string{
|
||||
`CREATE TABLE IF NOT EXISTS bindings (
|
||||
config_id TEXT NOT NULL,
|
||||
slot VARCHAR(256) NOT NULL,
|
||||
type VARCHAR(64) NOT NULL,
|
||||
resource_id TEXT NOT NULL,
|
||||
snapshot JSONB,
|
||||
secret_snapshot BYTEA,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (config_id, slot)
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_bindings_config_id ON bindings (config_id)`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP TABLE IF EXISTS bindings`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_13",
|
||||
Up: []string{
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'state'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'status'
|
||||
) THEN
|
||||
ALTER TABLE configs RENAME COLUMN state TO status;
|
||||
END IF;
|
||||
END $$`,
|
||||
},
|
||||
Down: []string{
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'status'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = 'configs' AND column_name = 'state'
|
||||
) THEN
|
||||
ALTER TABLE configs RENAME COLUMN status TO state;
|
||||
END IF;
|
||||
END $$`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_14",
|
||||
Up: []string{
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = 'binding_snapshots'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = 'bindings'
|
||||
) THEN
|
||||
ALTER TABLE binding_snapshots RENAME TO bindings;
|
||||
END IF;
|
||||
END $$`,
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_class
|
||||
WHERE relname = 'idx_binding_snapshots_config_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_class
|
||||
WHERE relname = 'idx_bindings_config_id'
|
||||
) THEN
|
||||
ALTER INDEX idx_binding_snapshots_config_id RENAME TO idx_bindings_config_id;
|
||||
END IF;
|
||||
END $$`,
|
||||
},
|
||||
Down: []string{
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = 'bindings'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = 'binding_snapshots'
|
||||
) THEN
|
||||
ALTER TABLE bindings RENAME TO binding_snapshots;
|
||||
END IF;
|
||||
END $$`,
|
||||
`DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_class
|
||||
WHERE relname = 'idx_bindings_config_id'
|
||||
) AND NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_class
|
||||
WHERE relname = 'idx_binding_snapshots_config_id'
|
||||
) THEN
|
||||
ALTER INDEX idx_bindings_config_id RENAME TO idx_binding_snapshots_config_id;
|
||||
END IF;
|
||||
END $$`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_15",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS profiles ADD COLUMN IF NOT EXISTS binding_slots JSONB`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE IF EXISTS profiles DROP COLUMN IF EXISTS binding_slots`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "configs_16",
|
||||
Up: []string{
|
||||
`ALTER TABLE IF EXISTS profiles RENAME COLUMN template_format TO content_format`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE IF EXISTS profiles RENAME COLUMN content_format TO template_format`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap/postgres"
|
||||
mglog "github.com/absmach/magistrala/logger"
|
||||
pgclient "github.com/absmach/magistrala/pkg/postgres"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
)
|
||||
|
||||
var (
|
||||
testLog, _ = mglog.New(os.Stdout, "info")
|
||||
db *sqlx.DB
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err))
|
||||
}
|
||||
|
||||
container, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "postgres",
|
||||
Tag: "16.2-alpine",
|
||||
Env: []string{
|
||||
"POSTGRES_USER=test",
|
||||
"POSTGRES_PASSWORD=test",
|
||||
"POSTGRES_DB=test",
|
||||
"listen_addresses = '*'",
|
||||
},
|
||||
}, func(config *docker.HostConfig) {
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
port := container.GetPort("5432/tcp")
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
|
||||
db, err = sqlx.Open("pgx", url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return db.Ping()
|
||||
}); err != nil {
|
||||
testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err))
|
||||
}
|
||||
|
||||
dbConfig := pgclient.Config{
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
User: "test",
|
||||
Pass: "test",
|
||||
Name: "test",
|
||||
SSLMode: "disable",
|
||||
SSLCert: "",
|
||||
SSLKey: "",
|
||||
SSLRootCert: "",
|
||||
}
|
||||
|
||||
migration := postgres.Migration()
|
||||
|
||||
if db, err = pgclient.Setup(dbConfig, *migration); err != nil {
|
||||
testLog.Error(fmt.Sprintf("Could not setup test DB connection: %s", err))
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
// Defers will not be run when using os.Exit
|
||||
db.Close()
|
||||
if err := pool.Purge(container); err != nil {
|
||||
testLog.Error(fmt.Sprintf("Could not purge container: %s", err))
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// bootstrapRes represent Magistrala Response to the Bootstrap request.
|
||||
// This is used as a response from ConfigReader and can easily be
|
||||
// replaced with any other response format.
|
||||
type bootstrapRes struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
}
|
||||
|
||||
func (res bootstrapRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res bootstrapRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res bootstrapRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type reader struct {
|
||||
encKey []byte
|
||||
}
|
||||
|
||||
// NewConfigReader return new reader which is used to generate response
|
||||
// from the config.
|
||||
func NewConfigReader(encKey []byte) ConfigReader {
|
||||
return reader{encKey: encKey}
|
||||
}
|
||||
|
||||
func (r reader) ReadConfig(cfg Config, secure bool) (any, error) {
|
||||
res := bootstrapRes{
|
||||
ID: cfg.ID,
|
||||
Content: cfg.Content,
|
||||
ClientCert: cfg.ClientCert,
|
||||
ClientKey: cfg.ClientKey,
|
||||
CACert: cfg.CACert,
|
||||
}
|
||||
if secure {
|
||||
b, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.encrypt(b)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (r reader) encrypt(in []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(r.encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ciphertext := make([]byte, aes.BlockSize+len(in))
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream := cipher.NewCFBEncrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext[aes.BlockSize:], in)
|
||||
return ciphertext, nil
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package bootstrap_test
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type readResp struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ClientCert string `json:"client_cert,omitempty"`
|
||||
ClientKey string `json:"client_key,omitempty"`
|
||||
CACert string `json:"ca_cert,omitempty"`
|
||||
}
|
||||
|
||||
func dec(in []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(encKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(in) < aes.BlockSize {
|
||||
return nil, errors.ErrMalformedEntity
|
||||
}
|
||||
iv := in[:aes.BlockSize]
|
||||
in = in[aes.BlockSize:]
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(in, in)
|
||||
return in, nil
|
||||
}
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
cfg := bootstrap.Config{
|
||||
ID: "smq_id",
|
||||
ClientCert: "client_cert",
|
||||
ClientKey: "client_key",
|
||||
CACert: "ca_cert",
|
||||
Content: "content",
|
||||
}
|
||||
ret := readResp{
|
||||
ID: "smq_id",
|
||||
Content: "content",
|
||||
ClientCert: "client_cert",
|
||||
ClientKey: "client_key",
|
||||
CACert: "ca_cert",
|
||||
}
|
||||
|
||||
bin, err := json.Marshal(ret)
|
||||
assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err))
|
||||
|
||||
reader := bootstrap.NewConfigReader(encKey)
|
||||
cases := []struct {
|
||||
desc string
|
||||
config bootstrap.Config
|
||||
enc []byte
|
||||
secret bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "read a config",
|
||||
config: cfg,
|
||||
enc: bin,
|
||||
secret: false,
|
||||
},
|
||||
{
|
||||
desc: "read encrypted config",
|
||||
config: cfg,
|
||||
enc: bin,
|
||||
secret: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
res, err := reader.ReadConfig(tc.config, tc.secret)
|
||||
assert.Nil(t, err, fmt.Sprintf("Reading config to succeed: %s.\n", err))
|
||||
|
||||
if tc.secret {
|
||||
d, err := dec(res.([]byte))
|
||||
assert.Nil(t, err, fmt.Sprintf("Decrypting expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, tc.enc, d, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, d))
|
||||
continue
|
||||
}
|
||||
b, err := json.Marshal(res)
|
||||
assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err))
|
||||
assert.Equal(t, tc.enc, b, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, b))
|
||||
resp, ok := res.(magistrala.Response)
|
||||
assert.True(t, ok, "If not encrypted, reader should return response.")
|
||||
assert.False(t, resp.Empty(), fmt.Sprintf("Response should not be empty %s.", err))
|
||||
assert.Equal(t, http.StatusOK, resp.Code(), "Default config response code should be 200.")
|
||||
}
|
||||
}
|
||||
@@ -1,532 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
mgsdk "github.com/absmach/magistrala/pkg/sdk"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrExternalKey indicates a non-existent bootstrap configuration for given external key.
|
||||
ErrExternalKey = errors.NewAuthZError("failed to get bootstrap configuration for given external key")
|
||||
|
||||
// ErrExternalKeySecure indicates error in getting bootstrap configuration for given encrypted external key.
|
||||
ErrExternalKeySecure = errors.NewAuthZError("failed to get bootstrap configuration for given encrypted external key")
|
||||
|
||||
// ErrBootstrap indicates error in getting bootstrap configuration.
|
||||
ErrBootstrap = errors.New("failed to read bootstrap configuration")
|
||||
|
||||
// ErrAddBootstrap indicates error in adding bootstrap configuration.
|
||||
ErrAddBootstrap = errors.NewServiceError("failed to add bootstrap configuration")
|
||||
|
||||
// ErrBootstrapStatus indicates an invalid bootstrap status.
|
||||
ErrBootstrapStatus = errors.NewRequestError("invalid bootstrap status")
|
||||
|
||||
errRemoveBootstrap = errors.New("failed to remove bootstrap configuration")
|
||||
errEnableConfig = errors.New("failed to enable bootstrap configuration")
|
||||
errDisableConfig = errors.New("failed to disable bootstrap configuration")
|
||||
errUpdateCert = errors.New("failed to update cert")
|
||||
|
||||
errCreateProfile = errors.New("failed to create profile")
|
||||
errViewProfile = errors.New("failed to view profile")
|
||||
errUpdateProfile = errors.New("failed to update profile")
|
||||
errDeleteProfile = errors.New("failed to delete profile")
|
||||
errListProfiles = errors.New("failed to list profiles")
|
||||
errAssignProfile = errors.New("failed to assign profile to enrollment")
|
||||
errBindResources = errors.New("failed to bind resources")
|
||||
errListBindings = errors.New("failed to list bindings")
|
||||
errRefreshBinding = errors.New("failed to refresh bindings")
|
||||
errRenderBootstrap = errors.New("failed to render bootstrap configuration")
|
||||
)
|
||||
|
||||
var _ Service = (*bootstrapService)(nil)
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
// implementation, and all of its decorators (e.g. logging & metrics).
|
||||
type Service interface {
|
||||
// Add adds new Client Config to the user identified by the provided token.
|
||||
Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error)
|
||||
|
||||
// View returns Client Config with given ID belonging to the user identified by the given token.
|
||||
View(ctx context.Context, session smqauthn.Session, id string) (Config, error)
|
||||
|
||||
// Update updates editable fields of the provided Config.
|
||||
Update(ctx context.Context, session smqauthn.Session, cfg Config) error
|
||||
|
||||
// UpdateCert updates an existing Config certificate and token.
|
||||
// A non-nil error is returned to indicate operation failure.
|
||||
UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (Config, error)
|
||||
|
||||
// List returns subset of Configs with given search params that belong to the
|
||||
// user identified by the given token.
|
||||
List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error)
|
||||
|
||||
// Remove removes Config with specified token that belongs to the user identified by the given token.
|
||||
Remove(ctx context.Context, session smqauthn.Session, id string) error
|
||||
|
||||
// Bootstrap returns Config to the Client with provided external ID using external key.
|
||||
Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error)
|
||||
|
||||
// EnableConfig enables the Config so its device can successfully bootstrap.
|
||||
EnableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error)
|
||||
|
||||
// DisableConfig disables the Config, preventing its device from bootstrapping.
|
||||
DisableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error)
|
||||
|
||||
// CreateProfile persists a new device Profile.
|
||||
CreateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error)
|
||||
|
||||
// ViewProfile returns the Profile with the given ID.
|
||||
ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error)
|
||||
|
||||
// UpdateProfile updates editable fields of the given Profile and returns the updated Profile.
|
||||
UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error)
|
||||
|
||||
// ListProfiles returns a page of Profiles belonging to the domain.
|
||||
ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error)
|
||||
|
||||
// DeleteProfile removes the Profile with the given ID.
|
||||
DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error
|
||||
|
||||
// AssignProfile sets the ProfileID on an existing enrollment (Config).
|
||||
AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error
|
||||
|
||||
// BindResources resolves the requested bindings through their owning services,
|
||||
// stores snapshots, and marks the enrollment renderable when all required slots
|
||||
// are satisfied.
|
||||
BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []BindingRequest) error
|
||||
|
||||
// ListBindings returns all stored binding snapshots for an enrollment.
|
||||
ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]BindingSnapshot, error)
|
||||
|
||||
// RefreshBindings re-resolves all existing bindings for an enrollment and
|
||||
// updates the stored snapshots.
|
||||
RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error
|
||||
}
|
||||
|
||||
// ConfigReader is used to parse Config into format which will be encoded
|
||||
// as a JSON and consumed from the client side. The purpose of this interface
|
||||
// is to provide convenient way to generate custom configuration response
|
||||
// based on the specific Config which will be consumed by the client.
|
||||
type ConfigReader interface {
|
||||
ReadConfig(Config, bool) (any, error)
|
||||
}
|
||||
|
||||
type bootstrapService struct {
|
||||
configs ConfigRepository
|
||||
profiles ProfileRepository
|
||||
bindings BindingStore
|
||||
resolver BindingResolver
|
||||
renderer Renderer
|
||||
hasher Hasher
|
||||
sdk mgsdk.SDK
|
||||
encKey []byte
|
||||
idProvider magistrala.IDProvider
|
||||
}
|
||||
|
||||
// New returns new Bootstrap service.
|
||||
func New(
|
||||
configs ConfigRepository,
|
||||
profiles ProfileRepository,
|
||||
bindings BindingStore,
|
||||
resolver BindingResolver,
|
||||
renderer Renderer,
|
||||
sdk mgsdk.SDK,
|
||||
hasher Hasher,
|
||||
encKey []byte,
|
||||
idp magistrala.IDProvider,
|
||||
) Service {
|
||||
return &bootstrapService{
|
||||
configs: configs,
|
||||
profiles: profiles,
|
||||
bindings: bindings,
|
||||
resolver: resolver,
|
||||
renderer: renderer,
|
||||
hasher: hasher,
|
||||
sdk: sdk,
|
||||
encKey: encKey,
|
||||
idProvider: idp,
|
||||
}
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error) {
|
||||
id, err := bs.idProvider.ID()
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(ErrAddBootstrap, err)
|
||||
}
|
||||
|
||||
hashedKey, err := bs.hasher.Hash(cfg.ExternalKey)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(ErrAddBootstrap, err)
|
||||
}
|
||||
|
||||
cfg.ID = id
|
||||
cfg.DomainID = session.DomainID
|
||||
cfg.Status = Active
|
||||
cfg.ExternalKey = hashedKey
|
||||
|
||||
saved, err := bs.configs.Save(ctx, cfg)
|
||||
if err != nil {
|
||||
if errors.Contains(err, repoerr.ErrConflict) {
|
||||
return Config{}, errors.Wrap(svcerr.ErrConflict, err)
|
||||
}
|
||||
return Config{}, errors.Wrap(ErrAddBootstrap, err)
|
||||
}
|
||||
|
||||
cfg.ID = saved
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) View(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, id)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Update(ctx context.Context, session smqauthn.Session, cfg Config) error {
|
||||
cfg.DomainID = session.DomainID
|
||||
if err := bs.configs.Update(ctx, cfg); err != nil {
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (Config, error) {
|
||||
cfg, err := bs.configs.UpdateCert(ctx, session.DomainID, id, clientCert, clientKey, caCert)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errUpdateCert, err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error) {
|
||||
return bs.configs.RetrieveAll(ctx, session.DomainID, filter, offset, limit), nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Remove(ctx context.Context, session smqauthn.Session, id string) error {
|
||||
if err := bs.configs.Remove(ctx, session.DomainID, id); err != nil {
|
||||
return errors.Wrap(errRemoveBootstrap, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error) {
|
||||
cfg, err := bs.configs.RetrieveByExternalID(ctx, externalID)
|
||||
if err != nil {
|
||||
return cfg, errors.Wrap(ErrBootstrap, err)
|
||||
}
|
||||
if secure {
|
||||
dec, err := bs.dec(externalKey)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(ErrExternalKeySecure, err)
|
||||
}
|
||||
externalKey = dec
|
||||
}
|
||||
|
||||
if err := bs.hasher.Compare(externalKey, cfg.ExternalKey); err != nil {
|
||||
return Config{}, ErrExternalKey
|
||||
}
|
||||
if cfg.Status == DisabledStatus {
|
||||
return Config{}, ErrBootstrap
|
||||
}
|
||||
|
||||
cfg, err = bs.renderBootstrapConfig(ctx, cfg)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(ErrBootstrap, err)
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) renderBootstrapConfig(ctx context.Context, cfg Config) (Config, error) {
|
||||
if cfg.ProfileID == "" {
|
||||
return cfg, nil
|
||||
}
|
||||
if bs.profiles == nil || bs.bindings == nil || bs.renderer == nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, errors.New("profile rendering support not configured"))
|
||||
}
|
||||
|
||||
profile, err := bs.profiles.RetrieveByID(ctx, cfg.DomainID, cfg.ProfileID)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, err)
|
||||
}
|
||||
|
||||
bindings, err := bs.bindings.Retrieve(ctx, cfg.ID)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, err)
|
||||
}
|
||||
if err := validateRequiredBindings(profile, bindings); err != nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, err)
|
||||
}
|
||||
bindings, err = bs.decryptSecretSnapshots(bindings)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, err)
|
||||
}
|
||||
|
||||
rendered, err := bs.renderer.Render(profile, cfg, bindings)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errRenderBootstrap, err)
|
||||
}
|
||||
|
||||
cfg.Content = string(rendered)
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
|
||||
cfg, err := bs.changeConfigStatus(ctx, session.DomainID, id, EnabledStatus)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errEnableConfig, err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (Config, error) {
|
||||
cfg, err := bs.changeConfigStatus(ctx, session.DomainID, id, DisabledStatus)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(errDisableConfig, err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) changeConfigStatus(ctx context.Context, domainID, id string, status Status) (Config, error) {
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, domainID, id)
|
||||
if err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if cfg.Status == status {
|
||||
return cfg, nil
|
||||
}
|
||||
if err := bs.configs.ChangeStatus(ctx, domainID, id, status); err != nil {
|
||||
return Config{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
cfg.Status = status
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// --- Profile management ---
|
||||
|
||||
func (bs bootstrapService) CreateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) {
|
||||
if bs.profiles == nil {
|
||||
return Profile{}, errors.Wrap(errCreateProfile, errors.New("profile repository not configured"))
|
||||
}
|
||||
id, err := bs.idProvider.ID()
|
||||
if err != nil {
|
||||
return Profile{}, errors.Wrap(errCreateProfile, err)
|
||||
}
|
||||
p.ID = id
|
||||
p.DomainID = session.DomainID
|
||||
if p.ContentFormat == "" {
|
||||
p.ContentFormat = ContentFormatJSON
|
||||
}
|
||||
p.Version = 1
|
||||
if err := validateProfileBindingSlots(p); err != nil {
|
||||
return Profile{}, errors.Wrap(errCreateProfile, err)
|
||||
}
|
||||
if err := validateProfileTemplate(p); err != nil {
|
||||
return Profile{}, errors.Wrap(errCreateProfile, err)
|
||||
}
|
||||
saved, err := bs.profiles.Save(ctx, p)
|
||||
if err != nil {
|
||||
return Profile{}, errors.Wrap(errCreateProfile, err)
|
||||
}
|
||||
return saved, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (Profile, error) {
|
||||
if bs.profiles == nil {
|
||||
return Profile{}, errors.Wrap(errViewProfile, errors.New("profile repository not configured"))
|
||||
}
|
||||
p, err := bs.profiles.RetrieveByID(ctx, session.DomainID, profileID)
|
||||
if err != nil {
|
||||
return Profile{}, errors.Wrap(errViewProfile, err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) UpdateProfile(ctx context.Context, session smqauthn.Session, p Profile) (Profile, error) {
|
||||
if bs.profiles == nil {
|
||||
return Profile{}, errors.Wrap(errUpdateProfile, errors.New("profile repository not configured"))
|
||||
}
|
||||
p.DomainID = session.DomainID
|
||||
if err := validateProfileBindingSlots(p); err != nil {
|
||||
return Profile{}, errors.Wrap(errUpdateProfile, err)
|
||||
}
|
||||
if err := validateProfileTemplate(p); err != nil {
|
||||
return Profile{}, errors.Wrap(errUpdateProfile, err)
|
||||
}
|
||||
updated, err := bs.profiles.Update(ctx, p)
|
||||
if err != nil {
|
||||
return Profile{}, errors.Wrap(errUpdateProfile, err)
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (ProfilesPage, error) {
|
||||
if bs.profiles == nil {
|
||||
return ProfilesPage{}, errors.Wrap(errListProfiles, errors.New("profile repository not configured"))
|
||||
}
|
||||
page, err := bs.profiles.RetrieveAll(ctx, session.DomainID, offset, limit, name)
|
||||
if err != nil {
|
||||
return ProfilesPage{}, errors.Wrap(errListProfiles, err)
|
||||
}
|
||||
return page, nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
|
||||
if bs.profiles == nil {
|
||||
return errors.Wrap(errDeleteProfile, errors.New("profile repository not configured"))
|
||||
}
|
||||
if err := bs.profiles.Delete(ctx, session.DomainID, profileID); err != nil {
|
||||
return errors.Wrap(errDeleteProfile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Enrollment-profile assignment ---
|
||||
|
||||
func (bs bootstrapService) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
|
||||
if bs.profiles == nil {
|
||||
return errors.Wrap(errAssignProfile, errors.New("profile repository not configured"))
|
||||
}
|
||||
// Validate profile exists in domain.
|
||||
if _, err := bs.profiles.RetrieveByID(ctx, session.DomainID, profileID); err != nil {
|
||||
return errors.Wrap(errAssignProfile, err)
|
||||
}
|
||||
if err := bs.configs.AssignProfile(ctx, session.DomainID, configID, profileID); err != nil {
|
||||
return errors.Wrap(errAssignProfile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Binding management ---
|
||||
|
||||
func (bs bootstrapService) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, requested []BindingRequest) error {
|
||||
if bs.profiles == nil || bs.bindings == nil || bs.resolver == nil {
|
||||
return errors.Wrap(errBindResources, errors.New("binding support not configured"))
|
||||
}
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
profile, err := bs.profiles.RetrieveByID(ctx, session.DomainID, cfg.ProfileID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
if err := validateRequestedBindings(profile, requested); err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
snapshots, err := bs.resolver.Resolve(ctx, ResolveRequest{
|
||||
Enrollment: cfg,
|
||||
Token: token,
|
||||
Requested: requested,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
existing, err := bs.bindings.Retrieve(ctx, configID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
if err := validateRequiredBindings(profile, mergeBindingSnapshots(existing, snapshots)); err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
snapshots, err = bs.encryptSecretSnapshots(snapshots)
|
||||
if err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
if err := bs.bindings.Save(ctx, configID, snapshots); err != nil {
|
||||
return errors.Wrap(errBindResources, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]BindingSnapshot, error) {
|
||||
if bs.bindings == nil {
|
||||
return nil, errors.Wrap(errListBindings, errors.New("binding support not configured"))
|
||||
}
|
||||
if _, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID); err != nil {
|
||||
return nil, errors.Wrap(errListBindings, err)
|
||||
}
|
||||
snapshots, err := bs.bindings.Retrieve(ctx, configID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errListBindings, err)
|
||||
}
|
||||
return hideSecretSnapshots(snapshots), nil
|
||||
}
|
||||
|
||||
func (bs bootstrapService) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
|
||||
if bs.profiles == nil || bs.bindings == nil || bs.resolver == nil {
|
||||
return errors.Wrap(errRefreshBinding, errors.New("binding support not configured"))
|
||||
}
|
||||
cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, configID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
profile, err := bs.profiles.RetrieveByID(ctx, session.DomainID, cfg.ProfileID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
existing, err := bs.bindings.Retrieve(ctx, configID)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
if len(existing) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Re-resolve every existing binding to refresh its snapshot.
|
||||
requested := make([]BindingRequest, len(existing))
|
||||
for i, b := range existing {
|
||||
requested[i] = BindingRequest{Slot: b.Slot, Type: b.Type, ResourceID: b.ResourceID}
|
||||
}
|
||||
if err := validateRequestedBindings(profile, requested); err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
refreshed, err := bs.resolver.Resolve(ctx, ResolveRequest{
|
||||
Enrollment: cfg,
|
||||
Token: token,
|
||||
Requested: requested,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
if err := validateRequiredBindings(profile, refreshed); err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
refreshed, err = bs.encryptSecretSnapshots(refreshed)
|
||||
if err != nil {
|
||||
return errors.Wrap(errRefreshBinding, err)
|
||||
}
|
||||
return bs.bindings.Save(ctx, configID, refreshed)
|
||||
}
|
||||
|
||||
func (bs bootstrapService) dec(in string) (string, error) {
|
||||
ciphertext, err := hex.DecodeString(in)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(bs.encKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(ciphertext) < aes.BlockSize {
|
||||
return "", err
|
||||
}
|
||||
iv := ciphertext[:aes.BlockSize]
|
||||
ciphertext = ciphertext[aes.BlockSize:]
|
||||
stream := cipher.NewCFBDecrypter(block, iv)
|
||||
stream.XORKeyStream(ciphertext, ciphertext)
|
||||
return string(ciphertext), nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package tracing provides tracing instrumentation for Magistrala Users service.
|
||||
//
|
||||
// This package provides tracing middleware for Magistrala Users service.
|
||||
// It can be used to trace incoming requests and add tracing capabilities to
|
||||
// Magistrala Users service.
|
||||
//
|
||||
// For more details about tracing instrumentation for Magistrala messaging refer
|
||||
// to the documentation at https://magistrala.absmach.eu/docs/.
|
||||
package tracing
|
||||
@@ -1,198 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/bootstrap"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
var _ bootstrap.Service = (*tracingMiddleware)(nil)
|
||||
|
||||
type tracingMiddleware struct {
|
||||
tracer trace.Tracer
|
||||
svc bootstrap.Service
|
||||
}
|
||||
|
||||
// New returns a new bootstrap service with tracing capabilities.
|
||||
func New(svc bootstrap.Service, tracer trace.Tracer) bootstrap.Service {
|
||||
return &tracingMiddleware{tracer, svc}
|
||||
}
|
||||
|
||||
// Add traces the "Add" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_register_user", trace.WithAttributes(
|
||||
attribute.String("config_id", cfg.ID),
|
||||
attribute.String("domain_id ", cfg.DomainID),
|
||||
attribute.String("name", cfg.Name),
|
||||
attribute.String("external_id", cfg.ExternalID),
|
||||
attribute.String("content", cfg.Content),
|
||||
attribute.String("status", cfg.Status.String()),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Add(ctx, session, token, cfg)
|
||||
}
|
||||
|
||||
// View traces the "View" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_view_user", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.View(ctx, session, id)
|
||||
}
|
||||
|
||||
// Update traces the "Update" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_update_user", trace.WithAttributes(
|
||||
attribute.String("name", cfg.Name),
|
||||
attribute.String("content", cfg.Content),
|
||||
attribute.String("config_id", cfg.ID),
|
||||
attribute.String("domain_id ", cfg.DomainID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Update(ctx, session, cfg)
|
||||
}
|
||||
|
||||
// UpdateCert traces the "UpdateCert" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, id, clientCert, clientKey, caCert string) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_update_cert", trace.WithAttributes(
|
||||
attribute.String("config_id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateCert(ctx, session, id, clientCert, clientKey, caCert)
|
||||
}
|
||||
|
||||
// List traces the "List" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_users", trace.WithAttributes(
|
||||
attribute.Int64("offset", int64(offset)),
|
||||
attribute.Int64("limit", int64(limit)),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.List(ctx, session, filter, offset, limit)
|
||||
}
|
||||
|
||||
// Remove traces the "Remove" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_remove_user", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Remove(ctx, session, id)
|
||||
}
|
||||
|
||||
// Bootstrap traces the "Bootstrap" operation of the wrapped bootstrap.Service.
|
||||
func (tm *tracingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_bootstrap_user", trace.WithAttributes(
|
||||
attribute.String("external_id", externalID),
|
||||
attribute.Bool("secure", secure),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.Bootstrap(ctx, externalKey, externalID, secure)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) EnableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_enable_config", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.EnableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) DisableConfig(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_disable_config", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.DisableConfig(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) CreateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_create_profile", trace.WithAttributes(
|
||||
attribute.String("name", p.Name),
|
||||
attribute.String("domain_id", p.DomainID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.CreateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ViewProfile(ctx context.Context, session smqauthn.Session, profileID string) (bootstrap.Profile, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_view_profile", trace.WithAttributes(
|
||||
attribute.String("profile_id", profileID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.ViewProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) UpdateProfile(ctx context.Context, session smqauthn.Session, p bootstrap.Profile) (bootstrap.Profile, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_update_profile", trace.WithAttributes(
|
||||
attribute.String("profile_id", p.ID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.UpdateProfile(ctx, session, p)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListProfiles(ctx context.Context, session smqauthn.Session, offset, limit uint64, name string) (bootstrap.ProfilesPage, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_profiles", trace.WithAttributes(
|
||||
attribute.Int64("offset", int64(offset)),
|
||||
attribute.Int64("limit", int64(limit)),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.ListProfiles(ctx, session, offset, limit, name)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) DeleteProfile(ctx context.Context, session smqauthn.Session, profileID string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_delete_profile", trace.WithAttributes(
|
||||
attribute.String("profile_id", profileID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.DeleteProfile(ctx, session, profileID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) AssignProfile(ctx context.Context, session smqauthn.Session, configID, profileID string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_assign_profile", trace.WithAttributes(
|
||||
attribute.String("config_id", configID),
|
||||
attribute.String("profile_id", profileID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.AssignProfile(ctx, session, configID, profileID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) BindResources(ctx context.Context, session smqauthn.Session, token, configID string, bindings []bootstrap.BindingRequest) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_bind_resources", trace.WithAttributes(
|
||||
attribute.String("config_id", configID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.BindResources(ctx, session, token, configID, bindings)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListBindings(ctx context.Context, session smqauthn.Session, configID string) ([]bootstrap.BindingSnapshot, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_list_bindings", trace.WithAttributes(
|
||||
attribute.String("config_id", configID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.ListBindings(ctx, session, configID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RefreshBindings(ctx context.Context, session smqauthn.Session, token, configID string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "svc_refresh_bindings", trace.WithAttributes(
|
||||
attribute.String("config_id", configID),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.RefreshBindings(ctx, session, token, configID)
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
# Channels
|
||||
|
||||
The Channels service is a core component of Magistrala that manages communication channels between devices and applications. It handles channel creation, configuration, access control and message routing within the Magistrala ecosystem.
|
||||
|
||||
## Configuration
|
||||
|
||||
The service is configured using the following environment variables (unset variables use default values):
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------- | --------------------------------------------- | ------------------------------ |
|
||||
| `MG_CHANNELS_LOG_LEVEL` | Log level (debug, info, warn, error) | info |
|
||||
| `MG_CHANNELS_HTTP_HOST` | HTTP host for Channels service | localhost |
|
||||
| `MG_CHANNELS_HTTP_PORT` | HTTP port for Channels service | 9005 |
|
||||
| `MG_CHANNELS_SERVER_CERT` | Path to PEM encoded server certificate | "" |
|
||||
| `MG_CHANNELS_SERVER_KEY` | Path to PEM encoded server key file | "" |
|
||||
| `MG_CHANNELS_GRPC_HOST` | gRPC host for Channels service | localhost |
|
||||
| `MG_CHANNELS_GRPC_PORT` | gRPC port for Channels service | 7005 |
|
||||
| `MG_CHANNELS_DB_HOST` | Database host address | localhost |
|
||||
| `MG_CHANNELS_DB_PORT` | Database port | 5432 |
|
||||
| `MG_CHANNELS_DB_USER` | Database user | magistrala |
|
||||
| `MG_CHANNELS_DB_PASS` | Database password | magistrala |
|
||||
| `MG_CHANNELS_DB_NAME` | Name of the database used by the service | channels |
|
||||
| `MG_CHANNELS_DB_SSL_MODE` | Database connection SSL mode | disable |
|
||||
| `MG_CHANNELS_CACHE_URL` | Cache database URL | <redis://localhost:6379/0> |
|
||||
| `MG_JAEGER_URL` | Jaeger tracing server URL | <http://jaeger:4318/v1/traces> |
|
||||
| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | true |
|
||||
|
||||
## Features
|
||||
|
||||
- **Channel Management**: Create, update, delete and list channels
|
||||
- **Access Control**: Manage channel permissions and user access
|
||||
- **Message Routing**: Route messages between connected devices and services
|
||||
- **Channel Groups**: Organize channels into logical groups
|
||||
- **Metadata Support**: Attach custom metadata to channels
|
||||
- **Real-time Updates**: Live channel state synchronization
|
||||
|
||||
## Architecture
|
||||
|
||||
The service is built using:
|
||||
|
||||
- **Go**: Core service implementation
|
||||
- **gRPC**: Inter-service communication
|
||||
- **PostgreSQL**: Primary data storage
|
||||
- **Redis**: Caching and pub/sub messaging
|
||||
- **Docker**: Containerized deployment
|
||||
|
||||
### Channels Table
|
||||
|
||||
| Column | Type | Description |
|
||||
| ----------------- | ------------- | ----------------------------------------------------- |
|
||||
| `id` | VARCHAR(36) | UUID of the channel (primary key) |
|
||||
| `name` | VARCHAR(1024) | Human-readable name |
|
||||
| `domain_id` | VARCHAR(36) | Domain to which the channel belongs |
|
||||
| `parent_group_id` | VARCHAR(36) | Optional group parent |
|
||||
| `tags` | TEXT[] | Array of tags |
|
||||
| `metadata` | JSONB | Free-form structured metadata |
|
||||
| `created_by` | VARCHAR(254) | User that created the channel |
|
||||
| `created_at` | TIMESTAMPTZ | Timestamp of creation |
|
||||
| `updated_at` | TIMESTAMPTZ | Timestamp of last update |
|
||||
| `updated_by` | VARCHAR(254) | User that performed last update |
|
||||
| `status` | SMALLINT | 0 = enabled, 1 = disabled |
|
||||
| `route` | VARCHAR(36) | Optional route identifier unique within domain if set |
|
||||
|
||||
### Connections Table
|
||||
|
||||
| Column | Type | Description |
|
||||
| ------------ | ----------- | ----------------------------------------------- |
|
||||
| `channel_id` | VARCHAR(36) | Channel UUID |
|
||||
| `domain_id` | VARCHAR(36) | Domain of channel and client |
|
||||
| `client_id` | VARCHAR(36) | Client UUID |
|
||||
| `type` | SMALLINT | Connection type: `1 = Publish`, `2 = Subscribe` |
|
||||
|
||||
## Deployment
|
||||
|
||||
The service is available as a Docker container. Refer to the Docker Compose section for the `channels` service in `docker-compose.yaml` for deployment configuration.
|
||||
|
||||
To build and run locally:
|
||||
|
||||
```bash
|
||||
# download the latest version of the service
|
||||
git clone https://github.com/absmach/magistrala
|
||||
cd magistrala
|
||||
|
||||
# compile the channels
|
||||
make channels
|
||||
|
||||
make install
|
||||
|
||||
MG_CHANNELS_HTTP_HOST=localhost \
|
||||
MG_CHANNELS_HTTP_PORT=9005 \
|
||||
MG_CHANNELS_DB_HOST=localhost \
|
||||
MG_CHANNELS_DB_PORT=5432 \
|
||||
MG_CHANNELS_DB_USER=magistrala \MG_CHANNELS_DB_PASS=magistrala \MG_CHANNELS_DB_NAME=channels \
|
||||
$GOBIN/magistrala-channels
|
||||
```
|
||||
|
||||
### Running the Service
|
||||
|
||||
```bash
|
||||
# Set environment variables
|
||||
export MQ_CHANNELS_DB_HOST=localhost
|
||||
export MQ_CHANNELS_DB_PORT=5432
|
||||
|
||||
# Run the service
|
||||
go run cmd/main.go
|
||||
```
|
||||
|
||||
### Docker Deployment
|
||||
|
||||
```bash
|
||||
docker run -p 8180:8180 magistrala/channels
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run unit tests
|
||||
go test ./...
|
||||
|
||||
# Run integration tests
|
||||
make test-integration
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The Channels service supports the following operations:
|
||||
|
||||
| Operation | Description |
|
||||
| --------------- | -------------------------------------------- |
|
||||
| `create` | Create a new channel |
|
||||
| `list` | Retrieve all channels (paged) |
|
||||
| `get` | Retrieve a single channel by ID |
|
||||
| `update` | Update a channel’s name & metadata |
|
||||
| `delete` | Permanently delete a channel |
|
||||
| `enable` | Enable a previously disabled channel |
|
||||
| `disable` | Disable an active channel |
|
||||
| `set-parent` | Assign a parent group to a channel |
|
||||
| `remove-parent` | Remove parent group from a channel |
|
||||
| `connect` | Connect one or more clients to channels |
|
||||
| `disconnect` | Disconnect one or more clients from channels |
|
||||
|
||||
### Example: Create a Channel
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9005/<domainID>/channels \
|
||||
-H "Authorization: Bearer <your_access_token>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "myChannel",
|
||||
"metadata": { "location": "lab" },
|
||||
"route": "sensor-data",
|
||||
"tags": ["sensor","edge"],
|
||||
"status": "enabled"
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Connect Clients & Channels
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9005/<domainID>/channels/connect \
|
||||
-H "Authorization: Bearer <your_access_token>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"channel_ids": ["<chanID1>", "<chanID2>"],
|
||||
"client_ids": ["<clientID1>", "<clientID2>"],
|
||||
"types": ["publish", "subscribe"]
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Disconnect Clients from a Channel
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:9005/<domainID>/channels/disconnect \
|
||||
-H "Authorization: Bearer <your_access_token>" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"channel_ids": ["<chanID>"],
|
||||
"client_ids": ["<clientID>"],
|
||||
"types": ["publish"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
- Use tags and metadata to manage and categorize channels (e.g., environment, region, purpose).
|
||||
- Assign `route` thoughtfully when channels need a predictable identifier.
|
||||
- Keep channel hierarchies shallow for easier navigation (avoid deep nesting unless required).
|
||||
- Use `disable` rather than immediate delete when you want to suspend a channel temporarily.
|
||||
- Clean up unused connections: regularly review which clients are connected to channels and remove stale links.
|
||||
- Enforce minimal privileges: only allow clients to connect to channels they truly need.
|
||||
- Monitoring: use the `/health` endpoint and version metadata for service stability.
|
||||
|
||||
## Versioning & Health Check
|
||||
|
||||
The Channels service exposes a `/health` endpoint to provide operational status and version info.
|
||||
|
||||
### Health Check Request
|
||||
|
||||
```bash
|
||||
curl -X GET http://localhost:9005/health \
|
||||
-H "accept: application/health+json"
|
||||
```
|
||||
|
||||
### Example Response
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "pass",
|
||||
"version": "0.18.0",
|
||||
"commit": "<commit-hash>",
|
||||
"description": "channels service",
|
||||
"build_time": "2025-11-19T..."
|
||||
}
|
||||
```
|
||||
@@ -1,224 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
kitgrpc "github.com/go-kit/kit/transport/grpc"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
const svcName = "channels.v1.ChannelsService"
|
||||
|
||||
var _ grpcChannelsV1.ChannelsServiceClient = (*grpcClient)(nil)
|
||||
|
||||
type grpcClient struct {
|
||||
timeout time.Duration
|
||||
authorize endpoint.Endpoint
|
||||
removeClientConnections endpoint.Endpoint
|
||||
unsetParentGroupFromChannels endpoint.Endpoint
|
||||
retrieveEntity endpoint.Endpoint
|
||||
retrieveIDByRoute endpoint.Endpoint
|
||||
}
|
||||
|
||||
// NewClient returns new gRPC client instance.
|
||||
func NewClient(conn *grpc.ClientConn, timeout time.Duration) grpcChannelsV1.ChannelsServiceClient {
|
||||
return &grpcClient{
|
||||
authorize: kitgrpc.NewClient(
|
||||
conn,
|
||||
svcName,
|
||||
"Authorize",
|
||||
encodeAuthorizeRequest,
|
||||
decodeAuthorizeResponse,
|
||||
grpcChannelsV1.AuthzRes{},
|
||||
).Endpoint(),
|
||||
removeClientConnections: kitgrpc.NewClient(
|
||||
conn,
|
||||
svcName,
|
||||
"RemoveClientConnections",
|
||||
encodeRemoveClientConnectionsRequest,
|
||||
decodeRemoveClientConnectionsResponse,
|
||||
grpcChannelsV1.RemoveClientConnectionsRes{},
|
||||
).Endpoint(),
|
||||
unsetParentGroupFromChannels: kitgrpc.NewClient(
|
||||
conn,
|
||||
svcName,
|
||||
"UnsetParentGroupFromChannels",
|
||||
encodeUnsetParentGroupFromChannelsRequest,
|
||||
decodeUnsetParentGroupFromChannelsResponse,
|
||||
grpcChannelsV1.UnsetParentGroupFromChannelsRes{},
|
||||
).Endpoint(),
|
||||
retrieveEntity: kitgrpc.NewClient(
|
||||
conn,
|
||||
svcName,
|
||||
"RetrieveEntity",
|
||||
encodeRetrieveEntityRequest,
|
||||
decodeRetrieveEntityResponse,
|
||||
grpcCommonV1.RetrieveEntityRes{},
|
||||
).Endpoint(),
|
||||
retrieveIDByRoute: kitgrpc.NewClient(
|
||||
conn,
|
||||
svcName,
|
||||
"RetrieveIDByRoute",
|
||||
encodeRetrieveIDByRouteRequest,
|
||||
decodeRetrieveIDByRouteResponse,
|
||||
grpcCommonV1.RetrieveEntityRes{},
|
||||
).Endpoint(),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (client grpcClient) Authorize(ctx context.Context, req *grpcChannelsV1.AuthzReq, _ ...grpc.CallOption) (r *grpcChannelsV1.AuthzRes, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.authorize(ctx, authorizeReq{
|
||||
domainID: req.GetDomainId(),
|
||||
clientID: req.GetClientId(),
|
||||
clientType: req.GetClientType(),
|
||||
channelID: req.GetChannelId(),
|
||||
connType: connections.ConnType(req.GetType()),
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcChannelsV1.AuthzRes{}, decodeError(err)
|
||||
}
|
||||
|
||||
ar := res.(authorizeRes)
|
||||
|
||||
return &grpcChannelsV1.AuthzRes{Authorized: ar.authorized}, nil
|
||||
}
|
||||
|
||||
func encodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(authorizeReq)
|
||||
|
||||
return &grpcChannelsV1.AuthzReq{
|
||||
DomainId: req.domainID,
|
||||
ClientId: req.clientID,
|
||||
ClientType: req.clientType,
|
||||
ChannelId: req.channelID,
|
||||
Type: uint32(req.connType),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(*grpcChannelsV1.AuthzRes)
|
||||
|
||||
return authorizeRes{authorized: res.GetAuthorized()}, nil
|
||||
}
|
||||
|
||||
func (client grpcClient) RemoveClientConnections(ctx context.Context, req *grpcChannelsV1.RemoveClientConnectionsReq, _ ...grpc.CallOption) (r *grpcChannelsV1.RemoveClientConnectionsRes, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
if _, err := client.removeClientConnections(ctx, req); err != nil {
|
||||
return &grpcChannelsV1.RemoveClientConnectionsRes{}, decodeError(err)
|
||||
}
|
||||
|
||||
return &grpcChannelsV1.RemoveClientConnectionsRes{}, nil
|
||||
}
|
||||
|
||||
func encodeRemoveClientConnectionsRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return grpcReq.(*grpcChannelsV1.RemoveClientConnectionsReq), nil
|
||||
}
|
||||
|
||||
func decodeRemoveClientConnectionsResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes.(*grpcChannelsV1.RemoveClientConnectionsRes), nil
|
||||
}
|
||||
|
||||
func (client grpcClient) UnsetParentGroupFromChannels(ctx context.Context, req *grpcChannelsV1.UnsetParentGroupFromChannelsReq, _ ...grpc.CallOption) (r *grpcChannelsV1.UnsetParentGroupFromChannelsRes, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
if _, err := client.unsetParentGroupFromChannels(ctx, req); err != nil {
|
||||
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, decodeError(err)
|
||||
}
|
||||
|
||||
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, nil
|
||||
}
|
||||
|
||||
func encodeUnsetParentGroupFromChannelsRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return grpcReq.(*grpcChannelsV1.UnsetParentGroupFromChannelsReq), nil
|
||||
}
|
||||
|
||||
func decodeUnsetParentGroupFromChannelsResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes.(*grpcChannelsV1.UnsetParentGroupFromChannelsRes), nil
|
||||
}
|
||||
|
||||
func (client grpcClient) RetrieveEntity(ctx context.Context, req *grpcCommonV1.RetrieveEntityReq, _ ...grpc.CallOption) (r *grpcCommonV1.RetrieveEntityRes, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.retrieveEntity(ctx, req)
|
||||
if err != nil {
|
||||
return &grpcCommonV1.RetrieveEntityRes{}, decodeError(err)
|
||||
}
|
||||
|
||||
return res.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func encodeRetrieveEntityRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return grpcReq.(*grpcCommonV1.RetrieveEntityReq), nil
|
||||
}
|
||||
|
||||
func decodeRetrieveEntityResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func (client grpcClient) RetrieveIDByRoute(ctx context.Context, req *grpcCommonV1.RetrieveIDByRouteReq, _ ...grpc.CallOption) (r *grpcCommonV1.RetrieveEntityRes, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.retrieveIDByRoute(ctx, req)
|
||||
if err != nil {
|
||||
return &grpcCommonV1.RetrieveEntityRes{}, decodeError(err)
|
||||
}
|
||||
|
||||
return res.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func encodeRetrieveIDByRouteRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
return grpcReq.(*grpcCommonV1.RetrieveIDByRouteReq), nil
|
||||
}
|
||||
|
||||
func decodeRetrieveIDByRouteResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func decodeError(err error) error {
|
||||
if st, ok := status.FromError(err); ok {
|
||||
switch st.Code() {
|
||||
case codes.Unauthenticated:
|
||||
return errors.Wrap(svcerr.ErrAuthentication, errors.New(st.Message()))
|
||||
case codes.PermissionDenied:
|
||||
return errors.Wrap(svcerr.ErrAuthorization, errors.New(st.Message()))
|
||||
case codes.InvalidArgument:
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
|
||||
case codes.FailedPrecondition:
|
||||
return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message()))
|
||||
case codes.NotFound:
|
||||
return errors.Wrap(svcerr.ErrNotFound, errors.New(st.Message()))
|
||||
case codes.AlreadyExists:
|
||||
return errors.Wrap(svcerr.ErrConflict, errors.New(st.Message()))
|
||||
case codes.OK:
|
||||
if msg := st.Message(); msg != "" {
|
||||
return errors.Wrap(errors.ErrUnidentified, errors.New(msg))
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return errors.Wrap(fmt.Errorf("unexpected gRPC status: %s (status code:%v)", st.Code().String(), st.Code()), errors.New(st.Message()))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package grpc contains implementation of Auth service gRPC API.
|
||||
package grpc
|
||||
@@ -1,85 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
ch "github.com/absmach/magistrala/channels"
|
||||
channels "github.com/absmach/magistrala/channels/private"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
func authorizeEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(authorizeReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return authorizeRes{}, err
|
||||
}
|
||||
|
||||
if err := svc.Authorize(ctx, ch.AuthzReq{
|
||||
DomainID: req.domainID,
|
||||
ClientID: req.clientID,
|
||||
ClientType: req.clientType,
|
||||
ChannelID: req.channelID,
|
||||
Type: req.connType,
|
||||
}); err != nil {
|
||||
return authorizeRes{}, err
|
||||
}
|
||||
|
||||
return authorizeRes{authorized: true}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func removeClientConnectionsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(removeClientConnectionsReq)
|
||||
|
||||
if err := svc.RemoveClientConnections(ctx, req.clientID); err != nil {
|
||||
return removeClientConnectionsRes{}, err
|
||||
}
|
||||
|
||||
return removeClientConnectionsRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func unsetParentGroupFromChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(unsetParentGroupFromChannelsReq)
|
||||
|
||||
if err := svc.UnsetParentGroupFromChannels(ctx, req.parentGroupID); err != nil {
|
||||
return unsetParentGroupFromChannelsRes{}, err
|
||||
}
|
||||
|
||||
return unsetParentGroupFromChannelsRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveEntityEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(retrieveEntityReq)
|
||||
channel, err := svc.RetrieveByID(ctx, req.Id)
|
||||
if err != nil {
|
||||
return retrieveEntityRes{}, err
|
||||
}
|
||||
|
||||
return retrieveEntityRes{id: channel.ID, domain: channel.Domain, parentGroup: channel.ParentGroup, status: uint8(channel.Status)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func retrieveIDByRouteEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(retrieveIDByRouteReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return retrieveIDByRouteRes{}, err
|
||||
}
|
||||
|
||||
id, err := svc.RetrieveIDByRoute(ctx, req.route, req.domainID)
|
||||
if err != nil {
|
||||
return retrieveIDByRouteRes{}, err
|
||||
}
|
||||
|
||||
return retrieveIDByRouteRes{id: id}, nil
|
||||
}
|
||||
}
|
||||
@@ -1,345 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
ch "github.com/absmach/magistrala/channels"
|
||||
grpcapi "github.com/absmach/magistrala/channels/api/grpc"
|
||||
"github.com/absmach/magistrala/channels/private/mocks"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
const port = 7005
|
||||
|
||||
var (
|
||||
validID = testsutil.GenerateUUID(&testing.T{})
|
||||
validChannel = ch.Channel{
|
||||
ID: validID,
|
||||
Domain: testsutil.GenerateUUID(&testing.T{}),
|
||||
Status: channels.EnabledStatus,
|
||||
}
|
||||
)
|
||||
|
||||
func startGRPCServer(svc *mocks.Service, port int) *grpc.Server {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to obtain port: %s", err))
|
||||
}
|
||||
server := grpc.NewServer()
|
||||
grpcChannelsV1.RegisterChannelsServiceServer(server, grpcapi.NewServer(svc))
|
||||
go func() {
|
||||
if err := server.Serve(listener); err != nil {
|
||||
panic(fmt.Sprintf("failed to serve: %s", err))
|
||||
}
|
||||
}()
|
||||
return server
|
||||
}
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
defer server.GracefulStop()
|
||||
authAddr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
client := grpcapi.NewClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
clientID string
|
||||
clientType string
|
||||
channelID string
|
||||
connType connections.ConnType
|
||||
err error
|
||||
authzErr error
|
||||
res *grpcChannelsV1.AuthzRes
|
||||
code codes.Code
|
||||
}{
|
||||
{
|
||||
desc: "authorize successfully",
|
||||
domainID: validID,
|
||||
clientID: validID,
|
||||
clientType: policies.UserType,
|
||||
channelID: validID,
|
||||
connType: connections.Publish,
|
||||
res: &grpcChannelsV1.AuthzRes{Authorized: true},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "authorize with authorization error",
|
||||
domainID: validID,
|
||||
clientID: validID,
|
||||
clientType: policies.UserType,
|
||||
channelID: validID,
|
||||
connType: connections.Publish,
|
||||
res: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authzErr: svcerr.ErrAuthorization,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "authorize withnot found error",
|
||||
domainID: validID,
|
||||
clientID: validID,
|
||||
clientType: policies.UserType,
|
||||
channelID: validID,
|
||||
connType: connections.Publish,
|
||||
res: &grpcChannelsV1.AuthzRes{Authorized: false},
|
||||
authzErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
authReq := ch.AuthzReq{
|
||||
DomainID: tc.domainID,
|
||||
ClientID: tc.clientID,
|
||||
ClientType: tc.clientType,
|
||||
ChannelID: tc.channelID,
|
||||
Type: tc.connType,
|
||||
}
|
||||
svcCall := svc.On("Authorize", mock.Anything, authReq).Return(tc.authzErr)
|
||||
res, err := client.Authorize(context.Background(), &grpcChannelsV1.AuthzReq{
|
||||
DomainId: tc.domainID,
|
||||
ClientId: tc.clientID,
|
||||
ClientType: tc.clientType,
|
||||
ChannelId: tc.channelID,
|
||||
Type: uint32(tc.connType),
|
||||
})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.res, res, fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.res, res))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveClientConnections(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
defer server.GracefulStop()
|
||||
authAddr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
client := grpcapi.NewClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
clientID string
|
||||
err error
|
||||
code codes.Code
|
||||
}{
|
||||
{
|
||||
desc: "remove client connections successfully",
|
||||
clientID: validID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "remove client connections with error",
|
||||
clientID: validID,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RemoveClientConnections", mock.Anything, tc.clientID).Return(tc.err)
|
||||
res, err := client.RemoveClientConnections(context.Background(), &grpcChannelsV1.RemoveClientConnectionsReq{
|
||||
ClientId: tc.clientID,
|
||||
})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
assert.Equal(t, &grpcChannelsV1.RemoveClientConnectionsRes{}, res)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsetParentGroupFromChannelsEndpoint(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
defer server.GracefulStop()
|
||||
authAddr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
client := grpcapi.NewClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
parentGroupID string
|
||||
err error
|
||||
code codes.Code
|
||||
}{
|
||||
{
|
||||
desc: "unset parent group from channels successfully",
|
||||
parentGroupID: validID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "unset parent group from channels authorization error",
|
||||
parentGroupID: validID,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "unset parent group from channels with not found error",
|
||||
parentGroupID: validID,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("UnsetParentGroupFromChannels", mock.Anything, tc.parentGroupID).Return(tc.err)
|
||||
res, err := client.UnsetParentGroupFromChannels(context.Background(), &grpcChannelsV1.UnsetParentGroupFromChannelsReq{
|
||||
ParentGroupId: tc.parentGroupID,
|
||||
})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
assert.Equal(t, &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, res)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveEntity(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
defer server.GracefulStop()
|
||||
authAddr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
client := grpcapi.NewClient(conn, time.Second)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
svcRes ch.Channel
|
||||
resp *grpcCommonV1.RetrieveEntityRes
|
||||
code codes.Code
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve entity successfully",
|
||||
id: validID,
|
||||
svcRes: validChannel,
|
||||
resp: &grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: validChannel.ID,
|
||||
DomainId: validChannel.Domain,
|
||||
ParentGroupId: validChannel.ParentGroup,
|
||||
Status: uint32(validChannel.Status),
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve entity with error",
|
||||
id: validID,
|
||||
resp: &grpcCommonV1.RetrieveEntityRes{},
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RetrieveByID", mock.Anything, tc.id).Return(tc.svcRes, tc.err)
|
||||
res, err := client.RetrieveEntity(context.Background(), &grpcCommonV1.RetrieveEntityReq{
|
||||
Id: tc.id,
|
||||
})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp.Entity, res.Entity)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveIDByRoute(t *testing.T) {
|
||||
svc := new(mocks.Service)
|
||||
server := startGRPCServer(svc, port)
|
||||
defer server.GracefulStop()
|
||||
authAddr := fmt.Sprintf("localhost:%d", port)
|
||||
conn, _ := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
client := grpcapi.NewClient(conn, time.Second)
|
||||
|
||||
validRoute := "validRoute"
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
retrieveReq *grpcCommonV1.RetrieveIDByRouteReq
|
||||
svcRes string
|
||||
svcErr error
|
||||
retrieveRes *grpcCommonV1.RetrieveEntityRes
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve entity by route successfully",
|
||||
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
|
||||
Route: validRoute,
|
||||
DomainId: domainID,
|
||||
},
|
||||
svcRes: validID,
|
||||
retrieveRes: &grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: validID,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve entity by route with empty route",
|
||||
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
|
||||
Route: "",
|
||||
DomainId: domainID,
|
||||
},
|
||||
svcRes: "",
|
||||
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
|
||||
err: apiutil.ErrMissingRoute,
|
||||
},
|
||||
{
|
||||
desc: "retrieve entity by route with empty domain ID",
|
||||
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
|
||||
Route: validRoute,
|
||||
DomainId: "",
|
||||
},
|
||||
svcRes: "",
|
||||
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
|
||||
err: apiutil.ErrMissingDomainID,
|
||||
},
|
||||
{
|
||||
desc: "retrieve entity by route with invalid route",
|
||||
retrieveReq: &grpcCommonV1.RetrieveIDByRouteReq{
|
||||
Route: "invalidRoute",
|
||||
DomainId: domainID,
|
||||
},
|
||||
svcRes: "",
|
||||
svcErr: svcerr.ErrNotFound,
|
||||
retrieveRes: &grpcCommonV1.RetrieveEntityRes{},
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RetrieveIDByRoute", mock.Anything, tc.retrieveReq.Route, tc.retrieveReq.DomainId).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := client.RetrieveIDByRoute(context.Background(), tc.retrieveReq)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.retrieveRes.Entity, res.Entity)
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
)
|
||||
|
||||
var errDomainID = errors.New("domain id required for users")
|
||||
|
||||
type authorizeReq struct {
|
||||
domainID string
|
||||
channelID string
|
||||
clientID string
|
||||
clientType string
|
||||
connType connections.ConnType
|
||||
}
|
||||
|
||||
func (req authorizeReq) validate() error {
|
||||
if req.clientType == policies.UserType && req.domainID == "" {
|
||||
return errDomainID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type removeClientConnectionsReq struct {
|
||||
clientID string
|
||||
}
|
||||
|
||||
type unsetParentGroupFromChannelsReq struct {
|
||||
parentGroupID string
|
||||
}
|
||||
|
||||
type retrieveEntityReq struct {
|
||||
Id string
|
||||
}
|
||||
|
||||
type retrieveIDByRouteReq struct {
|
||||
route string
|
||||
domainID string
|
||||
}
|
||||
|
||||
func (req retrieveIDByRouteReq) validate() error {
|
||||
if req.route == "" {
|
||||
return apiutil.ErrMissingRoute
|
||||
}
|
||||
if req.domainID == "" {
|
||||
return apiutil.ErrMissingDomainID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
type authorizeRes struct {
|
||||
authorized bool
|
||||
}
|
||||
|
||||
type removeClientConnectionsRes struct{}
|
||||
|
||||
type unsetParentGroupFromChannelsRes struct{}
|
||||
|
||||
type channelBasic struct {
|
||||
id string
|
||||
domain string
|
||||
parentGroup string
|
||||
status uint8
|
||||
}
|
||||
|
||||
type retrieveEntityRes channelBasic
|
||||
|
||||
type retrieveIDByRouteRes struct {
|
||||
id string
|
||||
}
|
||||
@@ -1,211 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
grpcChannelsV1 "github.com/absmach/magistrala/api/grpc/channels/v1"
|
||||
grpcCommonV1 "github.com/absmach/magistrala/api/grpc/common/v1"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
smqauth "github.com/absmach/magistrala/auth"
|
||||
channels "github.com/absmach/magistrala/channels/private"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
kitgrpc "github.com/go-kit/kit/transport/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
var _ grpcChannelsV1.ChannelsServiceServer = (*grpcServer)(nil)
|
||||
|
||||
type grpcServer struct {
|
||||
grpcChannelsV1.UnimplementedChannelsServiceServer
|
||||
authorize kitgrpc.Handler
|
||||
removeClientConnections kitgrpc.Handler
|
||||
unsetParentGroupFromChannels kitgrpc.Handler
|
||||
retrieveEntity kitgrpc.Handler
|
||||
retrieveIDByRoute kitgrpc.Handler
|
||||
}
|
||||
|
||||
// NewServer returns new AuthServiceServer instance.
|
||||
func NewServer(svc channels.Service) grpcChannelsV1.ChannelsServiceServer {
|
||||
return &grpcServer{
|
||||
authorize: kitgrpc.NewServer(
|
||||
authorizeEndpoint(svc),
|
||||
decodeAuthorizeRequest,
|
||||
encodeAuthorizeResponse,
|
||||
),
|
||||
removeClientConnections: kitgrpc.NewServer(
|
||||
removeClientConnectionsEndpoint(svc),
|
||||
decodeRemoveClientConnectionsRequest,
|
||||
encodeRemoveClientConnectionsResponse,
|
||||
),
|
||||
unsetParentGroupFromChannels: kitgrpc.NewServer(
|
||||
unsetParentGroupFromChannelsEndpoint(svc),
|
||||
decodeUnsetParentGroupFromChannelsRequest,
|
||||
encodeUnsetParentGroupFromChannelsResponse,
|
||||
),
|
||||
retrieveEntity: kitgrpc.NewServer(
|
||||
retrieveEntityEndpoint(svc),
|
||||
decodeRetrieveEntityRequest,
|
||||
encodeRetrieveEntityResponse,
|
||||
),
|
||||
retrieveIDByRoute: kitgrpc.NewServer(
|
||||
retrieveIDByRouteEndpoint(svc),
|
||||
decodeRetrieveIDByRouteRequest,
|
||||
encodeRetrieveIDByRouteResponse,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *grpcServer) Authorize(ctx context.Context, req *grpcChannelsV1.AuthzReq) (*grpcChannelsV1.AuthzRes, error) {
|
||||
_, res, err := s.authorize.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, encodeError(err)
|
||||
}
|
||||
return res.(*grpcChannelsV1.AuthzRes), nil
|
||||
}
|
||||
|
||||
func decodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcChannelsV1.AuthzReq)
|
||||
|
||||
connType := connections.ConnType(req.GetType())
|
||||
if err := connections.CheckConnType(connType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return authorizeReq{
|
||||
domainID: req.GetDomainId(),
|
||||
clientID: req.GetClientId(),
|
||||
clientType: req.GetClientType(),
|
||||
channelID: req.GetChannelId(),
|
||||
connType: connType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(authorizeRes)
|
||||
return &grpcChannelsV1.AuthzRes{Authorized: res.authorized}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) RemoveClientConnections(ctx context.Context, req *grpcChannelsV1.RemoveClientConnectionsReq) (*grpcChannelsV1.RemoveClientConnectionsRes, error) {
|
||||
_, res, err := s.removeClientConnections.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, encodeError(err)
|
||||
}
|
||||
return res.(*grpcChannelsV1.RemoveClientConnectionsRes), nil
|
||||
}
|
||||
|
||||
func decodeRemoveClientConnectionsRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcChannelsV1.RemoveClientConnectionsReq)
|
||||
|
||||
return removeClientConnectionsReq{
|
||||
clientID: req.GetClientId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeRemoveClientConnectionsResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
_ = grpcRes.(removeClientConnectionsRes)
|
||||
return &grpcChannelsV1.RemoveClientConnectionsRes{}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) UnsetParentGroupFromChannels(ctx context.Context, req *grpcChannelsV1.UnsetParentGroupFromChannelsReq) (*grpcChannelsV1.UnsetParentGroupFromChannelsRes, error) {
|
||||
_, res, err := s.unsetParentGroupFromChannels.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, encodeError(err)
|
||||
}
|
||||
return res.(*grpcChannelsV1.UnsetParentGroupFromChannelsRes), nil
|
||||
}
|
||||
|
||||
func decodeUnsetParentGroupFromChannelsRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcChannelsV1.UnsetParentGroupFromChannelsReq)
|
||||
|
||||
return unsetParentGroupFromChannelsReq{
|
||||
parentGroupID: req.GetParentGroupId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeUnsetParentGroupFromChannelsResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
_ = grpcRes.(unsetParentGroupFromChannelsRes)
|
||||
return &grpcChannelsV1.UnsetParentGroupFromChannelsRes{}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) RetrieveEntity(ctx context.Context, req *grpcCommonV1.RetrieveEntityReq) (*grpcCommonV1.RetrieveEntityRes, error) {
|
||||
_, res, err := s.retrieveEntity.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, encodeError(err)
|
||||
}
|
||||
return res.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func decodeRetrieveEntityRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcCommonV1.RetrieveEntityReq)
|
||||
return retrieveEntityReq{
|
||||
Id: req.GetId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeRetrieveEntityResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(retrieveEntityRes)
|
||||
|
||||
return &grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: res.id,
|
||||
DomainId: res.domain,
|
||||
ParentGroupId: res.parentGroup,
|
||||
Status: uint32(res.status),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeRetrieveIDByRouteRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcCommonV1.RetrieveIDByRouteReq)
|
||||
return retrieveIDByRouteReq{
|
||||
route: req.GetRoute(),
|
||||
domainID: req.GetDomainId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func encodeRetrieveIDByRouteResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(retrieveIDByRouteRes)
|
||||
|
||||
return &grpcCommonV1.RetrieveEntityRes{
|
||||
Entity: &grpcCommonV1.EntityBasic{
|
||||
Id: res.id,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *grpcServer) RetrieveIDByRoute(ctx context.Context, req *grpcCommonV1.RetrieveIDByRouteReq) (*grpcCommonV1.RetrieveEntityRes, error) {
|
||||
_, res, err := s.retrieveIDByRoute.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, encodeError(err)
|
||||
}
|
||||
return res.(*grpcCommonV1.RetrieveEntityRes), nil
|
||||
}
|
||||
|
||||
func encodeError(err error) error {
|
||||
switch {
|
||||
case errors.Contains(err, nil):
|
||||
return nil
|
||||
case errors.Contains(err, errors.ErrMalformedEntity),
|
||||
err == apiutil.ErrInvalidAuthKey,
|
||||
err == apiutil.ErrMissingID,
|
||||
err == apiutil.ErrMissingMemberType,
|
||||
err == apiutil.ErrMissingPolicySub,
|
||||
err == apiutil.ErrMissingPolicyObj,
|
||||
err == apiutil.ErrMalformedPolicyAct:
|
||||
return status.Error(codes.InvalidArgument, err.Error())
|
||||
case errors.Contains(err, svcerr.ErrAuthentication),
|
||||
errors.Contains(err, smqauth.ErrKeyExpired),
|
||||
err == apiutil.ErrMissingEmail,
|
||||
err == apiutil.ErrBearerToken:
|
||||
return status.Error(codes.Unauthenticated, err.Error())
|
||||
case errors.Contains(err, svcerr.ErrAuthorization):
|
||||
return status.Error(codes.PermissionDenied, err.Error())
|
||||
default:
|
||||
return status.Error(codes.Internal, err.Error())
|
||||
}
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
api "github.com/absmach/magistrala/api/http"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/internal/nullable"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func decodeViewChannel(_ context.Context, r *http.Request) (any, error) {
|
||||
roles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false)
|
||||
if err != nil {
|
||||
return viewChannelReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
req := viewChannelReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
roles: roles,
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeCreateChannelReq(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := createChannelReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Channel); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeCreateChannelsReq(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := createChannelsReq{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req.Channels); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeListChannels(_ context.Context, r *http.Request) (any, error) {
|
||||
name, err := apiutil.ReadStringQuery(r, api.NameKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
tags, err := apiutil.ReadStringQuery(r, api.TagsKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
var tq channels.TagsQuery
|
||||
if tags != "" {
|
||||
tq = channels.ToTagsQuery(tags)
|
||||
}
|
||||
|
||||
s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
status, err := channels.ToStatus(s)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
actions := []string{}
|
||||
|
||||
allActions = strings.TrimSpace(allActions)
|
||||
if allActions != "" {
|
||||
actions = strings.Split(allActions, ",")
|
||||
}
|
||||
roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
userID, err := apiutil.ReadStringQuery(r, api.UserKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
groupID, err := nullable.Parse(r.URL.Query(), api.GroupKey, nullable.ParseString)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
id, err := apiutil.ReadStringQuery(r, api.IDOrder, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
ot, err := apiutil.ReadBoolQuery(r, api.OnlyTotal, false)
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
connectionType, err := apiutil.ReadStringQuery(r, api.ConnTypeKey, "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
cfrom, err := apiutil.ReadStringQuery(r, "created_from", "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
cto, err := apiutil.ReadStringQuery(r, "created_to", "")
|
||||
if err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
var createdFrom, createdTo time.Time
|
||||
if cfrom != "" {
|
||||
if createdFrom, err = time.Parse(time.RFC3339, cfrom); err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrInvalidQueryParams, err)
|
||||
}
|
||||
}
|
||||
if cto != "" {
|
||||
if createdTo, err = time.Parse(time.RFC3339, cto); err != nil {
|
||||
return listChannelsReq{}, errors.Wrap(apiutil.ErrInvalidQueryParams, err)
|
||||
}
|
||||
}
|
||||
|
||||
req := listChannelsReq{
|
||||
Page: channels.Page{
|
||||
Name: name,
|
||||
Tags: tq,
|
||||
Status: status,
|
||||
Metadata: meta,
|
||||
RoleName: roleName,
|
||||
RoleID: roleID,
|
||||
Actions: actions,
|
||||
AccessType: accessType,
|
||||
Order: order,
|
||||
Dir: dir,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
Group: groupID,
|
||||
Client: clientID,
|
||||
ConnectionType: connectionType,
|
||||
ID: id,
|
||||
OnlyTotal: ot,
|
||||
CreatedFrom: createdFrom,
|
||||
CreatedTo: createdTo,
|
||||
},
|
||||
userID: userID,
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeUpdateChannel(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := updateChannelReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeUpdateChannelTags(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := updateChannelTagsReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeSetChannelParentGroupStatus(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := setChannelParentGroupReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeRemoveChannelParentGroupStatus(_ context.Context, r *http.Request) (any, error) {
|
||||
req := removeChannelParentGroupReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeChangeChannelStatus(_ context.Context, r *http.Request) (any, error) {
|
||||
req := changeChannelStatusReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeDeleteChannelReq(_ context.Context, r *http.Request) (any, error) {
|
||||
req := deleteChannelReq{
|
||||
id: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeConnectChannelClientRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
req := connectChannelClientsRequest{
|
||||
channelID: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeDisconnectChannelClientsRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
req := disconnectChannelClientsRequest{
|
||||
channelID: chi.URLParam(r, "channelID"),
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeConnectRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := connectRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeDisconnectRequest(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
req := disconnectRequest{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,364 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/go-kit/kit/endpoint"
|
||||
)
|
||||
|
||||
func createChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(createChannelReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
channels, _, err := svc.CreateChannels(ctx, session, req.Channel)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return createChannelRes{
|
||||
Channel: channels[0],
|
||||
created: true,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(createChannelsReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
channels, _, err := svc.CreateChannels(ctx, session, req.Channels...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := channelsPageRes{
|
||||
pageRes: pageRes{
|
||||
Total: uint64(len(channels)),
|
||||
},
|
||||
Channels: []viewChannelRes{},
|
||||
}
|
||||
for _, c := range channels {
|
||||
res.Channels = append(res.Channels, viewChannelRes{Channel: c})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func viewChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(viewChannelReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
c, err := svc.ViewChannel(ctx, session, req.id, req.roles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return viewChannelRes{Channel: c}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(listChannelsReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
var page channels.ChannelsPage
|
||||
var err error
|
||||
switch req.userID != "" {
|
||||
case true:
|
||||
page, err = svc.ListUserChannels(ctx, session, req.userID, req.Page)
|
||||
default:
|
||||
page, err = svc.ListChannels(ctx, session, req.Page)
|
||||
}
|
||||
if err != nil {
|
||||
return channelsPageRes{}, err
|
||||
}
|
||||
|
||||
res := channelsPageRes{
|
||||
pageRes: pageRes{
|
||||
Total: page.Total,
|
||||
Offset: page.Offset,
|
||||
Limit: page.Limit,
|
||||
},
|
||||
Channels: []viewChannelRes{},
|
||||
}
|
||||
for _, c := range page.Channels {
|
||||
res.Channels = append(res.Channels, viewChannelRes{Channel: c})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(updateChannelReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
ch := channels.Channel{
|
||||
ID: req.id,
|
||||
Name: req.Name,
|
||||
Metadata: req.Metadata,
|
||||
}
|
||||
ch, err := svc.UpdateChannel(ctx, session, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updateChannelRes{Channel: ch}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func updateChannelTagsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(updateChannelTagsReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
ch := channels.Channel{
|
||||
ID: req.id,
|
||||
Tags: req.Tags,
|
||||
}
|
||||
ch, err := svc.UpdateChannelTags(ctx, session, ch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updateChannelRes{Channel: ch}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func setChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(setChannelParentGroupReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.SetParentGroup(ctx, session, req.ParentGroupID, req.id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return setChannelParentGroupRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func removeChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(removeChannelParentGroupReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.RemoveParentGroup(ctx, session, req.id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return removeChannelParentGroupRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func enableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(changeChannelStatusReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
ch, err := svc.EnableChannel(ctx, session, req.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return changeChannelStatusRes{Channel: ch}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func disableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(changeChannelStatusReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
ch, err := svc.DisableChannel(ctx, session, req.id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return changeChannelStatusRes{Channel: ch}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func connectChannelClientEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(connectChannelClientsRequest)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.Connect(ctx, session, []string{req.channelID}, req.ClientIDs, req.Types); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connectChannelClientsRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func disconnectChannelClientsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(disconnectChannelClientsRequest)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.Disconnect(ctx, session, []string{req.channelID}, req.ClientIds, req.Types); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return disconnectChannelClientsRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func connectEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(connectRequest)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.Connect(ctx, session, req.ChannelIds, req.ClientIds, req.Types); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return connectRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func disconnectEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(disconnectRequest)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.Disconnect(ctx, session, req.ChannelIds, req.ClientIds, req.Types); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return disconnectRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func deleteChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(deleteChannelReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.RemoveChannel(ctx, session, req.id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return deleteChannelRes{}, nil
|
||||
}
|
||||
}
|
||||
@@ -1,321 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
api "github.com/absmach/magistrala/api/http"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
)
|
||||
|
||||
type createChannelReq struct {
|
||||
Channel channels.Channel
|
||||
}
|
||||
|
||||
func (req createChannelReq) validate() error {
|
||||
if len(req.Channel.Name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
if req.Channel.ID != "" {
|
||||
if strings.TrimSpace(req.Channel.ID) == "" {
|
||||
return apiutil.ErrMissingChannelID
|
||||
}
|
||||
}
|
||||
if req.Channel.Route != "" {
|
||||
if err := api.ValidateRoute(req.Channel.Route); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := api.ValidateUUID(req.Channel.Route); err == nil {
|
||||
return apiutil.ErrInvalidRouteFormat
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type createChannelsReq struct {
|
||||
Channels []channels.Channel
|
||||
}
|
||||
|
||||
func (req createChannelsReq) validate() error {
|
||||
if len(req.Channels) == 0 {
|
||||
return apiutil.ErrEmptyList
|
||||
}
|
||||
for _, channel := range req.Channels {
|
||||
if channel.ID != "" {
|
||||
if strings.TrimSpace(channel.ID) == "" {
|
||||
return apiutil.ErrMissingChannelID
|
||||
}
|
||||
}
|
||||
if len(channel.Name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
if channel.Route != "" {
|
||||
if err := api.ValidateRoute(channel.Route); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := api.ValidateUUID(channel.Route); err == nil {
|
||||
return apiutil.ErrInvalidRouteFormat
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type viewChannelReq struct {
|
||||
id string
|
||||
roles bool
|
||||
}
|
||||
|
||||
func (req viewChannelReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type listChannelsReq struct {
|
||||
channels.Page
|
||||
userID string
|
||||
}
|
||||
|
||||
func (req listChannelsReq) validate() error {
|
||||
if req.Limit > api.MaxLimitSize || req.Limit < 1 {
|
||||
return apiutil.ErrLimitSize
|
||||
}
|
||||
|
||||
if len(req.Name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
|
||||
switch req.Order {
|
||||
case "", api.NameOrder, api.CreatedAtOrder, api.UpdatedAtOrder:
|
||||
default:
|
||||
return apiutil.ErrInvalidOrder
|
||||
}
|
||||
|
||||
if req.Dir != "" && (req.Dir != api.DescDir && req.Dir != api.AscDir) {
|
||||
return apiutil.ErrInvalidDirection
|
||||
}
|
||||
|
||||
if req.ConnectionType != "" {
|
||||
if _, err := connections.ParseConnType(req.ConnectionType); err != nil {
|
||||
return apiutil.ErrValidation
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateChannelReq struct {
|
||||
id string
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
func (req updateChannelReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
if len(req.Name) > api.MaxNameSize {
|
||||
return apiutil.ErrNameSize
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type updateChannelTagsReq struct {
|
||||
id string
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
func (req updateChannelTagsReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type setChannelParentGroupReq struct {
|
||||
id string
|
||||
ParentGroupID string `json:"parent_group_id"`
|
||||
}
|
||||
|
||||
func (req setChannelParentGroupReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
if req.ParentGroupID == "" {
|
||||
return apiutil.ErrMissingParentGroupID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type removeChannelParentGroupReq struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (req removeChannelParentGroupReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type changeChannelStatusReq struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (req changeChannelStatusReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type connectChannelClientsRequest struct {
|
||||
channelID string
|
||||
ClientIDs []string `json:"client_ids,omitempty"`
|
||||
Types []connections.ConnType `json:"types,omitempty"`
|
||||
}
|
||||
|
||||
func (req *connectChannelClientsRequest) validate() error {
|
||||
if req.channelID == "" || strings.TrimSpace(req.channelID) == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
if len(req.ClientIDs) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
for _, tid := range req.ClientIDs {
|
||||
if err := api.ValidateUUID(tid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Types) == 0 {
|
||||
return apiutil.ErrMissingConnectionType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type disconnectChannelClientsRequest struct {
|
||||
channelID string
|
||||
ClientIds []string `json:"client_ids,omitempty"`
|
||||
Types []connections.ConnType `json:"types,omitempty"`
|
||||
}
|
||||
|
||||
func (req *disconnectChannelClientsRequest) validate() error {
|
||||
if req.channelID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
if err := api.ValidateUUID(req.channelID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(req.ClientIds) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
for _, tid := range req.ClientIds {
|
||||
if err := api.ValidateUUID(tid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Types) == 0 {
|
||||
return apiutil.ErrMissingConnectionType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type connectRequest struct {
|
||||
ChannelIds []string `json:"channel_ids,omitempty"`
|
||||
ClientIds []string `json:"client_ids,omitempty"`
|
||||
Types []connections.ConnType `json:"types,omitempty"`
|
||||
}
|
||||
|
||||
func (req *connectRequest) validate() error {
|
||||
if len(req.ChannelIds) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
for _, cid := range req.ChannelIds {
|
||||
if strings.TrimSpace(cid) == "" {
|
||||
return apiutil.ErrMissingChannelID
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.ClientIds) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
for _, tid := range req.ClientIds {
|
||||
if strings.TrimSpace(tid) == "" {
|
||||
return apiutil.ErrMissingChannelID
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Types) == 0 {
|
||||
return apiutil.ErrMissingConnectionType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type disconnectRequest struct {
|
||||
ChannelIds []string `json:"channel_ids,omitempty"`
|
||||
ClientIds []string `json:"client_ids,omitempty"`
|
||||
Types []connections.ConnType `json:"types,omitempty"`
|
||||
}
|
||||
|
||||
func (req *disconnectRequest) validate() error {
|
||||
if len(req.ChannelIds) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
for _, cid := range req.ChannelIds {
|
||||
if err := api.ValidateUUID(cid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.ClientIds) == 0 {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
for _, tid := range req.ClientIds {
|
||||
if err := api.ValidateUUID(tid); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.Types) == 0 {
|
||||
return apiutil.ErrMissingConnectionType
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type deleteChannelReq struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (req deleteChannelReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,628 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
api "github.com/absmach/magistrala/api/http"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateChannelReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req createChannelReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: createChannelReq{
|
||||
Channel: channels.Channel{
|
||||
Name: valid,
|
||||
Route: valid,
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "long name",
|
||||
req: createChannelReq{
|
||||
Channel: channels.Channel{
|
||||
Name: strings.Repeat("a", api.MaxNameSize+1),
|
||||
Route: valid,
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "invalid route",
|
||||
req: createChannelReq{
|
||||
Channel: channels.Channel{
|
||||
Name: valid,
|
||||
Route: "__invalid",
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrInvalidRouteFormat,
|
||||
},
|
||||
{
|
||||
desc: "uuid as route",
|
||||
req: createChannelReq{
|
||||
Channel: channels.Channel{
|
||||
Name: valid,
|
||||
Route: testsutil.GenerateUUID(t),
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrInvalidRouteFormat,
|
||||
},
|
||||
{
|
||||
desc: "missing channel ID",
|
||||
req: createChannelReq{
|
||||
Channel: channels.Channel{
|
||||
ID: " ",
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrMissingChannelID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateChannelsReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req createChannelsReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: createChannelsReq{
|
||||
Channels: []channels.Channel{
|
||||
{
|
||||
Name: valid,
|
||||
Route: valid,
|
||||
},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "long name",
|
||||
req: createChannelsReq{
|
||||
Channels: []channels.Channel{
|
||||
{
|
||||
Name: strings.Repeat("a", api.MaxNameSize+1),
|
||||
Route: valid,
|
||||
},
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
{
|
||||
desc: "missing channel ID",
|
||||
req: createChannelsReq{
|
||||
Channels: []channels.Channel{
|
||||
{
|
||||
ID: " ",
|
||||
},
|
||||
},
|
||||
},
|
||||
err: apiutil.ErrMissingChannelID,
|
||||
},
|
||||
{
|
||||
desc: "empty list",
|
||||
req: createChannelsReq{
|
||||
Channels: []channels.Channel{},
|
||||
},
|
||||
err: apiutil.ErrEmptyList,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewChannelReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req viewChannelReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: viewChannelReq{
|
||||
id: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: viewChannelReq{
|
||||
id: "",
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListChannelsReqValidation(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req listChannelsReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: listChannelsReq{
|
||||
Page: channels.Page{Limit: 10},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "limit is 0",
|
||||
req: listChannelsReq{
|
||||
Page: channels.Page{Limit: 0},
|
||||
},
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "limit is greater than max limit",
|
||||
req: listChannelsReq{
|
||||
Page: channels.Page{Limit: api.MaxLimitSize + 1},
|
||||
},
|
||||
err: apiutil.ErrLimitSize,
|
||||
},
|
||||
{
|
||||
desc: "name is too long",
|
||||
req: listChannelsReq{
|
||||
Page: channels.Page{Limit: 10, Name: strings.Repeat("a", api.MaxNameSize+1)},
|
||||
},
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateChannelReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req updateChannelReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: updateChannelReq{
|
||||
id: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: updateChannelReq{
|
||||
id: "",
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "name is too long",
|
||||
req: updateChannelReq{
|
||||
id: valid,
|
||||
Name: strings.Repeat("a", api.MaxNameSize+1),
|
||||
},
|
||||
err: apiutil.ErrNameSize,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateChannelTagsReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req updateChannelTagsReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: updateChannelTagsReq{
|
||||
id: valid,
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: updateChannelTagsReq{
|
||||
id: "",
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetChannelsParentGroupReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req setChannelParentGroupReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: setChannelParentGroupReq{
|
||||
id: valid,
|
||||
ParentGroupID: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: setChannelParentGroupReq{
|
||||
id: "",
|
||||
ParentGroupID: valid,
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing parent group ID",
|
||||
req: setChannelParentGroupReq{
|
||||
id: valid,
|
||||
ParentGroupID: "",
|
||||
},
|
||||
err: apiutil.ErrMissingParentGroupID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveChannelParentGroupReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req removeChannelParentGroupReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: removeChannelParentGroupReq{
|
||||
id: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: removeChannelParentGroupReq{
|
||||
id: "",
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeChannelStatusReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req changeChannelStatusReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: changeChannelStatusReq{
|
||||
id: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: changeChannelStatusReq{
|
||||
id: "",
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectChannelClientsReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req connectChannelClientsRequest
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: connectChannelClientsRequest{
|
||||
channelID: valid,
|
||||
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing channel ID",
|
||||
req: connectChannelClientsRequest{
|
||||
channelID: "",
|
||||
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing client IDs",
|
||||
req: connectChannelClientsRequest{
|
||||
channelID: valid,
|
||||
ClientIDs: []string{},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing connection types",
|
||||
req: connectChannelClientsRequest{
|
||||
channelID: valid,
|
||||
ClientIDs: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{},
|
||||
},
|
||||
err: apiutil.ErrMissingConnectionType,
|
||||
},
|
||||
{
|
||||
desc: "invalid client ID",
|
||||
req: connectChannelClientsRequest{
|
||||
channelID: valid,
|
||||
ClientIDs: []string{"client1", "invalid"},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnectChannelClientReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req disconnectChannelClientsRequest
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing channel ID",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: "",
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "invalid channel ID",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: "invalid",
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
{
|
||||
desc: "missing client IDs",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
ClientIds: []string{},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing connection types",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{},
|
||||
},
|
||||
err: apiutil.ErrMissingConnectionType,
|
||||
},
|
||||
{
|
||||
desc: "invalid client ID",
|
||||
req: disconnectChannelClientsRequest{
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
ClientIds: []string{"client1", "invalid"},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req connectRequest
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: connectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing channel IDs",
|
||||
req: connectRequest{
|
||||
ChannelIds: []string{},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing client IDs",
|
||||
req: connectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing connection types",
|
||||
req: connectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{},
|
||||
},
|
||||
err: apiutil.ErrMissingConnectionType,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnectReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req disconnectRequest
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing channel IDs",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing client IDs",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "missing connection types",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{},
|
||||
},
|
||||
err: apiutil.ErrMissingConnectionType,
|
||||
},
|
||||
{
|
||||
desc: "invalid client ID",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{"client1", "invalid"},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
{
|
||||
desc: "invalid channel ID",
|
||||
req: disconnectRequest{
|
||||
ChannelIds: []string{"invalid", testsutil.GenerateUUID(t)},
|
||||
ClientIds: []string{testsutil.GenerateUUID(t), testsutil.GenerateUUID(t)},
|
||||
Types: []connections.ConnType{connections.Publish},
|
||||
},
|
||||
err: apiutil.ErrInvalidIDFormat,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteChannelReqValidate(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
req deleteChannelReq
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "valid request",
|
||||
req: deleteChannelReq{
|
||||
id: valid,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "missing ID",
|
||||
req: deleteChannelReq{
|
||||
id: "",
|
||||
},
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
err := tc.req.validate()
|
||||
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
@@ -1,221 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
)
|
||||
|
||||
var (
|
||||
_ magistrala.Response = (*createChannelRes)(nil)
|
||||
_ magistrala.Response = (*viewChannelRes)(nil)
|
||||
_ magistrala.Response = (*channelsPageRes)(nil)
|
||||
_ magistrala.Response = (*updateChannelRes)(nil)
|
||||
_ magistrala.Response = (*deleteChannelRes)(nil)
|
||||
_ magistrala.Response = (*connectChannelClientsRes)(nil)
|
||||
_ magistrala.Response = (*disconnectChannelClientsRes)(nil)
|
||||
_ magistrala.Response = (*connectRes)(nil)
|
||||
_ magistrala.Response = (*disconnectRes)(nil)
|
||||
_ magistrala.Response = (*changeChannelStatusRes)(nil)
|
||||
)
|
||||
|
||||
type pageRes struct {
|
||||
Limit uint64 `json:"limit,omitempty"`
|
||||
Offset uint64 `json:"offset,omitempty"`
|
||||
Total uint64 `json:"total"`
|
||||
}
|
||||
|
||||
type createChannelRes struct {
|
||||
channels.Channel
|
||||
created bool
|
||||
}
|
||||
|
||||
func (res createChannelRes) Code() int {
|
||||
if res.created {
|
||||
return http.StatusCreated
|
||||
}
|
||||
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res createChannelRes) Headers() map[string]string {
|
||||
if res.created {
|
||||
return map[string]string{
|
||||
"Location": fmt.Sprintf("/channels/%s", res.ID),
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res createChannelRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type viewChannelRes struct {
|
||||
channels.Channel
|
||||
}
|
||||
|
||||
func (res viewChannelRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res viewChannelRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res viewChannelRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type channelsPageRes struct {
|
||||
pageRes
|
||||
Channels []viewChannelRes `json:"channels,omitempty"`
|
||||
}
|
||||
|
||||
func (res channelsPageRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res channelsPageRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res channelsPageRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type changeChannelStatusRes struct {
|
||||
channels.Channel
|
||||
}
|
||||
|
||||
func (res changeChannelStatusRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res changeChannelStatusRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res changeChannelStatusRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type updateChannelRes struct {
|
||||
channels.Channel
|
||||
}
|
||||
|
||||
func (res updateChannelRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res updateChannelRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res updateChannelRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type setChannelParentGroupRes struct{}
|
||||
|
||||
func (res setChannelParentGroupRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res setChannelParentGroupRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res setChannelParentGroupRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type removeChannelParentGroupRes struct{}
|
||||
|
||||
func (res removeChannelParentGroupRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res removeChannelParentGroupRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res removeChannelParentGroupRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type deleteChannelRes struct{}
|
||||
|
||||
func (res deleteChannelRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res deleteChannelRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res deleteChannelRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type connectChannelClientsRes struct{}
|
||||
|
||||
func (res connectChannelClientsRes) Code() int {
|
||||
return http.StatusCreated
|
||||
}
|
||||
|
||||
func (res connectChannelClientsRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res connectChannelClientsRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type disconnectChannelClientsRes struct{}
|
||||
|
||||
func (res disconnectChannelClientsRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res disconnectChannelClientsRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res disconnectChannelClientsRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type connectRes struct{}
|
||||
|
||||
func (res connectRes) Code() int {
|
||||
return http.StatusCreated
|
||||
}
|
||||
|
||||
func (res connectRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res connectRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type disconnectRes struct{}
|
||||
|
||||
func (res disconnectRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res disconnectRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res disconnectRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/absmach/magistrala"
|
||||
api "github.com/absmach/magistrala/api/http"
|
||||
apiutil "github.com/absmach/magistrala/api/http/util"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
smqauthn "github.com/absmach/magistrala/pkg/authn"
|
||||
roleManagerHttp "github.com/absmach/magistrala/pkg/roles/rolemanager/api"
|
||||
"github.com/go-chi/chi/v5"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for Channels API endpoints.
|
||||
func MakeHandler(svc channels.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp magistrala.IDProvider) *chi.Mux {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
|
||||
d := roleManagerHttp.NewDecoder("channelID")
|
||||
|
||||
mux.Route("/{domainID}/channels", func(r chi.Router) {
|
||||
r.Use(authn.Middleware())
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
createChannelEndpoint(svc),
|
||||
decodeCreateChannelReq,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "create_channel").ServeHTTP)
|
||||
|
||||
r.Post("/bulk", otelhttp.NewHandler(kithttp.NewServer(
|
||||
createChannelsEndpoint(svc),
|
||||
decodeCreateChannelsReq,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "create_channels").ServeHTTP)
|
||||
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listChannelsEndpoint(svc),
|
||||
decodeListChannels,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "list_channels").ServeHTTP)
|
||||
|
||||
r.Post("/connect", otelhttp.NewHandler(kithttp.NewServer(
|
||||
connectEndpoint(svc),
|
||||
decodeConnectRequest,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "connect").ServeHTTP)
|
||||
|
||||
r.Post("/disconnect", otelhttp.NewHandler(kithttp.NewServer(
|
||||
disconnectEndpoint(svc),
|
||||
decodeDisconnectRequest,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "disconnect").ServeHTTP)
|
||||
|
||||
r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts)
|
||||
|
||||
r.Route("/{channelID}", func(r chi.Router) {
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewChannelEndpoint(svc),
|
||||
decodeViewChannel,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "view_channel").ServeHTTP)
|
||||
|
||||
r.Patch("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateChannelEndpoint(svc),
|
||||
decodeUpdateChannel,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "update_channel_name_and_metadata").ServeHTTP)
|
||||
|
||||
r.Patch("/tags", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateChannelTagsEndpoint(svc),
|
||||
decodeUpdateChannelTags,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "update_channel_tag").ServeHTTP)
|
||||
|
||||
r.Delete("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
deleteChannelEndpoint(svc),
|
||||
decodeDeleteChannelReq,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "delete_channel").ServeHTTP)
|
||||
|
||||
r.Post("/enable", otelhttp.NewHandler(kithttp.NewServer(
|
||||
enableChannelEndpoint(svc),
|
||||
decodeChangeChannelStatus,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "enable_channel").ServeHTTP)
|
||||
|
||||
r.Post("/disable", otelhttp.NewHandler(kithttp.NewServer(
|
||||
disableChannelEndpoint(svc),
|
||||
decodeChangeChannelStatus,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "disable_channel").ServeHTTP)
|
||||
|
||||
r.Post("/parent", otelhttp.NewHandler(kithttp.NewServer(
|
||||
setChannelParentGroupEndpoint(svc),
|
||||
decodeSetChannelParentGroupStatus,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "set_channel_parent_group").ServeHTTP)
|
||||
|
||||
r.Delete("/parent", otelhttp.NewHandler(kithttp.NewServer(
|
||||
removeChannelParentGroupEndpoint(svc),
|
||||
decodeRemoveChannelParentGroupStatus,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "remove_channel_parent_group").ServeHTTP)
|
||||
|
||||
r.Post("/connect", otelhttp.NewHandler(kithttp.NewServer(
|
||||
connectChannelClientEndpoint(svc),
|
||||
decodeConnectChannelClientRequest,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "connect_channel_client").ServeHTTP)
|
||||
|
||||
r.Post("/disconnect", otelhttp.NewHandler(kithttp.NewServer(
|
||||
disconnectChannelClientsEndpoint(svc),
|
||||
decodeDisconnectChannelClientsRequest,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "disconnect_channel_client").ServeHTTP)
|
||||
|
||||
roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts)
|
||||
})
|
||||
})
|
||||
|
||||
mux.Get("/health", magistrala.Health("channels", instanceID))
|
||||
mux.Handle("/metrics", promhttp.Handler())
|
||||
|
||||
return mux
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
package channels
|
||||
|
||||
import "github.com/absmach/magistrala/pkg/roles"
|
||||
|
||||
const BuiltInRoleAdmin roles.BuiltInRoleName = "admin"
|
||||
Vendored
-82
@@ -1,82 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyDomainID = errors.New("domain ID is empty")
|
||||
ErrEmptyChannelID = errors.New("channel ID is empty")
|
||||
ErrEmptyChannelRoute = errors.New("channel route is empty")
|
||||
)
|
||||
|
||||
type channelsCache struct {
|
||||
client *redis.Client
|
||||
duration time.Duration
|
||||
}
|
||||
|
||||
func NewChannelsCache(client *redis.Client, duration time.Duration) channels.Cache {
|
||||
return &channelsCache{
|
||||
client: client,
|
||||
duration: duration,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *channelsCache) Save(ctx context.Context, route, domainID, channelID string) error {
|
||||
key, err := encodeKey(domainID, route)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
if channelID == "" {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, ErrEmptyChannelID)
|
||||
}
|
||||
if err := cc.client.Set(ctx, key, channelID, cc.duration).Err(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cc *channelsCache) ID(ctx context.Context, channelRoute, domainID string) (string, error) {
|
||||
key, err := encodeKey(domainID, channelRoute)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(repoerr.ErrNotFound, err)
|
||||
}
|
||||
id, err := cc.client.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return "", errors.Wrap(repoerr.ErrNotFound, err)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (cc *channelsCache) Remove(ctx context.Context, channelRoute, domainID string) error {
|
||||
key, err := encodeKey(domainID, channelRoute)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if err := cc.client.Del(ctx, key).Err(); err != nil {
|
||||
return errors.Wrap(repoerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeKey(domainID, channelRoute string) (string, error) {
|
||||
if domainID == "" {
|
||||
return "", ErrEmptyDomainID
|
||||
}
|
||||
if channelRoute == "" {
|
||||
return "", ErrEmptyChannelRoute
|
||||
}
|
||||
return domainID + ":" + channelRoute, nil
|
||||
}
|
||||
Vendored
-186
@@ -1,186 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/channels/cache"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
testRoute = "test-route"
|
||||
nonExistent = "non-existing"
|
||||
)
|
||||
|
||||
func setupChannelsClient(t *testing.T) channels.Cache {
|
||||
opts, err := redis.ParseURL(redisURL)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err))
|
||||
redisClient := redis.NewClient(opts)
|
||||
|
||||
return cache.NewChannelsCache(redisClient, 10*time.Minute)
|
||||
}
|
||||
|
||||
func TestSave(t *testing.T) {
|
||||
cc := setupChannelsClient(t)
|
||||
|
||||
route := testRoute
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
channelID string
|
||||
channelRoute string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Save successfully",
|
||||
domainID: domainID,
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
channelRoute: route,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Save with empty domain ID",
|
||||
domainID: "",
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
channelRoute: route,
|
||||
err: cache.ErrEmptyDomainID,
|
||||
},
|
||||
{
|
||||
desc: "Save with empty channel ID",
|
||||
domainID: domainID,
|
||||
channelID: "",
|
||||
channelRoute: route,
|
||||
err: cache.ErrEmptyChannelID,
|
||||
},
|
||||
{
|
||||
desc: "Save with empty channel route",
|
||||
domainID: domainID,
|
||||
channelID: testsutil.GenerateUUID(t),
|
||||
channelRoute: "",
|
||||
err: cache.ErrEmptyChannelRoute,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
err := cc.Save(context.Background(), tc.channelRoute, tc.domainID, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestID(t *testing.T) {
|
||||
cc := setupChannelsClient(t)
|
||||
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
route := testRoute
|
||||
id := testsutil.GenerateUUID(t)
|
||||
|
||||
err := cc.Save(context.Background(), route, domainID, id)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on saving channel ID: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
channelRoute string
|
||||
channelID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Retrieve existing channel",
|
||||
domainID: domainID,
|
||||
channelRoute: route,
|
||||
channelID: id,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Retrieve non-existing channel",
|
||||
domainID: domainID,
|
||||
channelRoute: nonExistent,
|
||||
channelID: "",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "Retrieve with empty domain ID",
|
||||
domainID: "",
|
||||
channelRoute: route,
|
||||
channelID: "",
|
||||
err: cache.ErrEmptyDomainID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
id, err := cc.ID(context.Background(), tc.channelRoute, tc.domainID)
|
||||
assert.Equal(t, tc.channelID, id, fmt.Sprintf("expected channel ID '%s' got '%s'", tc.channelID, id))
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
cc := setupChannelsClient(t)
|
||||
|
||||
domainID := testsutil.GenerateUUID(t)
|
||||
route := testRoute
|
||||
id := testsutil.GenerateUUID(t)
|
||||
|
||||
err := cc.Save(context.Background(), domainID, route, id)
|
||||
assert.Nil(t, err, fmt.Sprintf("got unexpected error on saving channel ID: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
domainID string
|
||||
channelRoute string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Remove existing channel",
|
||||
domainID: domainID,
|
||||
channelRoute: route,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Remove non-existing channel",
|
||||
domainID: domainID,
|
||||
channelRoute: nonExistent,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Remove with empty domain ID",
|
||||
domainID: "",
|
||||
channelRoute: route,
|
||||
err: cache.ErrEmptyDomainID,
|
||||
},
|
||||
{
|
||||
desc: "Remove with empty channel route",
|
||||
domainID: domainID,
|
||||
channelRoute: "",
|
||||
err: cache.ErrEmptyChannelRoute,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
err := cc.Remove(context.Background(), tc.channelRoute, tc.domainID)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
|
||||
if tc.err == nil {
|
||||
id, err := cc.ID(context.Background(), tc.channelRoute, tc.domainID)
|
||||
assert.Equal(t, "", id, fmt.Sprintf("expected channel ID to be empty after removal, got '%s'", id))
|
||||
assert.True(t, errors.Contains(err, repoerr.ErrNotFound))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
-6
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package cache contains the domain concept definitions needed to
|
||||
// support Magistrala Channels cache service functionality.
|
||||
package cache
|
||||
Vendored
-61
@@ -1,61 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var (
|
||||
redisClient *redis.Client
|
||||
redisURL string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
pool, err := dockertest.NewPool("")
|
||||
if err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
container, err := pool.RunWithOptions(&dockertest.RunOptions{
|
||||
Repository: "redis",
|
||||
Tag: "7.2.4-alpine",
|
||||
}, func(config *docker.HostConfig) {
|
||||
config.AutoRemove = true
|
||||
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Could not start container: %s", err)
|
||||
}
|
||||
|
||||
redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp"))
|
||||
opts, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not parse redis URL: %s", err)
|
||||
}
|
||||
|
||||
if err := pool.Retry(func() error {
|
||||
redisClient = redis.NewClient(opts)
|
||||
|
||||
return redisClient.Ping(context.Background()).Err()
|
||||
}); err != nil {
|
||||
log.Fatalf("Could not connect to docker: %s", err)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
if err := pool.Purge(container); err != nil {
|
||||
log.Fatalf("Could not purge container: %s", err)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -1,240 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/internal/nullable"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
)
|
||||
|
||||
// Metadata represents arbitrary JSON.
|
||||
type Metadata map[string]any
|
||||
|
||||
// Channel represents a Magistrala "communication topic". This topic
|
||||
// contains the clients that can exchange messages between each other.
|
||||
type Channel struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
ParentGroup string `json:"parent_group_id,omitempty"`
|
||||
Domain string `json:"domain_id,omitempty"`
|
||||
Route string `json:"route,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
CreatedBy string `json:"created_by,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled
|
||||
// Extended
|
||||
ParentGroupPath string `json:"parent_group_path,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
AccessProviderId string `json:"access_provider_id,omitempty"`
|
||||
AccessProviderRoleId string `json:"access_provider_role_id,omitempty"`
|
||||
AccessProviderRoleName string `json:"access_provider_role_name,omitempty"`
|
||||
AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"`
|
||||
ConnectionTypes []connections.ConnType `json:"connection_types,omitempty"`
|
||||
MemberId string `json:"member_id,omitempty"`
|
||||
Roles []roles.MemberRoleActions `json:"roles,omitempty"`
|
||||
}
|
||||
|
||||
type Operator uint8
|
||||
|
||||
const (
|
||||
OrOp Operator = iota
|
||||
AndOp
|
||||
)
|
||||
|
||||
type TagsQuery struct {
|
||||
Elements []string
|
||||
Operator Operator
|
||||
}
|
||||
|
||||
func ToTagsQuery(s string) TagsQuery {
|
||||
switch {
|
||||
case strings.Contains(s, "+"):
|
||||
elements := strings.Split(s, "+")
|
||||
for i := range elements {
|
||||
elements[i] = strings.TrimSpace(elements[i])
|
||||
}
|
||||
return TagsQuery{Elements: elements, Operator: AndOp}
|
||||
case strings.Contains(s, ","):
|
||||
elements := strings.Split(s, ",")
|
||||
for i := range elements {
|
||||
elements[i] = strings.TrimSpace(elements[i])
|
||||
}
|
||||
return TagsQuery{Elements: elements, Operator: OrOp}
|
||||
default:
|
||||
return TagsQuery{Elements: []string{s}, Operator: OrOp}
|
||||
}
|
||||
}
|
||||
|
||||
type Page struct {
|
||||
Total uint64 `json:"total"`
|
||||
Offset uint64 `json:"offset"`
|
||||
Limit uint64 `json:"limit"`
|
||||
OnlyTotal bool `json:"only_total"`
|
||||
Order string `json:"order,omitempty"`
|
||||
Dir string `json:"dir,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Metadata Metadata `json:"metadata,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Tags TagsQuery `json:"tags,omitempty"`
|
||||
Status Status `json:"status,omitempty"`
|
||||
Group nullable.Value[string] `json:"group,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
ConnectionType string `json:"connection_type,omitempty"`
|
||||
RoleName string `json:"role_name,omitempty"`
|
||||
RoleID string `json:"role_id,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
AccessType string `json:"access_type,omitempty"`
|
||||
IDs []string `json:"-"`
|
||||
CreatedFrom time.Time `json:"created_from,omitempty"`
|
||||
CreatedTo time.Time `json:"created_to,omitempty"`
|
||||
}
|
||||
|
||||
// ChannelsPage contains page related metadata as well as list of channels that
|
||||
// belong to this page.
|
||||
type ChannelsPage struct {
|
||||
Page
|
||||
Channels []Channel
|
||||
}
|
||||
|
||||
type Connection struct {
|
||||
ClientID string
|
||||
ChannelID string
|
||||
DomainID string
|
||||
Type connections.ConnType
|
||||
}
|
||||
|
||||
type AuthzReq struct {
|
||||
DomainID string
|
||||
ChannelID string
|
||||
ClientID string
|
||||
ClientType string
|
||||
Type connections.ConnType
|
||||
}
|
||||
|
||||
type Service interface {
|
||||
// CreateChannels adds channels to the user.
|
||||
CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error)
|
||||
|
||||
// ViewChannel retrieves data about the channel identified by the provided
|
||||
// ID, that belongs to the user.
|
||||
ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (Channel, error)
|
||||
|
||||
// UpdateChannel updates the channel identified by the provided ID, that
|
||||
// belongs to the user.
|
||||
UpdateChannel(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
|
||||
|
||||
// UpdateChannelTags updates the channel's tags.
|
||||
UpdateChannelTags(ctx context.Context, session authn.Session, channel Channel) (Channel, error)
|
||||
|
||||
EnableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
|
||||
|
||||
DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error)
|
||||
|
||||
// ListChannels retrieves data about subset of channels that belongs to the user.
|
||||
ListChannels(ctx context.Context, session authn.Session, pm Page) (ChannelsPage, error)
|
||||
|
||||
// ListUserChannels retrieves data about subset of channels that belong to the specified user.
|
||||
ListUserChannels(ctx context.Context, session authn.Session, userID string, pm Page) (ChannelsPage, error)
|
||||
|
||||
// RemoveChannel removes the client identified by the provided ID, that
|
||||
// belongs to the user.
|
||||
RemoveChannel(ctx context.Context, session authn.Session, id string) error
|
||||
|
||||
// Connect adds clients to the channels list of connected clients.
|
||||
Connect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connType []connections.ConnType) error
|
||||
|
||||
// Disconnect removes clients from the channels list of connected clients.
|
||||
Disconnect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connType []connections.ConnType) error
|
||||
|
||||
SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error
|
||||
|
||||
RemoveParentGroup(ctx context.Context, session authn.Session, id string) error
|
||||
|
||||
roles.RoleManager
|
||||
}
|
||||
|
||||
// ChannelRepository specifies a channel persistence API.
|
||||
type Repository interface {
|
||||
// Save persists multiple channels. Channels are saved using a transaction. If one channel
|
||||
// fails then none will be saved. Successful operation is indicated by non-nil error response.
|
||||
Save(ctx context.Context, chs ...Channel) ([]Channel, error)
|
||||
|
||||
// Update performs an update to the existing channel.
|
||||
Update(ctx context.Context, c Channel) (Channel, error)
|
||||
|
||||
UpdateTags(ctx context.Context, ch Channel) (Channel, error)
|
||||
|
||||
ChangeStatus(ctx context.Context, channel Channel) (Channel, error)
|
||||
|
||||
// RetrieveUserChannels retrieves the channel of given domainID and userID.
|
||||
RetrieveUserChannels(ctx context.Context, domainID, userID string, pm Page) (ChannelsPage, error)
|
||||
|
||||
// RetrieveByID retrieves the channel having the provided identifier
|
||||
RetrieveByID(ctx context.Context, id string) (Channel, error)
|
||||
|
||||
// RetrieveByRoute retrieves the channel having the provided route
|
||||
RetrieveByRoute(ctx context.Context, route, domainID string) (Channel, error)
|
||||
|
||||
// RetrieveByIDWithRoles retrieves channel by its unique ID along with member roles.
|
||||
RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (Channel, error)
|
||||
|
||||
// RetrieveAll retrieves the subset of channels.
|
||||
RetrieveAll(ctx context.Context, pm Page) (ChannelsPage, error)
|
||||
|
||||
// Remove removes the channel having the provided identifier
|
||||
Remove(ctx context.Context, ids ...string) error
|
||||
|
||||
// SetParentGroup set parent group id to a given channel id
|
||||
SetParentGroup(ctx context.Context, ch Channel) error
|
||||
|
||||
// RemoveParentGroup remove parent group id fr given chanel id
|
||||
RemoveParentGroup(ctx context.Context, ch Channel) error
|
||||
|
||||
AddConnections(ctx context.Context, conns []Connection) error
|
||||
|
||||
RemoveConnections(ctx context.Context, conns []Connection) error
|
||||
|
||||
CheckConnection(ctx context.Context, conn Connection) error
|
||||
|
||||
ClientAuthorize(ctx context.Context, conn Connection) error
|
||||
|
||||
ChannelConnectionsCount(ctx context.Context, id string) (uint64, error)
|
||||
|
||||
DoesChannelHaveConnections(ctx context.Context, id string) (bool, error)
|
||||
|
||||
RemoveClientConnections(ctx context.Context, clientID string) error
|
||||
|
||||
RemoveChannelConnections(ctx context.Context, channelID string) error
|
||||
|
||||
RetrieveParentGroupChannels(ctx context.Context, parentGroupID string) ([]Channel, error)
|
||||
|
||||
UnsetParentGroupFromChannels(ctx context.Context, parentGroupID string) error
|
||||
|
||||
roles.Repository
|
||||
}
|
||||
|
||||
// Cache contains channels caching interface.
|
||||
type Cache interface {
|
||||
// Save stores the channelID for the given domain ID and channel route.
|
||||
Save(ctx context.Context, channelRoute, domainID, channelID string) error
|
||||
|
||||
// ID retrieves the channelID for the given domain ID and channel route.
|
||||
ID(ctx context.Context, channelRoute, domainID string) (string, error)
|
||||
|
||||
// Remove removes the channel ID for the given domain ID and channel route.
|
||||
Remove(ctx context.Context, channelRoute, domainID string) error
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package channels
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrInvalidStatus indicates invalid status.
|
||||
ErrInvalidStatus = errors.New("invalid channels status")
|
||||
|
||||
// ErrEnableChannel indicates error in enabling channel.
|
||||
ErrEnableChannel = errors.New("failed to enable channel")
|
||||
|
||||
// ErrDisableChannel indicates error in disabling channel.
|
||||
ErrDisableChannel = errors.New("failed to disable channel")
|
||||
)
|
||||
@@ -1,6 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package events provides the domain concept definitions
|
||||
// needed to support clients events functionality.
|
||||
package events
|
||||
@@ -1,384 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package events
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
)
|
||||
|
||||
const (
|
||||
channelPrefix = "channel."
|
||||
channelCreate = channelPrefix + "create"
|
||||
channelUpdate = channelPrefix + "update"
|
||||
channelUpdateTags = channelPrefix + "update_tags"
|
||||
channelEnable = channelPrefix + "enable"
|
||||
channelDisable = channelPrefix + "disable"
|
||||
channelRemove = channelPrefix + "remove"
|
||||
channelView = channelPrefix + "view"
|
||||
channelList = channelPrefix + "list"
|
||||
channelListByUser = channelPrefix + "list_by_user"
|
||||
channelConnect = channelPrefix + "connect"
|
||||
channelDisconnect = channelPrefix + "disconnect"
|
||||
channelSetParent = channelPrefix + "set_parent"
|
||||
channelRemoveParent = channelPrefix + "remove_parent"
|
||||
)
|
||||
|
||||
var (
|
||||
_ events.Event = (*createChannelEvent)(nil)
|
||||
_ events.Event = (*updateChannelEvent)(nil)
|
||||
_ events.Event = (*changeChannelStatusEvent)(nil)
|
||||
_ events.Event = (*viewChannelEvent)(nil)
|
||||
_ events.Event = (*listChannelEvent)(nil)
|
||||
_ events.Event = (*removeChannelEvent)(nil)
|
||||
_ events.Event = (*connectEvent)(nil)
|
||||
_ events.Event = (*disconnectEvent)(nil)
|
||||
)
|
||||
|
||||
type createChannelEvent struct {
|
||||
channels.Channel
|
||||
rolesProvisioned []roles.RoleProvision
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (cce createChannelEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": channelCreate,
|
||||
"id": cce.ID,
|
||||
"roles_provisioned": cce.rolesProvisioned,
|
||||
"route": cce.Route,
|
||||
"status": cce.Status.String(),
|
||||
"created_at": cce.CreatedAt,
|
||||
"domain": cce.DomainID,
|
||||
"user_id": cce.UserID,
|
||||
"token_type": cce.Type.String(),
|
||||
"super_admin": cce.SuperAdmin,
|
||||
"request_id": cce.requestID,
|
||||
}
|
||||
|
||||
if cce.Name != "" {
|
||||
val["name"] = cce.Name
|
||||
}
|
||||
if len(cce.Tags) > 0 {
|
||||
val["tags"] = cce.Tags
|
||||
}
|
||||
if cce.Metadata != nil {
|
||||
val["metadata"] = cce.Metadata
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type updateChannelEvent struct {
|
||||
channels.Channel
|
||||
authn.Session
|
||||
operation string
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (uce updateChannelEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": uce.operation,
|
||||
"updated_at": uce.UpdatedAt,
|
||||
"updated_by": uce.UpdatedBy,
|
||||
"domain": uce.DomainID,
|
||||
"user_id": uce.UserID,
|
||||
"token_type": uce.Type.String(),
|
||||
"super_admin": uce.SuperAdmin,
|
||||
"request_id": uce.requestID,
|
||||
}
|
||||
|
||||
if uce.ID != "" {
|
||||
val["id"] = uce.ID
|
||||
}
|
||||
if uce.Route != "" {
|
||||
val["route"] = uce.Route
|
||||
}
|
||||
if uce.Name != "" {
|
||||
val["name"] = uce.Name
|
||||
}
|
||||
if len(uce.Tags) > 0 {
|
||||
val["tags"] = uce.Tags
|
||||
}
|
||||
if uce.Metadata != nil {
|
||||
val["metadata"] = uce.Metadata
|
||||
}
|
||||
if !uce.CreatedAt.IsZero() {
|
||||
val["created_at"] = uce.CreatedAt
|
||||
}
|
||||
if uce.Status.String() != "" {
|
||||
val["status"] = uce.Status.String()
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type changeChannelStatusEvent struct {
|
||||
id string
|
||||
operation string
|
||||
status string
|
||||
updatedAt time.Time
|
||||
updatedBy string
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (cse changeChannelStatusEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": cse.operation,
|
||||
"id": cse.id,
|
||||
"status": cse.status,
|
||||
"updated_at": cse.updatedAt,
|
||||
"updated_by": cse.updatedBy,
|
||||
"domain": cse.DomainID,
|
||||
"user_id": cse.UserID,
|
||||
"token_type": cse.Type.String(),
|
||||
"super_admin": cse.SuperAdmin,
|
||||
"request_id": cse.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type viewChannelEvent struct {
|
||||
channels.Channel
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (vce viewChannelEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": channelView,
|
||||
"id": vce.ID,
|
||||
"domain": vce.DomainID,
|
||||
"user_id": vce.UserID,
|
||||
"token_type": vce.Type.String(),
|
||||
"super_admin": vce.SuperAdmin,
|
||||
"request_id": vce.requestID,
|
||||
}
|
||||
|
||||
if vce.Name != "" {
|
||||
val["name"] = vce.Name
|
||||
}
|
||||
if vce.Route != "" {
|
||||
val["route"] = vce.Route
|
||||
}
|
||||
if len(vce.Tags) > 0 {
|
||||
val["tags"] = vce.Tags
|
||||
}
|
||||
if vce.Metadata != nil {
|
||||
val["metadata"] = vce.Metadata
|
||||
}
|
||||
if !vce.CreatedAt.IsZero() {
|
||||
val["created_at"] = vce.CreatedAt
|
||||
}
|
||||
if !vce.UpdatedAt.IsZero() {
|
||||
val["updated_at"] = vce.UpdatedAt
|
||||
}
|
||||
if vce.UpdatedBy != "" {
|
||||
val["updated_by"] = vce.UpdatedBy
|
||||
}
|
||||
if vce.Status.String() != "" {
|
||||
val["status"] = vce.Status.String()
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type listChannelEvent struct {
|
||||
channels.Page
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (lce listChannelEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": channelList,
|
||||
"total": lce.Total,
|
||||
"offset": lce.Offset,
|
||||
"limit": lce.Limit,
|
||||
"domain": lce.DomainID,
|
||||
"user_id": lce.UserID,
|
||||
"token_type": lce.Type.String(),
|
||||
"super_admin": lce.SuperAdmin,
|
||||
"request_id": lce.requestID,
|
||||
}
|
||||
|
||||
if lce.Name != "" {
|
||||
val["name"] = lce.Name
|
||||
}
|
||||
if lce.Order != "" {
|
||||
val["order"] = lce.Order
|
||||
}
|
||||
if lce.Dir != "" {
|
||||
val["dir"] = lce.Dir
|
||||
}
|
||||
if lce.Metadata != nil {
|
||||
val["metadata"] = lce.Metadata
|
||||
}
|
||||
if len(lce.Tags.Elements) > 0 {
|
||||
val["tag"] = lce.Tags.Elements
|
||||
}
|
||||
if lce.Status.String() != "" {
|
||||
val["status"] = lce.Status.String()
|
||||
}
|
||||
if len(lce.IDs) > 0 {
|
||||
val["ids"] = lce.IDs
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type listUserChannelsEvent struct {
|
||||
userID string
|
||||
channels.Page
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (luce listUserChannelsEvent) Encode() (map[string]any, error) {
|
||||
val := map[string]any{
|
||||
"operation": channelListByUser,
|
||||
"req_user_id": luce.userID,
|
||||
"total": luce.Total,
|
||||
"offset": luce.Offset,
|
||||
"limit": luce.Limit,
|
||||
"domain": luce.DomainID,
|
||||
"user_id": luce.UserID,
|
||||
"token_type": luce.Type.String(),
|
||||
"super_admin": luce.SuperAdmin,
|
||||
"request_id": luce.requestID,
|
||||
}
|
||||
|
||||
if luce.Name != "" {
|
||||
val["name"] = luce.Name
|
||||
}
|
||||
if luce.Order != "" {
|
||||
val["order"] = luce.Order
|
||||
}
|
||||
if luce.Dir != "" {
|
||||
val["dir"] = luce.Dir
|
||||
}
|
||||
if luce.Metadata != nil {
|
||||
val["metadata"] = luce.Metadata
|
||||
}
|
||||
if luce.Domain != "" {
|
||||
val["domain"] = luce.Domain
|
||||
}
|
||||
if len(luce.Tags.Elements) > 0 {
|
||||
val["tag"] = luce.Tags.Elements
|
||||
}
|
||||
if luce.Status.String() != "" {
|
||||
val["status"] = luce.Status.String()
|
||||
}
|
||||
if len(luce.IDs) > 0 {
|
||||
val["ids"] = luce.IDs
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type removeChannelEvent struct {
|
||||
id string
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (dce removeChannelEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": channelRemove,
|
||||
"id": dce.id,
|
||||
"domain": dce.DomainID,
|
||||
"user_id": dce.UserID,
|
||||
"token_type": dce.Type.String(),
|
||||
"super_admin": dce.SuperAdmin,
|
||||
"request_id": dce.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type connectEvent struct {
|
||||
chIDs []string
|
||||
thIDs []string
|
||||
types []connections.ConnType
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (ce connectEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": channelConnect,
|
||||
"client_ids": ce.thIDs,
|
||||
"channel_ids": ce.chIDs,
|
||||
"types": ce.types,
|
||||
"domain": ce.DomainID,
|
||||
"user_id": ce.UserID,
|
||||
"token_type": ce.Type.String(),
|
||||
"super_admin": ce.SuperAdmin,
|
||||
"request_id": ce.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type disconnectEvent struct {
|
||||
chIDs []string
|
||||
thIDs []string
|
||||
types []connections.ConnType
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (de disconnectEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": channelDisconnect,
|
||||
"client_ids": de.thIDs,
|
||||
"channel_ids": de.chIDs,
|
||||
"types": de.types,
|
||||
"domain": de.DomainID,
|
||||
"user_id": de.UserID,
|
||||
"token_type": de.Type.String(),
|
||||
"super_admin": de.SuperAdmin,
|
||||
"request_id": de.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type setParentGroupEvent struct {
|
||||
id string
|
||||
parentGroupID string
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (spge setParentGroupEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": channelSetParent,
|
||||
"id": spge.id,
|
||||
"parent_group_id": spge.parentGroupID,
|
||||
"domain": spge.DomainID,
|
||||
"user_id": spge.UserID,
|
||||
"token_type": spge.Type.String(),
|
||||
"super_admin": spge.SuperAdmin,
|
||||
"request_id": spge.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type removeParentGroupEvent struct {
|
||||
id string
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (rpge removeParentGroupEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": channelRemoveParent,
|
||||
"id": rpge.id,
|
||||
"domain": rpge.DomainID,
|
||||
"user_id": rpge.UserID,
|
||||
"token_type": rpge.Type.String(),
|
||||
"super_admin": rpge.SuperAdmin,
|
||||
"request_id": rpge.requestID,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/events"
|
||||
"github.com/absmach/magistrala/pkg/events/store"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rmEvents "github.com/absmach/magistrala/pkg/roles/rolemanager/events"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
const (
|
||||
magistralaPrefix = "magistrala."
|
||||
createStream = magistralaPrefix + channelCreate
|
||||
updateStream = magistralaPrefix + channelUpdate
|
||||
updateTagsStream = magistralaPrefix + channelUpdateTags
|
||||
enableStream = magistralaPrefix + channelEnable
|
||||
disableStream = magistralaPrefix + channelDisable
|
||||
removeStream = magistralaPrefix + channelRemove
|
||||
viewStream = magistralaPrefix + channelView
|
||||
listStream = magistralaPrefix + channelList
|
||||
listByUserStream = magistralaPrefix + channelListByUser
|
||||
connectStream = magistralaPrefix + channelConnect
|
||||
disconnectStream = magistralaPrefix + channelDisconnect
|
||||
setParentStream = magistralaPrefix + channelSetParent
|
||||
removeParentStream = magistralaPrefix + channelRemoveParent
|
||||
)
|
||||
|
||||
var _ channels.Service = (*eventStore)(nil)
|
||||
|
||||
type eventStore struct {
|
||||
events.Publisher
|
||||
svc channels.Service
|
||||
rmEvents.RoleManagerEventStore
|
||||
}
|
||||
|
||||
// NewEventStoreMiddleware returns wrapper around clients service that sends
|
||||
// events to event store.
|
||||
func NewEventStoreMiddleware(ctx context.Context, svc channels.Service, url string) (channels.Service, error) {
|
||||
publisher, err := store.NewPublisher(ctx, url, "channels-es-pub")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rolesSvcEventStoreMiddleware := rmEvents.NewRoleManagerEventStore("channels", channelPrefix, svc, publisher)
|
||||
return &eventStore{
|
||||
svc: svc,
|
||||
Publisher: publisher,
|
||||
RoleManagerEventStore: rolesSvcEventStoreMiddleware,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
|
||||
chs, rps, err := es.svc.CreateChannels(ctx, session, chs...)
|
||||
if err != nil {
|
||||
return chs, rps, err
|
||||
}
|
||||
|
||||
for _, ch := range chs {
|
||||
event := createChannelEvent{
|
||||
Channel: ch,
|
||||
rolesProvisioned: rps,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, createStream, event); err != nil {
|
||||
return chs, rps, err
|
||||
}
|
||||
}
|
||||
|
||||
return chs, rps, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) UpdateChannel(ctx context.Context, session authn.Session, ch channels.Channel) (channels.Channel, error) {
|
||||
ch, err := es.svc.UpdateChannel(ctx, session, ch)
|
||||
if err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
event := updateChannelEvent{
|
||||
Channel: ch,
|
||||
Session: session,
|
||||
operation: channelUpdate,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, updateStream, event); err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) UpdateChannelTags(ctx context.Context, session authn.Session, ch channels.Channel) (channels.Channel, error) {
|
||||
ch, err := es.svc.UpdateChannelTags(ctx, session, ch)
|
||||
if err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
event := updateChannelEvent{
|
||||
Channel: ch,
|
||||
Session: session,
|
||||
operation: channelUpdateTags,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, updateTagsStream, event); err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
|
||||
chann, err := es.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
if err != nil {
|
||||
return chann, err
|
||||
}
|
||||
|
||||
event := viewChannelEvent{
|
||||
Channel: chann,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, viewStream, event); err != nil {
|
||||
return chann, err
|
||||
}
|
||||
|
||||
return chann, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
cp, err := es.svc.ListChannels(ctx, session, pm)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
event := listChannelEvent{
|
||||
Page: pm,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, listStream, event); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
cp, err := es.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
event := listUserChannelsEvent{
|
||||
userID: userID,
|
||||
Page: pm,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, listByUserStream, event); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
ch, err := es.svc.EnableChannel(ctx, session, id)
|
||||
if err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
return es.changeStatus(ctx, session, channelEnable, enableStream, ch)
|
||||
}
|
||||
|
||||
func (es *eventStore) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
ch, err := es.svc.DisableChannel(ctx, session, id)
|
||||
if err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
return es.changeStatus(ctx, session, channelDisable, disableStream, ch)
|
||||
}
|
||||
|
||||
func (es *eventStore) changeStatus(ctx context.Context, session authn.Session, operation, stream string, ch channels.Channel) (channels.Channel, error) {
|
||||
event := changeChannelStatusEvent{
|
||||
id: ch.ID,
|
||||
operation: operation,
|
||||
updatedAt: ch.UpdatedAt,
|
||||
updatedBy: ch.UpdatedBy,
|
||||
status: ch.Status.String(),
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, stream, event); err != nil {
|
||||
return ch, err
|
||||
}
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
if err := es.svc.RemoveChannel(ctx, session, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := removeChannelEvent{
|
||||
id: id,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, removeStream, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventStore) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
if err := es.svc.Connect(ctx, session, chIDs, thIDs, connTypes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := connectEvent{
|
||||
chIDs: chIDs,
|
||||
thIDs: thIDs,
|
||||
types: connTypes,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, connectStream, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventStore) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
if err := es.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := disconnectEvent{
|
||||
chIDs: chIDs,
|
||||
thIDs: thIDs,
|
||||
types: connTypes,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, disconnectStream, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventStore) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
|
||||
if err := es.svc.SetParentGroup(ctx, session, parentGroupID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := setParentGroupEvent{
|
||||
parentGroupID: parentGroupID,
|
||||
id: id,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, setParentStream, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventStore) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
|
||||
if err := es.svc.RemoveParentGroup(ctx, session, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := removeParentGroupEvent{
|
||||
id: id,
|
||||
Session: session,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
if err := es.Publish(ctx, removeParentStream, event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,669 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/channels/events"
|
||||
"github.com/absmach/magistrala/channels/mocks"
|
||||
"github.com/absmach/magistrala/internal/testsutil"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
storeClient *redis.Client
|
||||
storeURL string
|
||||
validSession = authn.Session{
|
||||
DomainID: testsutil.GenerateUUID(&testing.T{}),
|
||||
UserID: testsutil.GenerateUUID(&testing.T{}),
|
||||
}
|
||||
validChannel = generateTestChannel(&testing.T{})
|
||||
validChannelsPage = channels.ChannelsPage{
|
||||
Page: channels.Page{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
Total: 1,
|
||||
},
|
||||
Channels: []channels.Channel{validChannel},
|
||||
}
|
||||
)
|
||||
|
||||
func newEventStoreMiddleware(t *testing.T) (*mocks.Service, channels.Service) {
|
||||
svc := new(mocks.Service)
|
||||
nsvc, err := events.NewEventStoreMiddleware(context.Background(), svc, storeURL)
|
||||
require.Nil(t, err, fmt.Sprintf("create events store middleware failed with unexpected error: %s", err))
|
||||
|
||||
return svc, nsvc
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
code := testsutil.RunRedisTest(m, &storeClient, &storeURL)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestCreateChannels(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validID := testsutil.GenerateUUID(t)
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, validID)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channels []channels.Channel
|
||||
svcRes []channels.Channel
|
||||
svcRoleRes []roles.RoleProvision
|
||||
svcErr error
|
||||
resp []channels.Channel
|
||||
respRoleRes []roles.RoleProvision
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channels: []channels.Channel{validChannel},
|
||||
svcRes: []channels.Channel{validChannel},
|
||||
svcRoleRes: []roles.RoleProvision{},
|
||||
svcErr: nil,
|
||||
resp: []channels.Channel{validChannel},
|
||||
respRoleRes: []roles.RoleProvision{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channels: []channels.Channel{validChannel},
|
||||
svcRes: []channels.Channel{},
|
||||
svcRoleRes: []roles.RoleProvision{},
|
||||
svcErr: svcerr.ErrCreateEntity,
|
||||
resp: []channels.Channel{},
|
||||
respRoleRes: []roles.RoleProvision{},
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("CreateChannels", validCtx, tc.session, tc.channels).Return(tc.svcRes, tc.svcRoleRes, tc.svcErr)
|
||||
resp, respRoleRes, err := nsvc.CreateChannels(validCtx, tc.session, tc.channels...)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
assert.Equal(t, tc.respRoleRes, respRoleRes, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.respRoleRes, respRoleRes))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewChannel(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channelID string
|
||||
withRoles bool
|
||||
svcRes channels.Channel
|
||||
svcErr error
|
||||
resp channels.Channel
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
withRoles: false,
|
||||
svcRes: validChannel,
|
||||
svcErr: nil,
|
||||
resp: validChannel,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
withRoles: false,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
resp: channels.Channel{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("ViewChannel", validCtx, tc.session, tc.channelID, tc.withRoles).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.ViewChannel(validCtx, tc.session, tc.channelID, tc.withRoles)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateChannel(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
updatedChannel := validChannel
|
||||
updatedChannel.Name = "updatedName"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channel channels.Channel
|
||||
svcRes channels.Channel
|
||||
svcErr error
|
||||
resp channels.Channel
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channel: updatedChannel,
|
||||
svcRes: updatedChannel,
|
||||
svcErr: nil,
|
||||
resp: updatedChannel,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channel: updatedChannel,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
resp: channels.Channel{},
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("UpdateChannel", validCtx, tc.session, tc.channel).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.UpdateChannel(validCtx, tc.session, tc.channel)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateChannelTags(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
updatedChannel := validChannel
|
||||
updatedChannel.Tags = []string{"newTag1", "newTag2"}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channel channels.Channel
|
||||
svcRes channels.Channel
|
||||
svcErr error
|
||||
resp channels.Channel
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channel: updatedChannel,
|
||||
svcRes: updatedChannel,
|
||||
svcErr: nil,
|
||||
resp: updatedChannel,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channel: updatedChannel,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
resp: channels.Channel{},
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("UpdateChannelTags", validCtx, tc.session, tc.channel).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.UpdateChannelTags(validCtx, tc.session, tc.channel)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableChannel(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channelID string
|
||||
svcRes channels.Channel
|
||||
svcErr error
|
||||
resp channels.Channel
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcRes: validChannel,
|
||||
svcErr: nil,
|
||||
resp: validChannel,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
resp: channels.Channel{},
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("EnableChannel", validCtx, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.EnableChannel(validCtx, tc.session, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisableChannel(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channelID string
|
||||
svcRes channels.Channel
|
||||
svcErr error
|
||||
resp channels.Channel
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcRes: validChannel,
|
||||
svcErr: nil,
|
||||
resp: validChannel,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcRes: channels.Channel{},
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
resp: channels.Channel{},
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("DisableChannel", validCtx, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.DisableChannel(validCtx, tc.session, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListChannels(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
pageMeta channels.Page
|
||||
svcRes channels.ChannelsPage
|
||||
svcErr error
|
||||
resp channels.ChannelsPage
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
pageMeta: channels.Page{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
svcRes: validChannelsPage,
|
||||
svcErr: nil,
|
||||
resp: validChannelsPage,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
pageMeta: channels.Page{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
svcRes: channels.ChannelsPage{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
resp: channels.ChannelsPage{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("ListChannels", validCtx, tc.session, tc.pageMeta).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.ListChannels(validCtx, tc.session, tc.pageMeta)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserChannels(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
userID string
|
||||
pageMeta channels.Page
|
||||
svcRes channels.ChannelsPage
|
||||
svcErr error
|
||||
resp channels.ChannelsPage
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
userID: validSession.UserID,
|
||||
pageMeta: channels.Page{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
svcRes: validChannelsPage,
|
||||
svcErr: nil,
|
||||
resp: validChannelsPage,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
userID: validSession.UserID,
|
||||
pageMeta: channels.Page{
|
||||
Limit: 10,
|
||||
Offset: 0,
|
||||
},
|
||||
svcRes: channels.ChannelsPage{},
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
resp: channels.ChannelsPage{},
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("ListUserChannels", validCtx, tc.session, tc.userID, tc.pageMeta).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.ListUserChannels(validCtx, tc.session, tc.userID, tc.pageMeta)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveChannel(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channelID string
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcErr: svcerr.ErrRemoveEntity,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RemoveChannel", validCtx, tc.session, tc.channelID).Return(tc.svcErr)
|
||||
err := nsvc.RemoveChannel(validCtx, tc.session, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
chIDs []string
|
||||
clIDs []string
|
||||
connTypes []connections.ConnType
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
chIDs: []string{validChannel.ID},
|
||||
clIDs: []string{testsutil.GenerateUUID(t)},
|
||||
connTypes: []connections.ConnType{connections.Publish},
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
chIDs: []string{validChannel.ID},
|
||||
clIDs: []string{testsutil.GenerateUUID(t)},
|
||||
connTypes: []connections.ConnType{connections.Publish},
|
||||
svcErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Connect", validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes).Return(tc.svcErr)
|
||||
err := nsvc.Connect(validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisconnect(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
chIDs []string
|
||||
clIDs []string
|
||||
connTypes []connections.ConnType
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
chIDs: []string{validChannel.ID},
|
||||
clIDs: []string{testsutil.GenerateUUID(t)},
|
||||
connTypes: []connections.ConnType{connections.Publish},
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
chIDs: []string{validChannel.ID},
|
||||
clIDs: []string{testsutil.GenerateUUID(t)},
|
||||
connTypes: []connections.ConnType{connections.Publish},
|
||||
svcErr: svcerr.ErrRemoveEntity,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Disconnect", validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes).Return(tc.svcErr)
|
||||
err := nsvc.Disconnect(validCtx, tc.session, tc.chIDs, tc.clIDs, tc.connTypes)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetParentGroup(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
parentGroupID string
|
||||
channelID string
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
parentGroupID: testsutil.GenerateUUID(t),
|
||||
channelID: validChannel.ID,
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
parentGroupID: testsutil.GenerateUUID(t),
|
||||
channelID: validChannel.ID,
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("SetParentGroup", validCtx, tc.session, tc.parentGroupID, tc.channelID).Return(tc.svcErr)
|
||||
err := nsvc.SetParentGroup(validCtx, tc.session, tc.parentGroupID, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveParentGroup(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
channelID string
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
channelID: validChannel.ID,
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RemoveParentGroup", validCtx, tc.session, tc.channelID).Return(tc.svcErr)
|
||||
err := nsvc.RemoveParentGroup(validCtx, tc.session, tc.channelID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateTestChannel(t *testing.T) channels.Channel {
|
||||
createdAt, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z")
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error parsing time: %v", err))
|
||||
return channels.Channel{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Name: "channelname",
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
Tags: []string{"tag1", "tag2"},
|
||||
Metadata: channels.Metadata{"key1": "value1"},
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
Status: channels.EnabledStatus,
|
||||
}
|
||||
}
|
||||
@@ -1,383 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/absmach/magistrala/auth"
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/channels/operations"
|
||||
cOperations "github.com/absmach/magistrala/clients/operations"
|
||||
dOperations "github.com/absmach/magistrala/domains/operations"
|
||||
gOperations "github.com/absmach/magistrala/groups/operations"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
smqauthz "github.com/absmach/magistrala/pkg/authz"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/permissions"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rolemgr "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
|
||||
)
|
||||
|
||||
var (
|
||||
errView = errors.New("not authorized to view channel")
|
||||
errList = errors.New("not authorized to list user channels")
|
||||
errUpdate = errors.New("not authorized to update channel")
|
||||
errUpdateTags = errors.New("not authorized to update channel tags")
|
||||
errEnable = errors.New("not authorized to enable channel")
|
||||
errDisable = errors.New("not authorized to disable channel")
|
||||
errDelete = errors.New("not authorized to delete channel")
|
||||
errConnect = errors.New("not authorized to connect to channel")
|
||||
errDisconnect = errors.New("not authorized to disconnect from channel")
|
||||
errSetParentGroup = errors.New("not authorized to set parent group to channel")
|
||||
errRemoveParentGroup = errors.New("not authorized to remove parent group from channel")
|
||||
errDomainCreateChannels = errors.New("not authorized to create channel in domain")
|
||||
errGroupSetChildChannels = errors.New("not authorized to set child channel for group")
|
||||
errGroupRemoveChildChannels = errors.New("not authorized to remove child channel for group")
|
||||
errClientDisConnectChannels = errors.New("not authorized to disconnect channel for client")
|
||||
errClientConnectChannels = errors.New("not authorized to connect channel for client")
|
||||
)
|
||||
|
||||
var _ channels.Service = (*authorizationMiddleware)(nil)
|
||||
|
||||
type authorizationMiddleware struct {
|
||||
svc channels.Service
|
||||
repo channels.Repository
|
||||
authz smqauthz.Authorization
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation]
|
||||
rolemgr.RoleManagerAuthorizationMiddleware
|
||||
}
|
||||
|
||||
// NewAuthorization adds authorization to the channels service.
|
||||
func NewAuthorization(
|
||||
entityType string,
|
||||
svc channels.Service,
|
||||
authz smqauthz.Authorization,
|
||||
repo channels.Repository,
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation],
|
||||
roleOps permissions.Operations[permissions.RoleOperation],
|
||||
) (channels.Service, error) {
|
||||
if err := entitiesOps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ram, err := rolemgr.NewAuthorization(policies.ChannelType, svc, authz, roleOps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
authz: authz,
|
||||
repo: repo,
|
||||
entitiesOps: entitiesOps,
|
||||
RoleManagerAuthorizationMiddleware: ram,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
|
||||
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: session.DomainID,
|
||||
}); err != nil {
|
||||
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errDomainCreateChannels)
|
||||
}
|
||||
|
||||
for _, ch := range chs {
|
||||
if ch.ParentGroup != "" {
|
||||
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.GroupType,
|
||||
Object: ch.ParentGroup,
|
||||
}); err != nil {
|
||||
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errors.Wrap(errGroupSetChildChannels, fmt.Errorf("channel name %s parent group id %s", ch.Name, ch.ParentGroup)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return am.svc.CreateChannels(ctx, session, chs...)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpViewChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(err, errView)
|
||||
}
|
||||
|
||||
return am.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
switch err := am.checkSuperAdmin(ctx, session); {
|
||||
case err == nil:
|
||||
session.SuperAdmin = true
|
||||
case errors.Contains(err, svcerr.ErrSuperAdminAction):
|
||||
default:
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
return am.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
if err := am.checkSuperAdmin(ctx, session); err != nil {
|
||||
return channels.ChannelsPage{}, errors.Wrap(err, errList)
|
||||
}
|
||||
|
||||
return am.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: channel.ID,
|
||||
}); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(err, errUpdate)
|
||||
}
|
||||
|
||||
return am.svc.UpdateChannel(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: channel.ID,
|
||||
}); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(err, errUpdateTags)
|
||||
}
|
||||
|
||||
return am.svc.UpdateChannelTags(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpEnableChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(err, errEnable)
|
||||
}
|
||||
|
||||
return am.svc.EnableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisableChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return channels.Channel{}, errors.Wrap(err, errDisable)
|
||||
}
|
||||
|
||||
return am.svc.DisableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDeleteChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errDelete)
|
||||
}
|
||||
|
||||
return am.svc.RemoveChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
for _, chID := range chIDs {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpConnectClient, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: chID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errConnect)
|
||||
}
|
||||
}
|
||||
|
||||
for _, thID := range thIDs {
|
||||
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpConnectToChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ClientType,
|
||||
Object: thID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errClientConnectChannels)
|
||||
}
|
||||
}
|
||||
|
||||
return am.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
for _, chID := range chIDs {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisconnectClient, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: chID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errDisconnect)
|
||||
}
|
||||
}
|
||||
|
||||
for _, thID := range thIDs {
|
||||
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpDisconnectFromChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ClientType,
|
||||
Object: thID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errClientDisConnectChannels)
|
||||
}
|
||||
}
|
||||
|
||||
return am.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errSetParentGroup)
|
||||
}
|
||||
|
||||
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.GroupType,
|
||||
Object: parentGroupID,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errGroupSetChildChannels)
|
||||
}
|
||||
|
||||
return am.svc.SetParentGroup(ctx, session, parentGroupID, id)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
|
||||
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.ChannelType,
|
||||
Object: id,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errRemoveParentGroup)
|
||||
}
|
||||
|
||||
ch, err := am.repo.RetrieveByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
if ch.ParentGroup != "" {
|
||||
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupRemoveChildChannel, smqauthz.PolicyReq{
|
||||
Domain: session.DomainID,
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.DomainUserID,
|
||||
ObjectType: policies.GroupType,
|
||||
Object: ch.ParentGroup,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, errGroupRemoveChildChannels)
|
||||
}
|
||||
|
||||
return am.svc.RemoveParentGroup(ctx, session, id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, req smqauthz.PolicyReq) error {
|
||||
req.Domain = session.DomainID
|
||||
|
||||
perm, err := am.entitiesOps.GetPermission(entityType, op)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Permission = perm.String()
|
||||
|
||||
var pat *smqauthz.PATReq
|
||||
if session.PatID != "" {
|
||||
entityID := req.Object
|
||||
opName := am.entitiesOps.OperationName(entityType, op)
|
||||
if op == operations.OpListUserChannels || op == dOperations.OpCreateDomainChannels || op == dOperations.OpListDomainChannels {
|
||||
entityID = auth.AnyIDs
|
||||
}
|
||||
pat = &smqauthz.PATReq{
|
||||
UserID: session.UserID,
|
||||
PatID: session.PatID,
|
||||
EntityID: entityID,
|
||||
EntityType: patEntityType(entityType),
|
||||
Operation: opName,
|
||||
Domain: session.DomainID,
|
||||
}
|
||||
}
|
||||
|
||||
if err := am.authz.Authorize(ctx, req, pat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func patEntityType(entityType string) string {
|
||||
switch entityType {
|
||||
case policies.ClientType:
|
||||
return auth.ClientsType.String()
|
||||
default:
|
||||
return auth.ChannelsType.String()
|
||||
}
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error {
|
||||
if session.Role != authn.SuperAdminRole {
|
||||
return svcerr.ErrSuperAdminAction
|
||||
}
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: session.UserID,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.MagistralaObject,
|
||||
}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,247 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/channels/operations"
|
||||
dOperations "github.com/absmach/magistrala/domains/operations"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/callout"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
svcerr "github.com/absmach/magistrala/pkg/errors/service"
|
||||
"github.com/absmach/magistrala/pkg/permissions"
|
||||
"github.com/absmach/magistrala/pkg/policies"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
|
||||
)
|
||||
|
||||
var _ channels.Service = (*calloutMiddleware)(nil)
|
||||
|
||||
type calloutMiddleware struct {
|
||||
svc channels.Service
|
||||
repo channels.Repository
|
||||
callout callout.Callout
|
||||
entitiesOps permissions.EntitiesOperations[permissions.Operation]
|
||||
rolemw.RoleManagerCalloutMiddleware
|
||||
}
|
||||
|
||||
func NewCallout(svc channels.Service, repo channels.Repository, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation], callout callout.Callout) (channels.Service, error) {
|
||||
call, err := rolemw.NewCallout(policies.ChannelType, svc, callout, roleOps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := entitiesOps.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &calloutMiddleware{
|
||||
svc: svc,
|
||||
repo: repo,
|
||||
callout: callout,
|
||||
entitiesOps: entitiesOps,
|
||||
RoleManagerCalloutMiddleware: call,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
|
||||
params := map[string]any{
|
||||
"entities": chs,
|
||||
"count": len(chs),
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, params); err != nil {
|
||||
return []channels.Channel{}, []roles.RoleProvision{}, err
|
||||
}
|
||||
|
||||
return cm.svc.CreateChannels(ctx, session, chs...)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpViewChannel, params); err != nil {
|
||||
return channels.Channel{}, err
|
||||
}
|
||||
|
||||
return cm.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
params := map[string]any{
|
||||
"pagemeta": pm,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpListDomainChannels, params); err != nil {
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
return cm.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
params := map[string]any{
|
||||
"user_id": userID,
|
||||
"pagemeta": pm,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpListUserChannels, params); err != nil {
|
||||
return channels.ChannelsPage{}, err
|
||||
}
|
||||
|
||||
return cm.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": channel.ID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannel, params); err != nil {
|
||||
return channels.Channel{}, err
|
||||
}
|
||||
|
||||
return cm.svc.UpdateChannel(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": channel.ID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, params); err != nil {
|
||||
return channels.Channel{}, err
|
||||
}
|
||||
|
||||
return cm.svc.UpdateChannelTags(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpEnableChannel, params); err != nil {
|
||||
return channels.Channel{}, err
|
||||
}
|
||||
|
||||
return cm.svc.EnableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisableChannel, params); err != nil {
|
||||
return channels.Channel{}, err
|
||||
}
|
||||
|
||||
return cm.svc.DisableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDeleteChannel, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cm.svc.RemoveChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
params := map[string]any{
|
||||
"channel_ids": chIDs,
|
||||
"client_ids": thIDs,
|
||||
"connection_types": connTypes,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpConnectClient, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cm.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
params := map[string]any{
|
||||
"channel_ids": chIDs,
|
||||
"client_ids": thIDs,
|
||||
"connection_types": connTypes,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisconnectClient, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cm.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
"parent_group_id": parentGroupID,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpSetParentGroup, params); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return cm.svc.SetParentGroup(ctx, session, parentGroupID, id)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
|
||||
ch, err := cm.repo.RetrieveByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
if ch.ParentGroup != "" {
|
||||
params := map[string]any{
|
||||
"entity_id": id,
|
||||
"parent_group_id": ch.ParentGroup,
|
||||
}
|
||||
|
||||
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpRemoveParentGroup, params); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return cm.svc.RemoveParentGroup(ctx, session, id)
|
||||
}
|
||||
|
||||
func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, pld map[string]any) error {
|
||||
var entityID string
|
||||
if id, ok := pld["entity_id"].(string); ok {
|
||||
entityID = id
|
||||
}
|
||||
|
||||
req := callout.Request{
|
||||
BaseRequest: callout.BaseRequest{
|
||||
Operation: cm.entitiesOps.OperationName(entityType, op),
|
||||
EntityType: entityType,
|
||||
EntityID: entityID,
|
||||
CallerID: session.UserID,
|
||||
CallerType: policies.UserType,
|
||||
DomainID: session.DomainID,
|
||||
Time: time.Now().UTC(),
|
||||
},
|
||||
Payload: pld,
|
||||
}
|
||||
|
||||
if err := cm.callout.Callout(ctx, req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package middleware provides authorization, logging, metrics and tracing middleware
|
||||
// for Magistrala Channels Service.
|
||||
//
|
||||
// For more details about tracing instrumentation for Magistrala refer to the
|
||||
// documentation at https://magistrala.absmach.eu/docs/.
|
||||
package middleware
|
||||
@@ -1,294 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
)
|
||||
|
||||
var _ channels.Service = (*loggingMiddleware)(nil)
|
||||
|
||||
type loggingMiddleware struct {
|
||||
logger *slog.Logger
|
||||
svc channels.Service
|
||||
rolemw.RoleManagerLoggingMiddleware
|
||||
}
|
||||
|
||||
// NewLogging adds logging facilities to the channels service.
|
||||
func NewLogging(svc channels.Service, logger *slog.Logger) channels.Service {
|
||||
return &loggingMiddleware{logger, svc, rolemw.NewLogging("channels", svc, logger)}
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) CreateChannels(ctx context.Context, session authn.Session, clients ...channels.Channel) (cs []channels.Channel, rps []roles.RoleProvision, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn(fmt.Sprintf("Create %d channels failed", len(clients)), args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info(fmt.Sprintf("Create %d channel completed successfully", len(clients)), args...)
|
||||
}(time.Now())
|
||||
return lm.svc.CreateChannels(ctx, session, clients...)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (c channels.Channel, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("channel",
|
||||
slog.String("id", c.ID),
|
||||
slog.String("name", c.Name),
|
||||
slog.Bool("with_roles", withRoles),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("View channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("View channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (cp channels.ChannelsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("page",
|
||||
slog.Uint64("limit", pm.Limit),
|
||||
slog.Uint64("offset", pm.Offset),
|
||||
slog.Uint64("total", cp.Total),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List channels failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List channels completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (cp channels.ChannelsPage, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.String("user_id", userID),
|
||||
slog.Group("page",
|
||||
slog.Uint64("limit", pm.Limit),
|
||||
slog.Uint64("offset", pm.Offset),
|
||||
slog.Uint64("total", cp.Total),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List user channels failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List user channels completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("channel",
|
||||
slog.String("id", client.ID),
|
||||
slog.String("name", client.Name),
|
||||
slog.Any("metadata", client.Metadata),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Update channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Update channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.UpdateChannel(ctx, session, client)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("channel",
|
||||
slog.String("id", c.ID),
|
||||
slog.String("name", c.Name),
|
||||
slog.Any("tags", c.Tags),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args := append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Update channel tags failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Update channel tags completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.UpdateChannelTags(ctx, session, client)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (c channels.Channel, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("channel",
|
||||
slog.String("id", id),
|
||||
slog.String("name", c.Name),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Enable channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Enable channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.EnableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (c channels.Channel, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Group("channel",
|
||||
slog.String("id", id),
|
||||
slog.String("name", c.Name),
|
||||
),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Disable channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Disable channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.DisableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.String("channel_id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Delete channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Delete channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.RemoveChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connTypes []connections.ConnType) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Any("channel_ids", chIDs),
|
||||
slog.Any("client_ids", clIDs),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Connect channels and clients failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Connect channels and clients completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.Connect(ctx, session, chIDs, clIDs, connTypes)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, clIDs []string, connTypes []connections.ConnType) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.Any("channel_ids", chIDs),
|
||||
slog.Any("client_ids", clIDs),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Disconnect channels and clients failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Disconnect channels and clients completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.Disconnect(ctx, session, chIDs, clIDs, connTypes)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.String("parent_group_id", parentGroupID),
|
||||
slog.String("channel_id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Set parent group to channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Set parent group to channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.SetParentGroup(ctx, session, parentGroupID, id)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("domain_id", session.DomainID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
slog.String("channel_id", id),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Remove parent group from channel failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Remove parent group from channel completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.RemoveParentGroup(ctx, session, id)
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
)
|
||||
|
||||
var _ channels.Service = (*metricsMiddleware)(nil)
|
||||
|
||||
type metricsMiddleware struct {
|
||||
counter metrics.Counter
|
||||
latency metrics.Histogram
|
||||
svc channels.Service
|
||||
rolemw.RoleManagerMetricsMiddleware
|
||||
}
|
||||
|
||||
// NewMetrics returns a new metrics middleware wrapper.
|
||||
func NewMetrics(svc channels.Service, counter metrics.Counter, latency metrics.Histogram) channels.Service {
|
||||
return &metricsMiddleware{
|
||||
counter: counter,
|
||||
latency: latency,
|
||||
svc: svc,
|
||||
RoleManagerMetricsMiddleware: rolemw.NewMetrics("channels", svc, counter, latency),
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "register_channels").Add(1)
|
||||
ms.latency.With("method", "register_channels").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.CreateChannels(ctx, session, chs...)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "view_channel").Add(1)
|
||||
ms.latency.With("method", "view_channel").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_channels").Add(1)
|
||||
ms.latency.With("method", "list_channels").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_user_channels").Add(1)
|
||||
ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "update_channel").Add(1)
|
||||
ms.latency.With("method", "update_channel").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.UpdateChannel(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "update_channel_tags").Add(1)
|
||||
ms.latency.With("method", "update_channel_tags").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.UpdateChannelTags(ctx, session, channel)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "enable_channel").Add(1)
|
||||
ms.latency.With("method", "enable_channel").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.EnableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "disable_channel").Add(1)
|
||||
ms.latency.With("method", "disable_channel").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.DisableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "delete_channel").Add(1)
|
||||
ms.latency.With("method", "delete_channel").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.RemoveChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "connect").Add(1)
|
||||
ms.latency.With("method", "connect").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "disconnect").Add(1)
|
||||
ms.latency.With("method", "disconnect").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "set_parent_group").Add(1)
|
||||
ms.latency.With("method", "set_parent_group").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.SetParentGroup(ctx, session, parentGroupID, id)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "remove_parent_group").Add(1)
|
||||
ms.latency.With("method", "remove_parent_group").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.RemoveParentGroup(ctx, session, id)
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/channels"
|
||||
"github.com/absmach/magistrala/pkg/authn"
|
||||
"github.com/absmach/magistrala/pkg/connections"
|
||||
"github.com/absmach/magistrala/pkg/roles"
|
||||
rolemw "github.com/absmach/magistrala/pkg/roles/rolemanager/middleware"
|
||||
"github.com/absmach/magistrala/pkg/tracing"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
var _ channels.Service = (*tracingMiddleware)(nil)
|
||||
|
||||
type tracingMiddleware struct {
|
||||
tracer trace.Tracer
|
||||
svc channels.Service
|
||||
rolemw.RoleManagerTracing
|
||||
}
|
||||
|
||||
// NewTracing returns a new channels service with tracing capabilities.
|
||||
func NewTracing(svc channels.Service, tracer trace.Tracer) channels.Service {
|
||||
return &tracingMiddleware{tracer, svc, rolemw.NewTracing("channels", svc, tracer)}
|
||||
}
|
||||
|
||||
// CreateChannels traces the "CreateChannels" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_create_channel")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.CreateChannels(ctx, session, chs...)
|
||||
}
|
||||
|
||||
// ViewChannel traces the "ViewChannel" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_channel", trace.WithAttributes(attribute.String("id", id)))
|
||||
defer span.End()
|
||||
return tm.svc.ViewChannel(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
// ListChannels traces the "ListChannels" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_channels")
|
||||
defer span.End()
|
||||
return tm.svc.ListChannels(ctx, session, pm)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.Page) (channels.ChannelsPage, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_user_channels")
|
||||
defer span.End()
|
||||
return tm.svc.ListUserChannels(ctx, session, userID, pm)
|
||||
}
|
||||
|
||||
// UpdateChannel traces the "UpdateChannel" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, cli channels.Channel) (channels.Channel, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_update_channel", trace.WithAttributes(attribute.String("id", cli.ID)))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateChannel(ctx, session, cli)
|
||||
}
|
||||
|
||||
// UpdateChannelTags traces the "UpdateChannelTags" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, cli channels.Channel) (channels.Channel, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_update_channel_tags", trace.WithAttributes(
|
||||
attribute.String("id", cli.ID),
|
||||
attribute.StringSlice("tags", cli.Tags),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.UpdateChannelTags(ctx, session, cli)
|
||||
}
|
||||
|
||||
// EnableChannel traces the "EnableChannel" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_enable_channel", trace.WithAttributes(attribute.String("id", id)))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.EnableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
// DisableChannel traces the "DisableChannel" operation of the wrapped policies.Service.
|
||||
func (tm *tracingMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_disable_channel", trace.WithAttributes(attribute.String("id", id)))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.DisableChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
// DeleteChannel traces the "DeleteChannel" operation of the wrapped channels.Service.
|
||||
func (tm *tracingMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "delete_channel", trace.WithAttributes(attribute.String("id", id)))
|
||||
defer span.End()
|
||||
return tm.svc.RemoveChannel(ctx, session, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "connect", trace.WithAttributes(
|
||||
attribute.StringSlice("channel_ids", chIDs),
|
||||
attribute.StringSlice("client_ids", thIDs),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.Connect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "disconnect", trace.WithAttributes(
|
||||
attribute.StringSlice("channel_ids", chIDs),
|
||||
attribute.StringSlice("client_ids", thIDs),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.Disconnect(ctx, session, chIDs, thIDs, connTypes)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "set_parent_group", trace.WithAttributes(
|
||||
attribute.String("parent_group_id", parentGroupID),
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.SetParentGroup(ctx, session, parentGroupID, id)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "remove_parent_group", trace.WithAttributes(
|
||||
attribute.String("id", id),
|
||||
))
|
||||
defer span.End()
|
||||
return tm.svc.RemoveParentGroup(ctx, session, id)
|
||||
}
|
||||
@@ -1,246 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewCache(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *Cache {
|
||||
mock := &Cache{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// Cache is an autogenerated mock type for the Cache type
|
||||
type Cache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type Cache_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *Cache) EXPECT() *Cache_Expecter {
|
||||
return &Cache_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// ID provides a mock function for the type Cache
|
||||
func (_mock *Cache) ID(ctx context.Context, channelRoute string, domainID string) (string, error) {
|
||||
ret := _mock.Called(ctx, channelRoute, domainID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ID")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (string, error)); ok {
|
||||
return returnFunc(ctx, channelRoute, domainID)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) string); ok {
|
||||
r0 = returnFunc(ctx, channelRoute, domainID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, channelRoute, domainID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Cache_ID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ID'
|
||||
type Cache_ID_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ID is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - channelRoute string
|
||||
// - domainID string
|
||||
func (_e *Cache_Expecter) ID(ctx interface{}, channelRoute interface{}, domainID interface{}) *Cache_ID_Call {
|
||||
return &Cache_ID_Call{Call: _e.mock.On("ID", ctx, channelRoute, domainID)}
|
||||
}
|
||||
|
||||
func (_c *Cache_ID_Call) Run(run func(ctx context.Context, channelRoute string, domainID string)) *Cache_ID_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_ID_Call) Return(s string, err error) *Cache_ID_Call {
|
||||
_c.Call.Return(s, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_ID_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string) (string, error)) *Cache_ID_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Remove provides a mock function for the type Cache
|
||||
func (_mock *Cache) Remove(ctx context.Context, channelRoute string, domainID string) error {
|
||||
ret := _mock.Called(ctx, channelRoute, domainID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Remove")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, channelRoute, domainID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Cache_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove'
|
||||
type Cache_Remove_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Remove is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - channelRoute string
|
||||
// - domainID string
|
||||
func (_e *Cache_Expecter) Remove(ctx interface{}, channelRoute interface{}, domainID interface{}) *Cache_Remove_Call {
|
||||
return &Cache_Remove_Call{Call: _e.mock.On("Remove", ctx, channelRoute, domainID)}
|
||||
}
|
||||
|
||||
func (_c *Cache_Remove_Call) Run(run func(ctx context.Context, channelRoute string, domainID string)) *Cache_Remove_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_Remove_Call) Return(err error) *Cache_Remove_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_Remove_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string) error) *Cache_Remove_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Save provides a mock function for the type Cache
|
||||
func (_mock *Cache) Save(ctx context.Context, channelRoute string, domainID string, channelID string) error {
|
||||
ret := _mock.Called(ctx, channelRoute, domainID, channelID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Save")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, channelRoute, domainID, channelID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Cache_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save'
|
||||
type Cache_Save_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Save is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - channelRoute string
|
||||
// - domainID string
|
||||
// - channelID string
|
||||
func (_e *Cache_Expecter) Save(ctx interface{}, channelRoute interface{}, domainID interface{}, channelID interface{}) *Cache_Save_Call {
|
||||
return &Cache_Save_Call{Call: _e.mock.On("Save", ctx, channelRoute, domainID, channelID)}
|
||||
}
|
||||
|
||||
func (_c *Cache_Save_Call) Run(run func(ctx context.Context, channelRoute string, domainID string, channelID string)) *Cache_Save_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 string
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_Save_Call) Return(err error) *Cache_Save_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Cache_Save_Call) RunAndReturn(run func(ctx context.Context, channelRoute string, domainID string, channelID string) error) *Cache_Save_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -1,460 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/absmach/magistrala/api/grpc/channels/v1"
|
||||
v10 "github.com/absmach/magistrala/api/grpc/common/v1"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
// NewChannelsServiceClient creates a new instance of ChannelsServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewChannelsServiceClient(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *ChannelsServiceClient {
|
||||
mock := &ChannelsServiceClient{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// ChannelsServiceClient is an autogenerated mock type for the ChannelsServiceClient type
|
||||
type ChannelsServiceClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type ChannelsServiceClient_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *ChannelsServiceClient) EXPECT() *ChannelsServiceClient_Expecter {
|
||||
return &ChannelsServiceClient_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Authorize provides a mock function for the type ChannelsServiceClient
|
||||
func (_mock *ChannelsServiceClient) Authorize(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption) (*v1.AuthzRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Authorize")
|
||||
}
|
||||
|
||||
var r0 *v1.AuthzRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) (*v1.AuthzRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) *v1.AuthzRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.AuthzRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.AuthzReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ChannelsServiceClient_Authorize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Authorize'
|
||||
type ChannelsServiceClient_Authorize_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Authorize is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v1.AuthzReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *ChannelsServiceClient_Expecter) Authorize(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_Authorize_Call {
|
||||
return &ChannelsServiceClient_Authorize_Call{Call: _e.mock.On("Authorize",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_Authorize_Call) Run(run func(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption)) *ChannelsServiceClient_Authorize_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v1.AuthzReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v1.AuthzReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_Authorize_Call) Return(authzRes *v1.AuthzRes, err error) *ChannelsServiceClient_Authorize_Call {
|
||||
_c.Call.Return(authzRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_Authorize_Call) RunAndReturn(run func(ctx context.Context, in *v1.AuthzReq, opts ...grpc.CallOption) (*v1.AuthzRes, error)) *ChannelsServiceClient_Authorize_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveClientConnections provides a mock function for the type ChannelsServiceClient
|
||||
func (_mock *ChannelsServiceClient) RemoveClientConnections(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RemoveClientConnections")
|
||||
}
|
||||
|
||||
var r0 *v1.RemoveClientConnectionsRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) *v1.RemoveClientConnectionsRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.RemoveClientConnectionsRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RemoveClientConnectionsReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ChannelsServiceClient_RemoveClientConnections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveClientConnections'
|
||||
type ChannelsServiceClient_RemoveClientConnections_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RemoveClientConnections is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v1.RemoveClientConnectionsReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *ChannelsServiceClient_Expecter) RemoveClientConnections(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RemoveClientConnections_Call {
|
||||
return &ChannelsServiceClient_RemoveClientConnections_Call{Call: _e.mock.On("RemoveClientConnections",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) Run(run func(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RemoveClientConnections_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v1.RemoveClientConnectionsReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v1.RemoveClientConnectionsReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) Return(removeClientConnectionsRes *v1.RemoveClientConnectionsRes, err error) *ChannelsServiceClient_RemoveClientConnections_Call {
|
||||
_c.Call.Return(removeClientConnectionsRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RemoveClientConnections_Call) RunAndReturn(run func(ctx context.Context, in *v1.RemoveClientConnectionsReq, opts ...grpc.CallOption) (*v1.RemoveClientConnectionsRes, error)) *ChannelsServiceClient_RemoveClientConnections_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveEntity provides a mock function for the type ChannelsServiceClient
|
||||
func (_mock *ChannelsServiceClient) RetrieveEntity(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveEntity")
|
||||
}
|
||||
|
||||
var r0 *v10.RetrieveEntityRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) (*v10.RetrieveEntityRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) *v10.RetrieveEntityRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v10.RetrieveEntityRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v10.RetrieveEntityReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ChannelsServiceClient_RetrieveEntity_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntity'
|
||||
type ChannelsServiceClient_RetrieveEntity_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveEntity is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v10.RetrieveEntityReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *ChannelsServiceClient_Expecter) RetrieveEntity(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RetrieveEntity_Call {
|
||||
return &ChannelsServiceClient_RetrieveEntity_Call{Call: _e.mock.On("RetrieveEntity",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveEntity_Call) Run(run func(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RetrieveEntity_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v10.RetrieveEntityReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v10.RetrieveEntityReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveEntity_Call) Return(retrieveEntityRes *v10.RetrieveEntityRes, err error) *ChannelsServiceClient_RetrieveEntity_Call {
|
||||
_c.Call.Return(retrieveEntityRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveEntity_Call) RunAndReturn(run func(ctx context.Context, in *v10.RetrieveEntityReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error)) *ChannelsServiceClient_RetrieveEntity_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveIDByRoute provides a mock function for the type ChannelsServiceClient
|
||||
func (_mock *ChannelsServiceClient) RetrieveIDByRoute(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveIDByRoute")
|
||||
}
|
||||
|
||||
var r0 *v10.RetrieveEntityRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) (*v10.RetrieveEntityRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) *v10.RetrieveEntityRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v10.RetrieveEntityRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v10.RetrieveIDByRouteReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ChannelsServiceClient_RetrieveIDByRoute_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveIDByRoute'
|
||||
type ChannelsServiceClient_RetrieveIDByRoute_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveIDByRoute is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v10.RetrieveIDByRouteReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *ChannelsServiceClient_Expecter) RetrieveIDByRoute(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_RetrieveIDByRoute_Call {
|
||||
return &ChannelsServiceClient_RetrieveIDByRoute_Call{Call: _e.mock.On("RetrieveIDByRoute",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) Run(run func(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption)) *ChannelsServiceClient_RetrieveIDByRoute_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v10.RetrieveIDByRouteReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v10.RetrieveIDByRouteReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) Return(retrieveEntityRes *v10.RetrieveEntityRes, err error) *ChannelsServiceClient_RetrieveIDByRoute_Call {
|
||||
_c.Call.Return(retrieveEntityRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_RetrieveIDByRoute_Call) RunAndReturn(run func(ctx context.Context, in *v10.RetrieveIDByRouteReq, opts ...grpc.CallOption) (*v10.RetrieveEntityRes, error)) *ChannelsServiceClient_RetrieveIDByRoute_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UnsetParentGroupFromChannels provides a mock function for the type ChannelsServiceClient
|
||||
func (_mock *ChannelsServiceClient) UnsetParentGroupFromChannels(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UnsetParentGroupFromChannels")
|
||||
}
|
||||
|
||||
var r0 *v1.UnsetParentGroupFromChannelsRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) *v1.UnsetParentGroupFromChannelsRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.UnsetParentGroupFromChannelsRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.UnsetParentGroupFromChannelsReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// ChannelsServiceClient_UnsetParentGroupFromChannels_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UnsetParentGroupFromChannels'
|
||||
type ChannelsServiceClient_UnsetParentGroupFromChannels_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UnsetParentGroupFromChannels is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v1.UnsetParentGroupFromChannelsReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *ChannelsServiceClient_Expecter) UnsetParentGroupFromChannels(ctx interface{}, in interface{}, opts ...interface{}) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
|
||||
return &ChannelsServiceClient_UnsetParentGroupFromChannels_Call{Call: _e.mock.On("UnsetParentGroupFromChannels",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) Run(run func(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption)) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v1.UnsetParentGroupFromChannelsReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v1.UnsetParentGroupFromChannelsReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) Return(unsetParentGroupFromChannelsRes *v1.UnsetParentGroupFromChannelsRes, err error) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
|
||||
_c.Call.Return(unsetParentGroupFromChannelsRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *ChannelsServiceClient_UnsetParentGroupFromChannels_Call) RunAndReturn(run func(ctx context.Context, in *v1.UnsetParentGroupFromChannelsReq, opts ...grpc.CallOption) (*v1.UnsetParentGroupFromChannelsRes, error)) *ChannelsServiceClient_UnsetParentGroupFromChannels_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,72 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package operations
|
||||
|
||||
import (
|
||||
"github.com/absmach/magistrala/pkg/permissions"
|
||||
)
|
||||
|
||||
// Channel Operations.
|
||||
const (
|
||||
OpViewChannel permissions.Operation = iota
|
||||
OpUpdateChannel
|
||||
OpUpdateChannelTags
|
||||
OpEnableChannel
|
||||
OpDisableChannel
|
||||
OpDeleteChannel
|
||||
OpSetParentGroup
|
||||
OpRemoveParentGroup
|
||||
OpConnectClient
|
||||
OpDisconnectClient
|
||||
OpListUserChannels
|
||||
)
|
||||
|
||||
func OperationDetails() map[permissions.Operation]permissions.OperationDetails {
|
||||
return map[permissions.Operation]permissions.OperationDetails{
|
||||
OpViewChannel: {
|
||||
Name: "view",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpUpdateChannel: {
|
||||
Name: "update",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpUpdateChannelTags: {
|
||||
Name: "update_tags",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpEnableChannel: {
|
||||
Name: "enable",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpDisableChannel: {
|
||||
Name: "disable",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpDeleteChannel: {
|
||||
Name: "delete",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpSetParentGroup: {
|
||||
Name: "set_parent_group",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpRemoveParentGroup: {
|
||||
Name: "remove_parent_group",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpConnectClient: {
|
||||
Name: "connect_client",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpDisconnectClient: {
|
||||
Name: "disconnect_client",
|
||||
PermissionRequired: true,
|
||||
},
|
||||
OpListUserChannels: {
|
||||
Name: "list_user_channels",
|
||||
PermissionRequired: false, // hardcoded to superadmin
|
||||
},
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,26 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/magistrala/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "unique_domain_route_not_null":
|
||||
return errors.ErrRouteNotAvailable, true
|
||||
case "channels_pkey":
|
||||
return errors.NewRequestError("channel id already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
gpostgres "github.com/absmach/magistrala/groups/postgres"
|
||||
"github.com/absmach/magistrala/pkg/errors"
|
||||
repoerr "github.com/absmach/magistrala/pkg/errors/repository"
|
||||
rolesPostgres "github.com/absmach/magistrala/pkg/roles/repo/postgres"
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
|
||||
migrate "github.com/rubenv/sql-migrate"
|
||||
)
|
||||
|
||||
func Migration() (*migrate.MemoryMigrationSource, error) {
|
||||
rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName)
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err)
|
||||
}
|
||||
channelsMigration := &migrate.MemoryMigrationSource{
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "channels_01",
|
||||
// VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters
|
||||
// STATUS 0 to imply enabled and 1 to imply disabled
|
||||
Up: []string{
|
||||
`CREATE TABLE IF NOT EXISTS channels (
|
||||
id VARCHAR(36) PRIMARY KEY,
|
||||
name VARCHAR(1024),
|
||||
domain_id VARCHAR(36) NOT NULL,
|
||||
parent_group_id VARCHAR(36) DEFAULT NULL,
|
||||
tags TEXT[],
|
||||
metadata JSONB,
|
||||
created_by VARCHAR(254),
|
||||
created_at TIMESTAMP,
|
||||
updated_at TIMESTAMP,
|
||||
updated_by VARCHAR(254),
|
||||
status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0),
|
||||
UNIQUE (id, domain_id),
|
||||
UNIQUE (domain_id, name)
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS connections (
|
||||
channel_id VARCHAR(36),
|
||||
domain_id VARCHAR(36),
|
||||
client_id VARCHAR(36),
|
||||
type SMALLINT NOT NULL CHECK (type IN (1, 2)),
|
||||
FOREIGN KEY (channel_id, domain_id) REFERENCES channels (id, domain_id) ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
PRIMARY KEY (channel_id, domain_id, client_id, type)
|
||||
)`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP TABLE IF EXISTS channels`,
|
||||
`DROP TABLE IF EXISTS connections`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_02",
|
||||
Up: []string{
|
||||
`ALTER TABLE channels DROP CONSTRAINT IF EXISTS channels_domain_id_name_key`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE channels ADD CONSTRAINT channels_domain_id_name_key UNIQUE (domain_id, name)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_03",
|
||||
Up: []string{
|
||||
`ALTER TABLE channels ADD COLUMN IF NOT EXISTS route VARCHAR(36);`,
|
||||
`CREATE UNIQUE INDEX IF NOT EXISTS unique_domain_route_not_null ON channels (domain_id, route) WHERE route IS NOT NULL;`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS unique_domain_route_not_null;`,
|
||||
`ALTER TABLE channels DROP COLUMN IF EXISTS route;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_04",
|
||||
Up: []string{
|
||||
`ALTER TABLE channels ALTER COLUMN created_at TYPE TIMESTAMPTZ;`,
|
||||
`ALTER TABLE channels ALTER COLUMN updated_at TYPE TIMESTAMPTZ;`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE channels ALTER COLUMN created_at TYPE TIMESTAMP;`,
|
||||
`ALTER TABLE channels ALTER COLUMN updated_at TYPE TIMESTAMP;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_05",
|
||||
Up: []string{
|
||||
`UPDATE channels
|
||||
SET metadata = (COALESCE(metadata, '{}'::jsonb) || COALESCE(metadata->'ui', '{}'::jsonb)) - 'ui'
|
||||
WHERE metadata ? 'ui' AND jsonb_typeof(metadata->'ui') = 'object'`,
|
||||
},
|
||||
Down: []string{
|
||||
`SELECT 1`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "channels_06",
|
||||
Up: []string{
|
||||
`CREATE INDEX IF NOT EXISTS idx_channels_domain_id_status ON channels(domain_id, status);`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_channels_parent_group_id ON channels(parent_group_id);`,
|
||||
},
|
||||
Down: []string{
|
||||
`DROP INDEX IF EXISTS idx_channels_domain_id_status;`,
|
||||
`DROP INDEX IF EXISTS idx_channels_parent_group_id;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...)
|
||||
|
||||
groupsMigration, err := gpostgres.Migration()
|
||||
if err != nil {
|
||||
return &migrate.MemoryMigrationSource{}, err
|
||||
}
|
||||
|
||||
channelsMigration.Migrations = append(channelsMigration.Migrations, groupsMigration.Migrations...)
|
||||
|
||||
return channelsMigration, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user