diff --git a/.github/dependabot.yaml b/.github/dependabot.yml similarity index 83% rename from .github/dependabot.yaml rename to .github/dependabot.yml index 464738903..e4f3e3f06 100644 --- a/.github/dependabot.yaml +++ b/.github/dependabot.yml @@ -4,9 +4,9 @@ version: 2 updates: - package-ecosystem: "github-actions" - directory: "./.github/workflows" + directory: "/" schedule: - interval: "monthly" + interval: "weekly" day: "monday" timezone: "Europe/Paris" groups: @@ -22,9 +22,9 @@ updates: timezone: "Europe/Paris" - package-ecosystem: "docker" - directory: "./docker" + directory: "/docker" schedule: - interval: "monthly" + interval: "weekly" day: "monday" timezone: "Europe/Paris" groups: diff --git a/.github/workflows/api-tests.yaml b/.github/workflows/api-tests.yaml index 37a269740..16e7608f4 100644 --- a/.github/workflows/api-tests.yaml +++ b/.github/workflows/api-tests.yaml @@ -15,9 +15,14 @@ on: - "clients/api/http/**" - "domains/api/http/**" - "groups/api/http/**" - - "http/api/**" - "journal/api/**" - "users/api/**" + - "bootstrap/api/**" + - "certs/api/http/**" + - "readers/api/http/**" + - "re/api/**" + - "alarms/api/**" + - "reports/api/**" - "apidocs/openapi/**" pull_request: branches: @@ -30,9 +35,14 @@ on: - "clients/api/http/**" - "domains/api/http/**" - "groups/api/http/**" - - "http/api/**" - "journal/api/**" - "users/api/**" + - "bootstrap/api/**" + - "certs/api/http/**" + - "readers/api/http/**" + - "re/api/**" + - "alarms/api/**" + - "reports/api/**" - "apidocs/openapi/**" concurrency: @@ -50,9 +60,14 @@ env: CLIENTS_URL: http://localhost:9006 CHANNELS_URL: http://localhost:9005 GROUPS_URL: http://localhost:9004 - HTTP_ADAPTER_URL: http://localhost:8008 AUTH_URL: http://localhost:9001 JOURNAL_URL: http://localhost:9021 + BOOTSTRAP_URL: http://localhost:9013 + CERTS_URL: http://localhost:9019 + READERS_URL: http://localhost:9011 + RE_URL: http://localhost:9008 + ALARMS_URL: http://localhost:8050 + REPORTS_URL: http://localhost:9017 jobs: api-test: @@ -93,10 +108,6 @@ jobs: - "apidocs/openapi/domains.yaml" - "domains/api/http/**" - http: - - "apidocs/openapi/http.yaml" - - "http/api/**" - clients: - "apidocs/openapi/clients.yaml" - "clients/api/http/**" @@ -113,6 +124,30 @@ jobs: - "apidocs/openapi/users.yaml" - "users/api/**" + bootstrap: + - "apidocs/openapi/bootstrap.yaml" + - "bootstrap/api/**" + + certs: + - "apidocs/openapi/certs.yaml" + - "certs/api/http/**" + + readers: + - "apidocs/openapi/readers.yaml" + - "readers/api/http/**" + + re: + - "apidocs/openapi/rules.yaml" + - "re/api/**" + + alarms: + - "apidocs/openapi/alarms.yaml" + - "alarms/api/**" + + reports: + - "apidocs/openapi/reports.yaml" + - "reports/api/**" + - name: Build images run: make all -j $(nproc) && make dockers_dev -j $(nproc) @@ -178,15 +213,6 @@ jobs: checks: all args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' - - name: Run HTTP Adapter API tests - if: steps.changes.outputs.http == 'true' || steps.changes.outputs.workflow == 'true' - uses: schemathesis/action@v2.1.0 - with: - schema: apidocs/openapi/http.yaml - base-url: ${{ env.HTTP_ADAPTER_URL }} - checks: all - args: '--header "Authorization: Client ${{ env.CLIENT_SECRET }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' - - name: Run Auth API tests if: steps.changes.outputs.auth == 'true' || steps.changes.outputs.workflow == 'true' uses: schemathesis/action@v2.1.0 @@ -214,6 +240,60 @@ jobs: checks: all args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + - name: Run Bootstrap API tests + if: steps.changes.outputs.bootstrap == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/bootstrap.yaml + base-url: ${{ env.BOOTSTRAP_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + + - name: Run Certs API tests + if: steps.changes.outputs.certs == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/certs.yaml + base-url: ${{ env.CERTS_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + + - name: Run Readers API tests + if: steps.changes.outputs.readers == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/readers.yaml + base-url: ${{ env.READERS_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + + - name: Run Rules Engine API tests + if: steps.changes.outputs.re == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/rules.yaml + base-url: ${{ env.RE_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + + - name: Run Alarms API tests + if: steps.changes.outputs.alarms == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/alarms.yaml + base-url: ${{ env.ALARMS_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + + - name: Run Reports API tests + if: steps.changes.outputs.reports == 'true' || steps.changes.outputs.workflow == 'true' + uses: schemathesis/action@v2.1.0 + with: + schema: apidocs/openapi/reports.yaml + base-url: ${{ env.REPORTS_URL }} + checks: all + args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --suppress-health-check=filter_too_much --exclude-checks=positive_data_acceptance --phases=examples' + - name: Stop containers if: always() run: make run_latest down args="-v" && make run_addons down args="-v" diff --git a/.github/workflows/lint-and-build.yaml b/.github/workflows/lint-and-build.yaml index 524558648..28cafddb6 100644 --- a/.github/workflows/lint-and-build.yaml +++ b/.github/workflows/lint-and-build.yaml @@ -60,12 +60,9 @@ jobs: fail-fast: true matrix: variant: - - name: rabbitmq - env: SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq - target: mqtt - name: redis - env: SMQ_ES_TYPE=es_redis - target: mqtt + env: MG_ES_TYPE=es_redis + target: fluxmq steps: - name: Checkout code diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 17305a92b..0e4b395fe 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -22,24 +22,6 @@ concurrency: cancel-in-progress: true jobs: - check-certs: - name: Check Certs - runs-on: ubuntu-latest - steps: - - name: Checkout Code - uses: actions/checkout@v6 - - - name: Fetch Certs - run: | - make fetch_certs - if [[ -n $(git status --porcelain docker/addons/certs) ]]; then - echo "Certs docker file is not up to date. Please update it" - git diff docker/addons/certs - exit 1 - else - exit 0 - fi - lint-proto: name: Lint Proto runs-on: ubuntu-latest @@ -66,11 +48,9 @@ jobs: protolint . lint-and-build: - needs: [check-certs, lint-proto] + needs: [lint-proto] uses: ./.github/workflows/lint-and-build.yaml - - detect-changes: name: Detect Changes runs-on: ubuntu-latest @@ -131,14 +111,6 @@ jobs: - "domains/api/grpc/**" - "internal/grpc/**" - coap: - - "coap/**" - - "cmd/coap/**" - - "auth.pb.go" - - "auth_grpc.pb.go" - - "clients/**" - - "pkg/messaging/**" - domains: - "domains/**" - "cmd/domains/**" @@ -160,15 +132,6 @@ jobs: - "domains/api/grpc/**" - "internal/grpc/**" - http: - - "http/**" - - "cmd/http/**" - - "auth.pb.go" - - "auth_grpc.pb.go" - - "clients/**" - - "pkg/messaging/**" - - "logger/**" - internal: - "internal/**" @@ -183,16 +146,6 @@ jobs: logger: - "logger/**" - mqtt: - - "mqtt/**" - - "cmd/mqtt/**" - - "auth.pb.go" - - "auth_grpc.pb.go" - - "clients/**" - - "pkg/messaging/**" - - "logger/**" - - "pkg/events/**" - pkg-errors: - "pkg/errors/**" @@ -211,7 +164,6 @@ jobs: - "pkg/errors/**" - "pkg/groups/**" - "auth/**" - - "http/**" - "internal/*" - "clients/**" - "users/**" @@ -220,6 +172,9 @@ jobs: - "groups/**" - "journal/**" - "api/http/**" + - "re/**" + - "alarms/**" + - "reports/**" pkg-transformers: - "pkg/transformers/**" @@ -253,9 +208,28 @@ jobs: consumers: - "consumers/**" + - "cmd/postgres-writer/**" + - "cmd/timescale-writer/**" + - "cmd/smpp-notifier/**" + - "cmd/smtp-notifier/**" readers: - "readers/**" + - "cmd/postgres-reader/**" + - "cmd/timescale-reader/**" + + re: + - "re/**" + - "cmd/re/**" + - "re/api/**" + + alarms: + - "alarms/**" + - "cmd/alarms/**" + + reports: + - "reports/**" + - "cmd/reports/**" - name: Set matrix for changed modules id: set-matrix @@ -264,21 +238,18 @@ jobs: if [[ "${{ steps.changes.outputs.workflow }}" == "true" || "${{ steps.changes.outputs.pkg-errors }}" == "true" ]]; then # If workflow or pkg/errors changed, test everything - modules=("auth" "channels" "cli" "clients" "coap" "domains" "groups" "http" "internal" "journal" "logger" "mqtt" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers") + modules=("auth" "channels" "cli" "clients" "domains" "groups" "internal" "journal" "logger" "pkg-errors" "pkg-events" "pkg-grpcclient" "pkg-messaging" "pkg-sdk" "pkg-transformers" "pkg-ulid" "pkg-uuid" "users" "notifications" "api" "consumers" "readers" "re" "alarms" "reports") else # Add only changed modules [[ "${{ steps.changes.outputs.auth }}" == "true" ]] && modules+=("auth") [[ "${{ steps.changes.outputs.channels }}" == "true" ]] && modules+=("channels") [[ "${{ steps.changes.outputs.cli }}" == "true" ]] && modules+=("cli") [[ "${{ steps.changes.outputs.clients }}" == "true" ]] && modules+=("clients") - [[ "${{ steps.changes.outputs.coap }}" == "true" ]] && modules+=("coap") [[ "${{ steps.changes.outputs.domains }}" == "true" ]] && modules+=("domains") [[ "${{ steps.changes.outputs.groups }}" == "true" ]] && modules+=("groups") - [[ "${{ steps.changes.outputs.http }}" == "true" ]] && modules+=("http") [[ "${{ steps.changes.outputs.internal }}" == "true" ]] && modules+=("internal") [[ "${{ steps.changes.outputs.journal }}" == "true" ]] && modules+=("journal") [[ "${{ steps.changes.outputs.logger }}" == "true" ]] && modules+=("logger") - [[ "${{ steps.changes.outputs.mqtt }}" == "true" ]] && modules+=("mqtt") [[ "${{ steps.changes.outputs.pkg-errors }}" == "true" ]] && modules+=("pkg-errors") [[ "${{ steps.changes.outputs.pkg-events }}" == "true" ]] && modules+=("pkg-events") [[ "${{ steps.changes.outputs.pkg-grpcclient }}" == "true" ]] && modules+=("pkg-grpcclient") @@ -292,6 +263,9 @@ jobs: [[ "${{ steps.changes.outputs.api }}" == "true" ]] && modules+=("api") [[ "${{ steps.changes.outputs.consumers }}" == "true" ]] && modules+=("consumers") [[ "${{ steps.changes.outputs.readers }}" == "true" ]] && modules+=("readers") + [[ "${{ steps.changes.outputs.re }}" == "true" ]] && modules+=("re") + [[ "${{ steps.changes.outputs.alarms }}" == "true" ]] && modules+=("alarms") + [[ "${{ steps.changes.outputs.reports }}" == "true" ]] && modules+=("reports") fi # Convert to JSON array diff --git a/.gitignore b/.gitignore index 56697b68c..d55d99b2f 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,6 @@ coverage # Ignore Openbao data directory as it contains runtime-generated data docker/addons/certs/openbao/ + +# Ignore SeaweedFS data directory as it contains runtime-generated data +docker/data/* diff --git a/ADOPTERS.md b/ADOPTERS.md index 0d6d0a68f..0b0eac0ea 100644 --- a/ADOPTERS.md +++ b/ADOPTERS.md @@ -1,12 +1,12 @@ # Adopters -As SuperMQ Community grows, we'd like to keep track of SuperMQ adopters to grow the community, contact other users, share experiences and best practices. +As Magistrala Community grows, we'd like to keep track of Magistrala adopters to grow the community, contact other users, share experiences and best practices. -To accomplish this, we created a public ledger. The list of organizations and users who consider themselves as SuperMQ adopters and that **publicly/officially** shared information and/or details of their adoption journey(optional). +To accomplish this, we created a public ledger. The list of organizations and users who consider themselves as Magistrala adopters and that **publicly/officially** shared information and/or details of their adoption journey(optional). Where users themselves directly maintain the list. ## Adding yourself as an adopter -If you are using SuperMQ, please consider adding yourself as an adopter with a brief description of your use case by opening a pull request to this file and adding a section describing your adoption of SuperMQ technology. +If you are using Magistrala, please consider adding yourself as an adopter with a brief description of your use case by opening a pull request to this file and adding a section describing your adoption of Magistrala technology. **Please send PRs to add or remove organizations/users** @@ -25,9 +25,9 @@ Pull request commit must be [signed](https://docs.github.com/en/github/authentic * There is no minimum requirement or adaptation size, but we request to list permanent deployments only, i.e., no demo or trial deployments. Commercial or production use is not required. A well-done home lab setup can be equally impressive as a large-scale commercial deployment. -**The list of organizations/users that have publicly shared the usage of SuperMQ:** +**The list of organizations/users that have publicly shared the usage of Magistrala:** -**Note**: Several other organizations/users couldn't publicly share their usage details but are active project contributors and SuperMQ Community members. +**Note**: Several other organizations/users couldn't publicly share their usage details but are active project contributors and Magistrala Community members. ## Adopters list (alphabetical) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 99de0c74e..c7e3385a2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,6 @@ -# Contributing to SuperMQ +# Contributing to Magistrala -The following is a set of guidelines to contribute to SuperMQ and its libraries, which are +The following is a set of guidelines to contribute to Magistrala and its libraries, which are hosted on the [Abstract Machines Organization](https://github.com/absmach) on GitHub. This project adheres to the [Contributor Covenant 1.2](http://contributor-covenant.org/version/1/2/0). @@ -53,11 +53,11 @@ git checkout main git pull --rebase upstream main ``` -Create a new topic branch from `main` using the naming convention `SMQ-[issue-number]` +Create a new topic branch from `main` using the naming convention `MG-[issue-number]` to help us keep track of your contribution scope: ``` -git checkout -b SMQ-[issue-number] +git checkout -b MG-[issue-number] ``` Commit your changes in logical chunks. When you are ready to commit, make sure @@ -80,7 +80,7 @@ git pull --rebase upstream main Push your topic branch up to your fork: ``` -git push origin SMQ-[issue-number] +git push origin MG-[issue-number] ``` [Open a Pull Request](https://help.github.com/articles/using-pull-requests/) with a clear title diff --git a/LICENSE b/LICENSE index 8d0ab5e46..58dbdd78f 100644 --- a/LICENSE +++ b/LICENSE @@ -176,7 +176,7 @@ END OF TERMS AND CONDITIONS - Copyright 2015-2026 SuperMQ + Copyright 2015-2026 Magistrala Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile index b681e4d12..9aa6f3e13 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,10 @@ # Copyright (c) Abstract Machines # SPDX-License-Identifier: Apache-2.0 -SMQ_DOCKER_IMAGE_NAME_PREFIX ?= supermq +MG_DOCKER_IMAGE_NAME_PREFIX ?= magistrala BUILD_DIR ?= build -SERVICES = auth users clients groups channels domains http coap cli mqtt journal notifications -TEST_API_SERVICES = journal auth certs http clients users channels groups domains +SERVICES = auth users clients groups channels domains notifications certs re postgres-writer postgres-reader timescale-writer timescale-reader cli alarms reports bootstrap journal fluxmq +TEST_API_SERVICES = journal auth certs clients users channels groups domains TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES)) DOCKERS = $(addprefix docker_,$(SERVICES)) DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES)) @@ -29,21 +29,23 @@ PKG_PROTO_GEN_OUT_DIR=api/grpc INTERNAL_PROTO_DIR=internal/proto INTERNAL_PROTO_FILES := $(shell find $(INTERNAL_PROTO_DIR) -name "*.proto" | sed 's|$(INTERNAL_PROTO_DIR)/||') -ifneq ($(SMQ_MESSAGE_BROKER_TYPE),) - SMQ_MESSAGE_BROKER_TYPE := $(SMQ_MESSAGE_BROKER_TYPE) +ifneq ($(MG_MESSAGE_BROKER_TYPE),) + MG_MESSAGE_BROKER_TYPE := $(MG_MESSAGE_BROKER_TYPE) else - SMQ_MESSAGE_BROKER_TYPE=msg_nats + MG_MESSAGE_BROKER_TYPE=msg_fluxmq endif -ifneq ($(SMQ_ES_TYPE),) - SMQ_ES_TYPE := $(SMQ_ES_TYPE) +ifneq ($(MG_ES_TYPE),) + MG_ES_TYPE := $(MG_ES_TYPE) else - SMQ_ES_TYPE=es_nats + MG_ES_TYPE=es_fluxmq endif +BUILD_TAGS := $(strip $(MG_MESSAGE_BROKER_TYPE) $(MG_ES_TYPE)) + define compile_service CGO_ENABLED=$(CGO_ENABLED) GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) \ - go build -tags $(SMQ_MESSAGE_BROKER_TYPE) -tags $(SMQ_ES_TYPE) -ldflags "-s -w \ + go build -tags "$(BUILD_TAGS)" -ldflags "-s -w \ -X 'github.com/absmach/supermq.BuildTime=$(TIME)' \ -X 'github.com/absmach/supermq.Version=$(VERSION)' \ -X 'github.com/absmach/supermq.Commit=$(COMMIT)'" \ @@ -61,7 +63,7 @@ define make_docker --build-arg VERSION=$(VERSION) \ --build-arg COMMIT=$(COMMIT) \ --build-arg TIME=$(TIME) \ - --tag=$(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$(svc) \ + --tag=$(MG_DOCKER_IMAGE_NAME_PREFIX)/$(svc) \ -f docker/Dockerfile . endef @@ -71,7 +73,7 @@ define make_docker_dev docker build \ --no-cache \ --build-arg SVC=$(svc) \ - --tag=$(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$(svc) \ + --tag=$(MG_DOCKER_IMAGE_NAME_PREFIX)/$(svc) \ -f docker/Dockerfile.dev ./build endef @@ -82,20 +84,20 @@ define run_with_arch_detection git checkout $(1); \ GOARCH=arm64 $(MAKE) dockers; \ for svc in $(SERVICES); do \ - docker tag supermq/$$svc supermq/$$svc:latest; \ - docker tag supermq/$$svc docker.io/supermq/$$svc:latest; \ + docker tag magistrala/$$svc magistrala/$$svc:latest; \ + docker tag magistrala/$$svc docker.io/magistrala/$$svc:latest; \ done; \ - sed -i.bak 's/^SMQ_RELEASE_TAG=.*/SMQ_RELEASE_TAG=latest/' docker/.env && rm -f docker/.env.bak; \ + sed -i.bak 's/^MG_RELEASE_TAG=.*/MG_RELEASE_TAG=latest/' docker/.env && rm -f docker/.env.bak; \ docker compose -f docker/docker-compose.yaml --env-file docker/.env -p $(DOCKER_PROJECT) $(DOCKER_COMPOSE_COMMAND) $(args); \ else \ echo "x86_64 architecture detected."; \ git checkout $(1); \ - sed -i.bak 's/^SMQ_RELEASE_TAG=.*/SMQ_RELEASE_TAG=$(2)/' docker/.env && rm -f docker/.env.bak; \ + sed -i.bak 's/^MG_RELEASE_TAG=.*/MG_RELEASE_TAG=$(2)/' docker/.env && rm -f docker/.env.bak; \ docker compose -f docker/docker-compose.yaml --env-file docker/.env -p $(DOCKER_PROJECT) $(DOCKER_COMPOSE_COMMAND) $(args); \ fi endef -ADDON_SERVICES = journal certs +ADDON_SERVICES = bootstrap provision postgres-writer postgres-reader EXTERNAL_SERVICES = prometheus @@ -152,12 +154,12 @@ cleandocker: ifdef pv # Remove unused volumes - docker volume ls -f name=$(SMQ_DOCKER_IMAGE_NAME_PREFIX) -f dangling=true -q | xargs -r docker volume rm + docker volume ls -f name=$(MG_DOCKER_IMAGE_NAME_PREFIX) -f dangling=true -q | xargs -r docker volume rm endif install: for file in $(BUILD_DIR)/*; do \ - cp $$file $(GOBIN)/supermq-`basename $$file`; \ + cp $$file $(GOBIN)/magistrala-`basename $$file`; \ done mocks: $(MOCKERY) @@ -182,34 +184,18 @@ define test_api_service @if [ -z "$(USER_TOKEN)" ]; then \ echo "USER_TOKEN is not set"; \ - echo "Please set it to a valid token"; \ - exit 1; \ + echo "Please set it to a valid token"; \ + exit 1; \ fi - @if [ "$(svc)" = "http" ] && [ -z "$(CLIENT_SECRET)" ]; then \ - echo "CLIENT_SECRET is not set"; \ - echo "Please set it to a valid secret"; \ - exit 1; \ - fi - - @if [ "$(svc)" = "http" ]; then \ - uvx schemathesis run apidocs/openapi/$(svc).yaml \ - --checks all \ - --url $(2) \ - --header "Authorization: Client $(CLIENT_SECRET)" \ - --suppress-health-check=filter_too_much \ - --exclude-checks=positive_data_acceptance \ - --phases=examples,stateful; \ - else \ - uvx schemathesis run apidocs/openapi/$(svc).yaml \ - --checks all \ - --url $(2) \ - --header "Authorization: Bearer $(USER_TOKEN)" \ - --suppress-health-check=filter_too_much \ - --exclude-checks=positive_data_acceptance \ - --exclude-operation-id=requestPasswordReset \ - --phases=examples,stateful; \ - fi + @uvx schemathesis run apidocs/openapi/$(svc).yaml \ + --checks all \ + --url $(2) \ + --header "Authorization: Bearer $(USER_TOKEN)" \ + --suppress-health-check=filter_too_much \ + --exclude-checks=positive_data_acceptance \ + --exclude-operation-id=requestPasswordReset \ + --phases=examples,stateful endef test_api_users: TEST_API_URL := http://localhost:9002 @@ -217,7 +203,6 @@ test_api_clients: TEST_API_URL := http://localhost:9006 test_api_domains: TEST_API_URL := http://localhost:9003 test_api_channels: TEST_API_URL := http://localhost:9005 test_api_groups: TEST_API_URL := http://localhost:9004 -test_api_http: TEST_API_URL := http://localhost:8008 test_api_auth: TEST_API_URL := http://localhost:9001 test_api_certs: TEST_API_URL := http://localhost:9019 test_api_journal: TEST_API_URL := http://localhost:9021 @@ -244,7 +229,7 @@ dockers_dev: $(DOCKERS_DEV) define docker_push for svc in $(SERVICES); do \ - docker push $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(1); \ + docker push $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(1); \ done endef @@ -257,10 +242,10 @@ latest: dockers publish_arch: $(MAKE) dockers GOOS=$(GOOS) GOARCH=$(GOARCH) GOARM=$(GOARM) for svc in $(SERVICES); do \ - docker tag $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(VERSION)-$(GOARCH); \ - docker tag $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:latest-$(GOARCH); \ - docker push $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(VERSION)-$(GOARCH); \ - docker push $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:latest-$(GOARCH); \ + docker tag $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(VERSION)-$(GOARCH); \ + docker tag $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:latest-$(GOARCH); \ + docker push $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(VERSION)-$(GOARCH); \ + docker push $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:latest-$(GOARCH); \ done release: @@ -268,7 +253,7 @@ release: git checkout $(version) $(MAKE) dockers for svc in $(SERVICES); do \ - docker tag $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(SMQ_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(version); \ + docker tag $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc $(MG_DOCKER_IMAGE_NAME_PREFIX)/$$svc:$(version); \ done $(call docker_push,$(version)) @@ -303,29 +288,21 @@ endif endif endif -fetch_certs: - @./scripts/certs.sh - run_latest: check_certs - git checkout main - $(SED_INPLACE) 's/^SMQ_RELEASE_TAG=.*/SMQ_RELEASE_TAG=latest/' docker/.env + $(SED_INPLACE) 's/^MG_RELEASE_TAG=.*/MG_RELEASE_TAG=latest/' docker/.env $(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file docker/.env -p $(DOCKER_PROJECT) $(DOCKER_COMPOSE_COMMAND) $(args) run_stable: check_certs $(eval version = $(shell git describe --abbrev=0 --tags)) git checkout $(version) - $(SED_INPLACE) 's/^SMQ_RELEASE_TAG=.*/SMQ_RELEASE_TAG=$(version)/' docker/.env + $(SED_INPLACE) 's/^MG_RELEASE_TAG=.*/MG_RELEASE_TAG=$(version)/' docker/.env $(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file docker/.env -p $(DOCKER_PROJECT) $(DOCKER_COMPOSE_COMMAND) $(args) run_addons: check_certs $(foreach SVC,$(RUN_ADDON_ARGS),$(if $(filter $(SVC),$(ADDON_SERVICES) $(EXTERNAL_SERVICES)),,$(error Invalid Service $(SVC)))) @$(DOCKER_PLATFORM) docker compose -f docker/docker-compose.yaml --env-file ./docker/.env -p $(DOCKER_PROJECT) up -d auth domains jaeger @for SVC in $(RUN_ADDON_ARGS); do \ - if [ "$$SVC" = "certs" ]; then \ - $(DOCKER_PLATFORM) docker compose -f docker/addons/$$SVC/docker-compose.yaml -f docker/certs-docker-compose-override.yaml --env-file ./docker/.env --env-file ./docker/addons/$$SVC/.env -p $(DOCKER_PROJECT) $(DOCKER_COMPOSE_COMMAND) $(args) & \ - else \ - SMQ_ADDONS_CERTS_PATH_PREFIX="../." $(DOCKER_PLATFORM) docker compose -f docker/addons/$$SVC/docker-compose.yaml -p $(DOCKER_PROJECT) --env-file ./docker/.env $(DOCKER_COMPOSE_COMMAND) $(args) & \ - fi; \ + MG_ADDONS_CERTS_PATH_PREFIX="../" $(DOCKER_PLATFORM) docker compose -f docker/addons/$$SVC/docker-compose.yaml -p $(DOCKER_PROJECT) --env-file ./docker/.env $(DOCKER_COMPOSE_COMMAND) $(args) & \ done run_live: check_certs diff --git a/alarms/README.md b/alarms/README.md new file mode 100644 index 000000000..226fb39ef --- /dev/null +++ b/alarms/README.md @@ -0,0 +1,197 @@ +# Alarms + +The Alarms service stores, manages and exposes alarms raised by rules and device activity. It consumes alarm events from the message broker, persists them to PostgreSQL, and provides an HTTP API for listing, viewing, updating, and deleting alarms with full authn/authz, metrics, and tracing support. + +## Configuration + +The service is configured using the following environment variables (values shown are from [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env) as consumed by [docker/docker-compose.yaml](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml)): + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_ALARMS_LOG_LEVEL` | Log level for the service | `debug` | +| `MG_ALARMS_HTTP_HOST` | HTTP host to bind | `alarms` | +| `MG_ALARMS_HTTP_PORT` | HTTP port to bind | `8050` | +| `MG_ALARMS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTPS server certificate | "" | +| `MG_ALARMS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTPS server key | "" | +| `MG_ALARMS_DB_HOST` | PostgreSQL host | `alarms-db` | +| `MG_ALARMS_DB_PORT` | PostgreSQL port | `5432` | +| `MG_ALARMS_DB_USER` | PostgreSQL user | `magistrala` | +| `MG_ALARMS_DB_PASS` | PostgreSQL password | `magistrala` | +| `MG_ALARMS_DB_NAME` | PostgreSQL database name | `alarms` | +| `MG_ALARMS_DB_SSL_MODE` | PostgreSQL SSL mode | `disable` | +| `MG_ALARMS_DB_SSL_CERT` | PostgreSQL SSL client cert | "" | +| `MG_ALARMS_DB_SSL_KEY` | PostgreSQL SSL client key | "" | +| `MG_ALARMS_DB_SSL_ROOT_CERT` | PostgreSQL SSL root cert | "" | +| `MG_ALARMS_INSTANCE_ID` | Instance ID for tracing/health | "" | +| `MG_MESSAGE_BROKER_URL` | Message broker URL for alarm ingestion | `nats://nats:4222` | +| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` | +| `MG_AUTH_GRPC_URL` | Auth gRPC endpoint | `auth:7001` | +| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `300s` | +| `MG_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt}` | +| `MG_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key}` | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_DOMAINS_GRPC_URL` | Domains gRPC endpoint | `domains:7003` | +| `MG_DOMAINS_GRPC_TIMEOUT` | Domains gRPC timeout | `300s` | +| `MG_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt}` | +| `MG_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key}` | +| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to access | `true` | + +## Features + +- **Alarm ingestion**: Consumes alarms from the message broker and persists them to PostgreSQL. +- **Stateful updates**: Updates assignee, acknowledgment, resolution, and metadata fields. +- **Filtering and paging**: Lists alarms by domain, rule, channel, client, subtopic, status, severity, and time range. +- **Observability**: `/metrics` Prometheus endpoint and Jaeger tracing support. +- **Auth and authorization**: Authn/authz enforced via gRPC auth and domains services. + +## Architecture + +### Runtime flow + +1. The message broker publishes alarm events under the `alarms.>` subject. +2. The Alarms consumer decodes the event payload, enriches it with message metadata, validates it, and calls `CreateAlarm`. +3. The repository writes to PostgreSQL while deduplicating repeated active alarms with the same severity. +4. The HTTP API exposes list/view/update/delete operations with authn/authz, metrics, and tracing middleware. + +### Components + +- **HTTP API**: `alarms/api` exposes REST endpoints and health/metrics handlers. +- **Service layer**: `alarms/service.go` validates requests and coordinates repository operations. +- **Repository**: `alarms/postgres/alarms.go` implements persistence and filtering. +- **Consumer**: `alarms/consumer` processes broker messages and creates alarms. +- **Message broker**: `alarms/brokers` uses NATS JetStream with stream `alarms` and subject `alarms.>`. +- **Migrations**: `alarms/postgres/init.go` defines the alarms schema and indexes. + +### Alarms table + +Defined in `alarms/postgres/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `id` | `VARCHAR(36)` | Alarm UUID (primary key) | +| `rule_id` | `VARCHAR(36)` | Rule ID that triggered the alarm | +| `domain_id` | `VARCHAR(36)` | Domain ID | +| `channel_id` | `VARCHAR(36)` | Channel ID | +| `subtopic` | `TEXT` | Subtopic associated with the alarm | +| `client_id` | `VARCHAR(36)` | Client ID | +| `measurement` | `TEXT` | Measurement name | +| `value` | `TEXT` | Measured value | +| `unit` | `TEXT` | Measurement unit | +| `threshold` | `TEXT` | Threshold value | +| `cause` | `TEXT` | Cause/description | +| `status` | `SMALLINT` | 0 = active, 1 = cleared | +| `severity` | `SMALLINT` | Severity (0-100) | +| `assignee_id` | `VARCHAR(36)` | Assignee ID | +| `created_at` | `TIMESTAMPTZ` | Creation timestamp | +| `updated_at` | `TIMESTAMPTZ` | Last update timestamp | +| `updated_by` | `VARCHAR(36)` | User who updated | +| `assigned_at` | `TIMESTAMPTZ` | When assigned | +| `assigned_by` | `VARCHAR(36)` | Who assigned | +| `acknowledged_at` | `TIMESTAMPTZ` | When acknowledged | +| `acknowledged_by` | `VARCHAR(36)` | Who acknowledged | +| `resolved_at` | `TIMESTAMPTZ` | When resolved | +| `resolved_by` | `VARCHAR(36)` | Who resolved | +| `metadata` | `JSONB` | Custom metadata | + +Index: `idx_alarms_state (domain_id, rule_id, channel_id, subtopic, client_id, measurement, created_at DESC)` + +## Deployment + +### Build and run locally + +```bash +make alarms + +MG_ALARMS_LOG_LEVEL=debug \ +MG_ALARMS_HTTP_PORT=8050 \ +MG_ALARMS_DB_HOST=localhost \ +MG_ALARMS_DB_PORT=5432 \ +MG_ALARMS_DB_USER=magistrala \ +MG_ALARMS_DB_PASS=magistrala \ +MG_ALARMS_DB_NAME=alarms \ +MG_MESSAGE_BROKER_URL=nats://localhost:4222 \ +MG_AUTH_GRPC_URL=localhost:7001 \ +MG_AUTH_GRPC_TIMEOUT=300s \ +MG_DOMAINS_GRPC_URL=localhost:7003 \ +MG_DOMAINS_GRPC_TIMEOUT=300s \ +./build/alarms +``` + +### Docker Compose + +The service is available as a Docker container. Refer to [docker/docker-compose.yaml](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml) for the `alarms` and `alarms-db` services and their environment variables. For a full local stack, make sure the auth, domains, and message broker services are also running. + +```bash +docker compose -f docker/docker-compose.yaml up alarms alarms-db +``` + +### Health check + +```bash +curl -X GET http://localhost:8050/health \ + -H "accept: application/health+json" +``` + +## Testing + +```bash +go test ./alarms/... +``` + +## Usage + +The Alarms service supports the following operations: + +| Operation | Method & Path | Description | +| --- | --- | --- | +| `listAlarms` | `GET /{domainID}/alarms` | List alarms with filters | +| `viewAlarm` | `GET /{domainID}/alarms/{alarmID}` | Retrieve a single alarm | +| `updateAlarm` | `PUT /{domainID}/alarms/{alarmID}` | Update alarm status/assignee/metadata | +| `deleteAlarm` | `DELETE /{domainID}/alarms/{alarmID}` | Delete an alarm | +| `health` | `GET /health` | Service health check | + +Alarm creation is driven by message broker events and is not exposed as an HTTP endpoint. + +### Example: List alarms + +```bash +curl -X GET "http://localhost:8050//alarms?limit=10&offset=0&status=active&severity=50" \ + -H "Authorization: Bearer " +``` + +### Example: View an alarm + +```bash +curl -X GET http://localhost:8050//alarms/ \ + -H "Authorization: Bearer " +``` + +### Example: Update an alarm + +```bash +curl -X PUT http://localhost:8050//alarms/ \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "status": "cleared", + "assignee_id": "", + "severity": 40, + "metadata": { "note": "cleared after inspection" } + }' +``` + +### Example: Delete an alarm + +```bash +curl -X DELETE http://localhost:8050//alarms/ \ + -H "Authorization: Bearer " +``` + +### Example: Health check + +```bash +curl -X GET http://localhost:8050/health \ + -H "accept: application/health+json" +``` diff --git a/alarms/alarms.go b/alarms/alarms.go new file mode 100644 index 000000000..0ff19fd62 --- /dev/null +++ b/alarms/alarms.go @@ -0,0 +1,123 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "errors" + "time" + + "github.com/absmach/supermq/pkg/authn" +) + +const SeverityMax uint8 = 100 + +var ErrInvalidSeverity = errors.New("invalid severity. Must be between 0 and 100") + +type Metadata map[string]any + +// Alarm represents an alarm instance. +type Alarm struct { + ID string `json:"id"` + RuleID string `json:"rule_id"` + DomainID string `json:"domain_id"` + ChannelID string `json:"channel_id"` + ClientID string `json:"client_id"` + Subtopic string `json:"subtopic"` + Status Status `json:"status"` + Measurement string `json:"measurement"` + Value string `json:"value"` + Unit string `json:"unit"` + Threshold string `json:"threshold"` + Cause string `json:"cause"` + Severity uint8 `json:"severity"` + AssigneeID string `json:"assignee_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` + AssignedAt time.Time `json:"assigned_at,omitempty"` + AssignedBy string `json:"assigned_by,omitempty"` + AcknowledgedAt time.Time `json:"acknowledged_at,omitempty"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + ResolvedAt time.Time `json:"resolved_at,omitempty"` + ResolvedBy string `json:"resolved_by,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` +} + +type AlarmsPage struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Alarms []Alarm `json:"alarms"` +} + +type PageMetadata struct { + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + DomainID string `json:"domain_id" db:"domain_id"` + RuleID string `json:"rule_id" db:"rule_id"` + ChannelID string `json:"channel_id" db:"channel_id"` + ClientID string `json:"client_id" db:"client_id"` + Subtopic string `json:"subtopic" db:"subtopic"` + Measurement string `json:"measurement" db:"measurement"` + Dir string `json:"dir" db:"dir"` + Order string `json:"order" db:"order"` + Status Status `json:"status" db:"status"` + CreatedFrom time.Time `json:"created_from" db:"created_from"` + CreatedTo time.Time `json:"created_to" db:"created_to"` + AssigneeID string `json:"assignee_id" db:"assignee_id"` + Severity uint8 `json:"severity" db:"severity"` + UpdatedBy string `json:"updated_by" db:"updated_by"` + AssignedBy string `json:"assigned_by" db:"assigned_by"` + AcknowledgedBy string `json:"acknowledged_by" db:"acknowledged_by"` + ResolvedBy string `json:"resolved_by" db:"resolved_by"` + UserID string `json:"user_id" db:"user_id"` +} + +func (a Alarm) Validate() error { + if a.RuleID == "" { + return errors.New("rule_id is required") + } + if a.DomainID == "" { + return errors.New("domain_id is required") + } + if a.ChannelID == "" { + return errors.New("channel_id is required") + } + if a.ClientID == "" { + return errors.New("client_id is required") + } + if a.Measurement == "" { + return errors.New("measurement is required") + } + if a.Value == "" { + return errors.New("value is required") + } + if a.Cause == "" { + return errors.New("cause is required") + } + if a.Severity > SeverityMax { + return ErrInvalidSeverity + } + + return nil +} + +// Service specifies an API that must be fulfilled by the domain service. +type Service interface { + CreateAlarm(ctx context.Context, alarm Alarm) error + UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, session authn.Session, id string) (Alarm, error) + ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, session authn.Session, id string) error +} + +type Repository interface { + CreateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + UpdateAlarm(ctx context.Context, alarm Alarm) (Alarm, error) + ViewAlarm(ctx context.Context, alarmID, domainID string) (Alarm, error) + ListAllAlarms(ctx context.Context, pm PageMetadata) (AlarmsPage, error) + ListUserAlarms(ctx context.Context, userID string, pm PageMetadata) (AlarmsPage, error) + DeleteAlarm(ctx context.Context, id string) error +} diff --git a/alarms/alarms_test.go b/alarms/alarms_test.go new file mode 100644 index 000000000..d446c1b10 --- /dev/null +++ b/alarms/alarms_test.go @@ -0,0 +1,173 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "fmt" + "testing" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateAlarms(t *testing.T) { + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("rule_id is required"), + }, + { + desc: "missing domain_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("domain_id is required"), + }, + { + desc: "missing channel_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("channel_id is required"), + }, + { + desc: "missing client_id", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("client_id is required"), + }, + { + desc: "missing measurement", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("measurement is required"), + }, + { + desc: "missing value", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: errors.New("value is required"), + }, + { + desc: "missing cause", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Severity: 100, + }, + err: errors.New("cause is required"), + }, + { + desc: "higher severity", + alarm: alarms.Alarm{ + RuleID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + ChannelID: testsutil.GenerateUUID(t), + ClientID: testsutil.GenerateUUID(t), + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: alarms.SeverityMax + 1, + }, + err: alarms.ErrInvalidSeverity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.alarm.Validate() + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} diff --git a/coap/api/doc.go b/alarms/api/doc.go similarity index 100% rename from coap/api/doc.go rename to alarms/api/doc.go diff --git a/alarms/api/endpoint.go b/alarms/api/endpoint.go new file mode 100644 index 000000000..1723f252c --- /dev/null +++ b/alarms/api/endpoint.go @@ -0,0 +1,104 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + "github.com/absmach/supermq/alarms" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" +) + +func updateAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(updateAlarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.UpdateAlarm(ctx, session, req.Alarm) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func viewAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + alarm, err := svc.ViewAlarm(ctx, session, req.ID) + if err != nil { + return alarmRes{}, err + } + + return alarmRes{ + Alarm: alarm, + }, nil + } +} + +func listAlarmsEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(listAlarmsReq) + if err := req.validate(); err != nil { + return alarmsPageRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return alarmsPageRes{}, svcerr.ErrAuthorization + } + + alarms, err := svc.ListAlarms(ctx, session, req.PageMetadata) + if err != nil { + return alarmsPageRes{}, err + } + + return alarmsPageRes{ + AlarmsPage: alarms, + }, nil + } +} + +func deleteAlarmEndpoint(svc alarms.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(alarmReq) + if err := req.validate(); err != nil { + return alarmRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return alarmRes{}, svcerr.ErrAuthorization + } + + if err := svc.DeleteAlarm(ctx, session, req.ID); err != nil { + return alarmRes{}, err + } + + return alarmRes{deleted: true}, nil + } +} diff --git a/alarms/api/requests.go b/alarms/api/requests.go new file mode 100644 index 000000000..5d2b9921e --- /dev/null +++ b/alarms/api/requests.go @@ -0,0 +1,59 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + + "github.com/absmach/supermq/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" +) + +type alarmReq struct { + alarms.Alarm `json:",inline"` +} + +func (req alarmReq) validate() error { + if req.Alarm.ID == "" { + return errors.New("missing alarm id") + } + + return nil +} + +type updateAlarmReq struct { + alarms.Alarm `json:",inline"` +} + +func (req updateAlarmReq) validate() error { + if req.Alarm.ID == "" { + return errors.New("missing alarm id") + } + if req.Alarm.AssigneeID == "" && req.Alarm.AcknowledgedBy == "" && req.Alarm.ResolvedBy == "" && len(req.Alarm.Metadata) == 0 { + return errors.New("at least one of assignee_id, acknowledged_by, resolved_by, or metadata must be set") + } + + return nil +} + +type listAlarmsReq struct { + alarms.PageMetadata +} + +func (req listAlarmsReq) validate() error { + if req.Limit > api.MaxLimitSize || req.Limit < 1 { + return apiutil.ErrLimitSize + } + + if req.Order != "" && req.Order != api.UpdatedAtOrder && req.Order != api.CreatedAtOrder { + return apiutil.ErrInvalidOrder + } + + if req.Dir != api.AscDir && req.Dir != api.DescDir { + return apiutil.ErrInvalidDirection + } + + return nil +} diff --git a/alarms/api/responses.go b/alarms/api/responses.go new file mode 100644 index 000000000..ecd08ca7e --- /dev/null +++ b/alarms/api/responses.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/alarms" +) + +var ( + _ supermq.Response = (*alarmRes)(nil) + _ supermq.Response = (*alarmsPageRes)(nil) +) + +type alarmRes struct { + alarms.Alarm `json:",inline"` + created bool + deleted bool +} + +func (res alarmRes) Headers() map[string]string { + switch { + case res.created: + return map[string]string{ + "Location": fmt.Sprintf("/%s/alarms/%s", res.DomainID, res.ID), + } + default: + return map[string]string{} + } +} + +func (res alarmRes) Code() int { + switch { + case res.created: + return http.StatusCreated + case res.deleted: + return http.StatusNoContent + default: + return http.StatusOK + } +} + +func (res alarmRes) Empty() bool { + switch { + case res.deleted: + return true + default: + return false + } +} + +type alarmsPageRes struct { + alarms.AlarmsPage `json:",inline"` +} + +func (res alarmsPageRes) Headers() map[string]string { + return map[string]string{} +} + +func (res alarmsPageRes) Code() int { + return http.StatusOK +} + +func (res alarmsPageRes) Empty() bool { + return false +} diff --git a/alarms/api/transport.go b/alarms/api/transport.go new file mode 100644 index 000000000..ac9f1dc78 --- /dev/null +++ b/alarms/api/transport.go @@ -0,0 +1,209 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "math" + "net/http" + "strings" + "time" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/alarms" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +func MakeHandler(svc alarms.Service, logger *slog.Logger, idp supermq.IDProvider, instanceID string, authn smqauthn.AuthNMiddleware) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + mux := chi.NewRouter() + + mux.Route("/{domainID}/alarms", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) + r.Use(api.RequestIDMiddleware(idp)) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listAlarmsEndpoint(svc), + decodeListAlarmsReq, + api.EncodeResponse, + opts..., + ), "list_alarms").ServeHTTP) + r.Route("/{alarmID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "get_alarm").ServeHTTP) + r.Put("/", otelhttp.NewHandler(kithttp.NewServer( + updateAlarmEndpoint(svc), + decodeUpdateAlarmReq, + api.EncodeResponse, + opts..., + ), "update_alarm").ServeHTTP) + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteAlarmEndpoint(svc), + decodeAlarmReq, + api.EncodeResponse, + opts..., + ), "delete_alarm").ServeHTTP) + }) + }) + }) + + mux.Get("/health", supermq.Health("alarms", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeListAlarmsReq(_ context.Context, r *http.Request) (any, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + domainID, err := apiutil.ReadStringQuery(r, "domain_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + channelID, err := apiutil.ReadStringQuery(r, "channel_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + clientID, err := apiutil.ReadStringQuery(r, "client_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + subtopic, err := apiutil.ReadStringQuery(r, "subtopic", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + ruleID, err := apiutil.ReadStringQuery(r, "rule_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + s, err := apiutil.ReadStringQuery(r, api.StatusKey, alarms.All) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + status, err := alarms.ToStatus(s) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assigneeID, err := apiutil.ReadStringQuery(r, "assignee_id", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + serverity, err := apiutil.ReadNumQuery(r, "severity", uint64(math.MaxUint8)) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + updatedBy, err := apiutil.ReadStringQuery(r, "updated_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + assignedBy, err := apiutil.ReadStringQuery(r, "assigned_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + acknowledgedBy, err := apiutil.ReadStringQuery(r, "acknowledged_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + resolvedBy, err := apiutil.ReadStringQuery(r, "resolved_by", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + cfrom, err := apiutil.ReadStringQuery(r, "created_from", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + cto, err := apiutil.ReadStringQuery(r, "created_to", "") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder) + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") + if err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + var createdFrom, createdTo time.Time + if cfrom != "" { + if createdFrom, err = time.Parse(time.RFC3339, cfrom); err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + } + if cto != "" { + if createdTo, err = time.Parse(time.RFC3339, cto); err != nil { + return listAlarmsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + } + + return listAlarmsReq{ + PageMetadata: alarms.PageMetadata{ + Offset: offset, + Limit: limit, + DomainID: domainID, + ChannelID: channelID, + ClientID: clientID, + Subtopic: subtopic, + RuleID: ruleID, + Status: status, + AssigneeID: assigneeID, + ResolvedBy: resolvedBy, + Severity: uint8(serverity), + UpdatedBy: updatedBy, + AcknowledgedBy: acknowledgedBy, + AssignedBy: assignedBy, + CreatedFrom: createdFrom, + CreatedTo: createdTo, + Dir: dir, + Order: order, + }, + }, nil +} + +func decodeAlarmReq(_ context.Context, r *http.Request) (any, error) { + return alarmReq{ + Alarm: alarms.Alarm{ + ID: chi.URLParam(r, "alarmID"), + }, + }, nil +} + +func decodeUpdateAlarmReq(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return updateAlarmReq{}, apiutil.ErrUnsupportedContentType + } + + req := updateAlarmReq{} + if err := json.NewDecoder(r.Body).Decode(&req.Alarm); err != nil { + return updateAlarmReq{}, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + req.Alarm.ID = chi.URLParam(r, "alarmID") + + return req, nil +} diff --git a/alarms/brokers/brokers_fluxmq.go b/alarms/brokers/brokers_fluxmq.go new file mode 100644 index 000000000..149663212 --- /dev/null +++ b/alarms/brokers/brokers_fluxmq.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build msg_fluxmq +// +build msg_fluxmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/fluxmq" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + AllTopic = "alarms/#" + + prefix = "alarms" +) + +var cfg = jetstream.StreamConfig{ + Name: "alarms", + Description: "SuperMQ stream alarms", + Subjects: []string{"alarms/#"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, +} + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix(prefix), broker.JSStreamConfig(cfg), broker.ConnectionName("alarms-msg-pubsub")) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewPublisher(ctx context.Context, url string) (messaging.Publisher, error) { + pb, err := broker.NewPublisher(ctx, url, broker.Prefix(prefix), broker.JSStreamConfig(cfg), broker.ConnectionName("alarms-msg-pub")) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/brokers/brokers_nats.go b/alarms/brokers/brokers_nats.go new file mode 100644 index 000000000..562bf55ec --- /dev/null +++ b/alarms/brokers/brokers_nats.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !msg_fluxmq && !msg_rabbitmq && !rabbitmq +// +build !msg_fluxmq,!msg_rabbitmq,!rabbitmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/nats" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + AllTopic = "alarms.>" + + prefix = "alarms" +) + +var cfg = jetstream.StreamConfig{ + Name: "alarms", + Description: "SuperMQ stream alarms", + Subjects: []string{"alarms.>"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, +} + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix(prefix), broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewPublisher(ctx context.Context, url string) (messaging.Publisher, error) { + pb, err := broker.NewPublisher(ctx, url, broker.Prefix(prefix), broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/alarms/consumer/consumer.go b/alarms/consumer/consumer.go new file mode 100644 index 000000000..fe5d39928 --- /dev/null +++ b/alarms/consumer/consumer.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "bytes" + "context" + "encoding/gob" + "log/slog" + "time" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" +) + +var errFailedToDecode = errors.New("failed to decode alarm") + +type handler struct { + svc alarms.Service + logger *slog.Logger +} + +func NewHandler(svc alarms.Service, logger *slog.Logger) messaging.MessageHandler { + return &handler{svc: svc, logger: logger} +} + +func (h handler) Handle(msg *messaging.Message) (err error) { + if msg == nil { + return errors.New("message is empty") + } + if msg.GetPayload() == nil { + return errors.New("message payload is empty") + } + + var alarm alarms.Alarm + if err := gob.NewDecoder(bytes.NewReader(msg.GetPayload())).Decode(&alarm); err != nil { + return messaging.NewError(errors.Wrap(errFailedToDecode, err), messaging.Term) + } + alarm.DomainID = msg.GetDomain() + alarm.ChannelID = msg.GetChannel() + alarm.ClientID = msg.ClientIdentity() + alarm.Subtopic = msg.GetSubtopic() + alarm.CreatedAt = time.Unix(0, int64(msg.GetCreated())) + + if err := alarm.Validate(); err != nil { + return err + } + + return h.svc.CreateAlarm(context.Background(), alarm) +} + +func (h handler) Cancel() error { + return nil +} diff --git a/alarms/doc.go b/alarms/doc.go new file mode 100644 index 000000000..9f7866f33 --- /dev/null +++ b/alarms/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package alarms contains domain concept definitions needed to support +// Alarms service feature, i.e. create, read, update, and delete alarms. +package alarms diff --git a/alarms/middleware/authorization.go b/alarms/middleware/authorization.go new file mode 100644 index 000000000..6a5aed67e --- /dev/null +++ b/alarms/middleware/authorization.go @@ -0,0 +1,172 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/alarms/operations" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/authn" + smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" +) + +var ( + errDomainUpdateAlarms = errors.New("not authorized to update alarms in domain") + errDomainDeleteAlarms = errors.New("not authorized to delete alarms in domain") + errDomainViewAlarms = errors.New("not authorized to view alarms in domain") +) + +type authorizationMiddleware struct { + svc alarms.Service + authz smqauthz.Authorization + entitiesOps permissions.EntitiesOperations[permissions.Operation] +} + +var _ alarms.Service = (*authorizationMiddleware)(nil) + +func NewAuthorizationMiddleware(svc alarms.Service, authz smqauthz.Authorization, entitiesOps permissions.EntitiesOperations[permissions.Operation]) (alarms.Service, error) { + if err := entitiesOps.Validate(); err != nil { + return nil, err + } + + return &authorizationMiddleware{ + svc: svc, + authz: authz, + entitiesOps: entitiesOps, + }, nil +} + +func (am *authorizationMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + return am.svc.CreateAlarm(ctx, alarm) +} + +func (am *authorizationMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + if len(alarm.Metadata) > 0 { + if err := am.authorize(ctx, operations.OpUpdateAlarm, session, policies.DomainType, session.DomainID); err != nil { + return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err) + } + } + + if alarm.AssigneeID != "" { + if err := am.authorize(ctx, operations.OpAssignAlarm, session, policies.DomainType, session.DomainID); err != nil { + return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err) + } + domainUserID := auth.EncodeDomainUserID(session.DomainID, alarm.AssigneeID) + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: domainUserID, + Permission: policies.MembershipPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + }, nil); err != nil { + return alarms.Alarm{}, err + } + } + + if alarm.AcknowledgedBy != "" { + if err := am.authorize(ctx, operations.OpAcknowledgeAlarm, session, policies.DomainType, session.DomainID); err != nil { + return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err) + } + } + + if alarm.ResolvedBy != "" { + if err := am.authorize(ctx, operations.OpResolveAlarm, session, policies.DomainType, session.DomainID); err != nil { + return alarms.Alarm{}, errors.Wrap(errDomainUpdateAlarms, err) + } + } + + return am.svc.UpdateAlarm(ctx, session, alarm) +} + +func (am *authorizationMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, operations.OpDeleteAlarm, session, policies.DomainType, session.DomainID); err != nil { + return errors.Wrap(errDomainDeleteAlarms, err) + } + + return am.svc.DeleteAlarm(ctx, session, id) +} + +func (am *authorizationMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + if pm.DomainID == "" { + pm.DomainID = session.DomainID + } + + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: + session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return alarms.AlarmsPage{}, err + } + + return am.svc.ListAlarms(ctx, session, pm) +} + +func (am *authorizationMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + if err := am.authorize(ctx, operations.OpViewAlarm, session, policies.DomainType, session.DomainID); err != nil { + return alarms.Alarm{}, errors.Wrap(errDomainViewAlarms, err) + } + + return am.svc.ViewAlarm(ctx, session, id) +} + +func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session, objType, obj string) error { + perm, err := am.entitiesOps.GetPermission(operations.EntityType, op) + if err != nil { + return err + } + + pr := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: obj, + ObjectType: objType, + Permission: perm.String(), + } + + var pat *smqauthz.PATReq + if session.PatID != "" { + opName := am.entitiesOps.OperationName(operations.EntityType, op) + pat = &smqauthz.PATReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityID: session.DomainID, + EntityType: operations.EntityType, + Operation: opName, + Domain: session.DomainID, + } + } + + if err := am.authz.Authorize(ctx, pr, pat); err != nil { + return err + } + + return nil +} + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error { + if session.Role != authn.SuperAdminRole { + return svcerr.ErrSuperAdminAction + } + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + SubjectType: policies.UserType, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.MagistralaObject, + }, nil); err != nil { + return err + } + return nil +} diff --git a/alarms/middleware/doc.go b/alarms/middleware/doc.go new file mode 100644 index 000000000..ce4a296d2 --- /dev/null +++ b/alarms/middleware/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides middleware for the alarms service. +// This is logging, metrics, and tracing middleware. +package middleware diff --git a/alarms/middleware/logging.go b/alarms/middleware/logging.go new file mode 100644 index 000000000..3ac3016d4 --- /dev/null +++ b/alarms/middleware/logging.go @@ -0,0 +1,155 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-chi/chi/v5/middleware" +) + +type loggingMiddleware struct { + logger *slog.Logger + service alarms.Service +} + +var _ alarms.Service = (*loggingMiddleware)(nil) + +func NewLoggingMiddleware(logger *slog.Logger, service alarms.Service) alarms.Service { + return &loggingMiddleware{ + logger: logger, + service: service, + } +} + +func (lm *loggingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("rule_id", alarm.RuleID), + slog.String("domain_id", alarm.DomainID), + slog.String("channel_id", alarm.ChannelID), + slog.String("client_id", alarm.ClientID), + slog.String("subtopic", alarm.Subtopic), + slog.String("measurement", alarm.Measurement), + slog.String("value", alarm.Value), + slog.String("unit", alarm.Unit), + slog.Uint64("status", uint64(alarm.Status)), + slog.Uint64("severity", uint64(alarm.Severity)), + slog.String("threshold", alarm.Threshold), + slog.String("cause", alarm.Cause), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create alarm failed", args...) + return + } + if alarm.ID != "" { + lm.logger.Info("Create alarm completed successfully", args...) + } + }(time.Now()) + + return lm.service.CreateAlarm(ctx, alarm) +} + +func (lm *loggingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Group("alarm", + slog.String("id", dba.ID), + slog.String("rule_id", dba.RuleID), + slog.String("domain_id", dba.DomainID), + slog.String("channel_id", dba.ChannelID), + slog.String("client_id", dba.ClientID), + slog.String("subtopic", dba.Subtopic), + slog.String("measurement", dba.Measurement), + slog.String("value", dba.Value), + slog.String("unit", dba.Unit), + slog.String("status", dba.Status.String()), + slog.Uint64("severity", uint64(dba.Severity)), + slog.String("threshold", dba.Threshold), + slog.String("cause", dba.Cause), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update alarm failed", args...) + return + } + lm.logger.Info("Update alarm completed successfully", args...) + }(time.Now()) + + return lm.service.UpdateAlarm(ctx, session, alarm) +} + +func (lm *loggingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (dba alarms.Alarm, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View alarm failed", args...) + return + } + lm.logger.Info("View alarm completed successfully", args...) + }(time.Now()) + + return lm.service.ViewAlarm(ctx, session, id) +} + +func (lm *loggingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (dbp alarms.AlarmsPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.Int("offset", int(pm.Offset)), + slog.Int("limit", int(pm.Limit)), + slog.String("rule_id", pm.RuleID), + slog.String("domain_id", pm.DomainID), + slog.String("channel_id", pm.ChannelID), + slog.String("client_id", pm.ClientID), + slog.String("subtopic", pm.Subtopic), + slog.String("status", pm.Status.String()), + slog.Uint64("severity", uint64(pm.Severity)), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List alarms failed", args...) + return + } + lm.logger.Info("List alarms completed successfully", args...) + }(time.Now()) + + return lm.service.ListAlarms(ctx, session, pm) +} + +func (lm *loggingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + slog.String("id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete alarm failed", args...) + return + } + lm.logger.Info("Delete alarm completed successfully", args...) + }(time.Now()) + + return lm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/metrics.go b/alarms/middleware/metrics.go new file mode 100644 index 000000000..da831513a --- /dev/null +++ b/alarms/middleware/metrics.go @@ -0,0 +1,74 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-kit/kit/metrics" +) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + service alarms.Service +} + +var _ alarms.Service = (*metricsMiddleware)(nil) + +func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service alarms.Service) alarms.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + service: service, + } +} + +func (mm *metricsMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + defer func(begin time.Time) { + mm.counter.With("method", "create_alarm").Add(1) + mm.latency.With("method", "create_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.CreateAlarm(ctx, alarm) +} + +func (mm *metricsMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_alarm").Add(1) + mm.latency.With("method", "update_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateAlarm(ctx, session, alarm) +} + +func (mm *metricsMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_alarm").Add(1) + mm.latency.With("method", "get_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewAlarm(ctx, session, id) +} + +func (mm *metricsMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_alarms").Add(1) + mm.latency.With("method", "list_alarms").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ListAlarms(ctx, session, pm) +} + +func (mm *metricsMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "delete_alarm").Add(1) + mm.latency.With("method", "delete_alarm").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/middleware/tracing.go b/alarms/middleware/tracing.go new file mode 100644 index 000000000..e2b4a055f --- /dev/null +++ b/alarms/middleware/tracing.go @@ -0,0 +1,84 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/authn" + smqTracing "github.com/absmach/supermq/pkg/tracing" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +type tracingMiddleware struct { + tracer trace.Tracer + svc alarms.Service +} + +var _ alarms.Service = (*tracingMiddleware)(nil) + +func NewTracingMiddleware(tracer trace.Tracer, svc alarms.Service) alarms.Service { + return &tracingMiddleware{ + tracer: tracer, + svc: svc, + } +} + +func (tm *tracingMiddleware) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "create_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.CreateAlarm(ctx, alarm) +} + +func (tm *tracingMiddleware) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_alarm", trace.WithAttributes( + attribute.String("rule_id", alarm.RuleID), + attribute.String("measurement", alarm.Measurement), + attribute.String("value", alarm.Value), + attribute.String("unit", alarm.Unit), + attribute.String("cause", alarm.Cause), + attribute.String("status", alarm.Status.String()), + )) + defer span.End() + + return tm.svc.UpdateAlarm(ctx, session, alarm) +} + +func (tm *tracingMiddleware) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "get_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewAlarm(ctx, session, id) +} + +func (tm *tracingMiddleware) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_alarms", trace.WithAttributes( + attribute.Int("offset", int(pm.Offset)), + attribute.Int("limit", int(pm.Limit)), + )) + defer span.End() + + return tm.svc.ListAlarms(ctx, session, pm) +} + +func (tm *tracingMiddleware) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "delete_alarm", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DeleteAlarm(ctx, session, id) +} diff --git a/alarms/mocks/repository.go b/alarms/mocks/repository.go new file mode 100644 index 000000000..186a098e3 --- /dev/null +++ b/alarms/mocks/repository.go @@ -0,0 +1,442 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/alarms" + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Repository +func (_mock *Repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Repository_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx context.Context +// - alarm alarms.Alarm +func (_e *Repository_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Repository_CreateAlarm_Call { + return &Repository_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Repository_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 alarms.Alarm + if args[1] != nil { + arg1 = args[1].(alarms.Alarm) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_CreateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_CreateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Repository +func (_mock *Repository) DeleteAlarm(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Repository_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) DeleteAlarm(ctx interface{}, id interface{}) *Repository_DeleteAlarm_Call { + return &Repository_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, id)} +} + +func (_c *Repository_DeleteAlarm_Call) Run(run func(ctx context.Context, id string)) *Repository_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) Return(err error) *Repository_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Repository_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAllAlarms provides a mock function for the type Repository +func (_mock *Repository) ListAllAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAllAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAllAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllAlarms' +type Repository_ListAllAlarms_Call struct { + *mock.Call +} + +// ListAllAlarms is a helper method to define mock.On call +// - ctx context.Context +// - pm alarms.PageMetadata +func (_e *Repository_Expecter) ListAllAlarms(ctx interface{}, pm interface{}) *Repository_ListAllAlarms_Call { + return &Repository_ListAllAlarms_Call{Call: _e.mock.On("ListAllAlarms", ctx, pm)} +} + +func (_c *Repository_ListAllAlarms_Call) Run(run func(ctx context.Context, pm alarms.PageMetadata)) *Repository_ListAllAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 alarms.PageMetadata + if args[1] != nil { + arg1 = args[1].(alarms.PageMetadata) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListAllAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListAllAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Repository_ListAllAlarms_Call) RunAndReturn(run func(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListAllAlarms_Call { + _c.Call.Return(run) + return _c +} + +// ListUserAlarms provides a mock function for the type Repository +func (_mock *Repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for ListUserAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, userID, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, userID, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, userID, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListUserAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserAlarms' +type Repository_ListUserAlarms_Call struct { + *mock.Call +} + +// ListUserAlarms is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - pm alarms.PageMetadata +func (_e *Repository_Expecter) ListUserAlarms(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserAlarms_Call { + return &Repository_ListUserAlarms_Call{Call: _e.mock.On("ListUserAlarms", ctx, userID, pm)} +} + +func (_c *Repository_ListUserAlarms_Call) Run(run func(ctx context.Context, userID string, pm alarms.PageMetadata)) *Repository_ListUserAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 alarms.PageMetadata + if args[2] != nil { + arg2 = args[2].(alarms.PageMetadata) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ListUserAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Repository_ListUserAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Repository_ListUserAlarms_Call) RunAndReturn(run func(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Repository_ListUserAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Repository +func (_mock *Repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Repository_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx context.Context +// - alarm alarms.Alarm +func (_e *Repository_Expecter) UpdateAlarm(ctx interface{}, alarm interface{}) *Repository_UpdateAlarm_Call { + return &Repository_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, alarm)} +} + +func (_c *Repository_UpdateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Repository_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 alarms.Alarm + if args[1] != nil { + arg1 = args[1].(alarms.Alarm) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Repository_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Repository_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error)) *Repository_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Repository +func (_mock *Repository) ViewAlarm(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, alarmID, domainID) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, alarmID, domainID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, alarmID, domainID) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, alarmID, domainID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Repository_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx context.Context +// - alarmID string +// - domainID string +func (_e *Repository_Expecter) ViewAlarm(ctx interface{}, alarmID interface{}, domainID interface{}) *Repository_ViewAlarm_Call { + return &Repository_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, alarmID, domainID)} +} + +func (_c *Repository_ViewAlarm_Call) Run(run func(ctx context.Context, alarmID string, domainID string)) *Repository_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Repository_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Repository_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, alarmID string, domainID string) (alarms.Alarm, error)) *Repository_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/mocks/service.go b/alarms/mocks/service.go new file mode 100644 index 000000000..26c1dc982 --- /dev/null +++ b/alarms/mocks/service.go @@ -0,0 +1,380 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// CreateAlarm provides a mock function for the type Service +func (_mock *Service) CreateAlarm(ctx context.Context, alarm alarms.Alarm) error { + ret := _mock.Called(ctx, alarm) + + if len(ret) == 0 { + panic("no return value specified for CreateAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, alarms.Alarm) error); ok { + r0 = returnFunc(ctx, alarm) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_CreateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateAlarm' +type Service_CreateAlarm_Call struct { + *mock.Call +} + +// CreateAlarm is a helper method to define mock.On call +// - ctx context.Context +// - alarm alarms.Alarm +func (_e *Service_Expecter) CreateAlarm(ctx interface{}, alarm interface{}) *Service_CreateAlarm_Call { + return &Service_CreateAlarm_Call{Call: _e.mock.On("CreateAlarm", ctx, alarm)} +} + +func (_c *Service_CreateAlarm_Call) Run(run func(ctx context.Context, alarm alarms.Alarm)) *Service_CreateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 alarms.Alarm + if args[1] != nil { + arg1 = args[1].(alarms.Alarm) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_CreateAlarm_Call) Return(err error) *Service_CreateAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_CreateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm alarms.Alarm) error) *Service_CreateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteAlarm provides a mock function for the type Service +func (_mock *Service) DeleteAlarm(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type Service_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) DeleteAlarm(ctx interface{}, session interface{}, id interface{}) *Service_DeleteAlarm_Call { + return &Service_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, session, id)} +} + +func (_c *Service_DeleteAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_DeleteAlarm_Call) Return(err error) *Service_DeleteAlarm_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ListAlarms provides a mock function for the type Service +func (_mock *Service) ListAlarms(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 alarms.AlarmsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) (alarms.AlarmsPage, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.PageMetadata) alarms.AlarmsPage); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(alarms.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.PageMetadata) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type Service_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - pm alarms.PageMetadata +func (_e *Service_Expecter) ListAlarms(ctx interface{}, session interface{}, pm interface{}) *Service_ListAlarms_Call { + return &Service_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, session, pm)} +} + +func (_c *Service_ListAlarms_Call) Run(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata)) *Service_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 alarms.PageMetadata + if args[2] != nil { + arg2 = args[2].(alarms.PageMetadata) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ListAlarms_Call) Return(alarmsPage alarms.AlarmsPage, err error) *Service_ListAlarms_Call { + _c.Call.Return(alarmsPage, err) + return _c +} + +func (_c *Service_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm alarms.PageMetadata) (alarms.AlarmsPage, error)) *Service_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// UpdateAlarm provides a mock function for the type Service +func (_mock *Service) UpdateAlarm(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, alarm) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, alarm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, alarms.Alarm) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, alarm) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, alarms.Alarm) error); ok { + r1 = returnFunc(ctx, session, alarm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type Service_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - alarm alarms.Alarm +func (_e *Service_Expecter) UpdateAlarm(ctx interface{}, session interface{}, alarm interface{}) *Service_UpdateAlarm_Call { + return &Service_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, session, alarm)} +} + +func (_c *Service_UpdateAlarm_Call) Run(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm)) *Service_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 alarms.Alarm + if args[2] != nil { + arg2 = args[2].(alarms.Alarm) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateAlarm_Call) Return(alarm1 alarms.Alarm, err error) *Service_UpdateAlarm_Call { + _c.Call.Return(alarm1, err) + return _c +} + +func (_c *Service_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, alarm alarms.Alarm) (alarms.Alarm, error)) *Service_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewAlarm provides a mock function for the type Service +func (_mock *Service) ViewAlarm(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 alarms.Alarm + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (alarms.Alarm, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) alarms.Alarm); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(alarms.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type Service_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) ViewAlarm(ctx interface{}, session interface{}, id interface{}) *Service_ViewAlarm_Call { + return &Service_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, session, id)} +} + +func (_c *Service_ViewAlarm_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ViewAlarm_Call) Return(alarm alarms.Alarm, err error) *Service_ViewAlarm_Call { + _c.Call.Return(alarm, err) + return _c +} + +func (_c *Service_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (alarms.Alarm, error)) *Service_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} diff --git a/alarms/operations/operations.go b/alarms/operations/operations.go new file mode 100644 index 000000000..2e536da4c --- /dev/null +++ b/alarms/operations/operations.go @@ -0,0 +1,52 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package operations + +import "github.com/absmach/supermq/pkg/permissions" + +const EntityType = "alarm" + +// Alarm Operations. +const ( + OpViewAlarm permissions.Operation = iota + OpDeleteAlarm + OpListAlarms + OpAssignAlarm + OpAcknowledgeAlarm + OpResolveAlarm + OpUpdateAlarm +) + +func OperationDetails() map[permissions.Operation]permissions.OperationDetails { + return map[permissions.Operation]permissions.OperationDetails{ + OpViewAlarm: { + Name: "view", + PermissionRequired: true, + }, + OpDeleteAlarm: { + Name: "delete", + PermissionRequired: true, + }, + OpListAlarms: { + Name: "list", + PermissionRequired: true, + }, + OpAssignAlarm: { + Name: "assign", + PermissionRequired: true, + }, + OpAcknowledgeAlarm: { + Name: "acknowledge", + PermissionRequired: true, + }, + OpResolveAlarm: { + Name: "resolve", + PermissionRequired: true, + }, + OpUpdateAlarm: { + Name: "update", + PermissionRequired: true, + }, + } +} diff --git a/alarms/postgres/alarms.go b/alarms/postgres/alarms.go new file mode 100644 index 000000000..888748146 --- /dev/null +++ b/alarms/postgres/alarms.go @@ -0,0 +1,518 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "math" + "strings" + "time" + + "github.com/absmach/supermq/alarms" + api "github.com/absmach/supermq/api/http" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" +) + +const alarmColumns = `alarms.id, alarms.rule_id, alarms.domain_id, alarms.channel_id, alarms.client_id, alarms.subtopic, alarms.measurement, alarms.value, alarms.unit, +alarms.threshold, alarms.cause, alarms.status, alarms.severity, alarms.assignee_id, alarms.created_at, alarms.updated_at, alarms.updated_by, alarms.assigned_at, +alarms.assigned_by, alarms.acknowledged_at, alarms.acknowledged_by, alarms.resolved_at, alarms.resolved_by, alarms.metadata` + +type repository struct { + db *sqlx.DB +} + +var _ alarms.Repository = (*repository)(nil) + +func NewAlarmsRepo(db *sqlx.DB) alarms.Repository { + return &repository{db: db} +} + +func (r *repository) CreateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + query := ` + WITH existing AS ( + SELECT status, severity + FROM alarms + WHERE domain_id = :domain_id + AND rule_id = :rule_id + AND channel_id = :channel_id + AND client_id = :client_id + AND subtopic = :subtopic + AND measurement = :measurement + AND created_at <= :created_at + ORDER BY created_at DESC + LIMIT 1 + ) + INSERT INTO alarms ( + id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, + value, unit, threshold, cause, status, severity, assignee_id, + created_at, updated_at, updated_by, assigned_at, assigned_by, + acknowledged_at, acknowledged_by, resolved_at, resolved_by, metadata + ) + SELECT + :id, :rule_id, :domain_id, :channel_id, :client_id, :subtopic, :measurement, + :value, :unit, :threshold, :cause, :status, :severity, :assignee_id, + :created_at, :updated_at, :updated_by, :assigned_at, :assigned_by, + :acknowledged_at, :acknowledged_by, :resolved_at, :resolved_by, :metadata + WHERE ( + EXISTS ( + SELECT 1 FROM existing + WHERE existing.status IS DISTINCT FROM :status + OR (:status = 0 AND existing.status = 0 AND existing.severity IS DISTINCT FROM :severity) + ) + OR ( + NOT EXISTS (SELECT 1 FROM existing) AND :status = 0 + ) + ) + RETURNING + id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, + value, unit, threshold, cause, status, severity, created_at, + assignee_id, updated_at, updated_by, assigned_at, assigned_by, + acknowledged_at, acknowledged_by, resolved_at, resolved_by, metadata + ; + ` + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, query, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) UpdateAlarm(ctx context.Context, alarm alarms.Alarm) (alarms.Alarm, error) { + var query []string + var upq string + if alarm.Status != 0 { + query = append(query, "status = :status,") + } + if alarm.AssigneeID != "" { + query = append(query, "assignee_id = :assignee_id,") + } + if !alarm.AssignedAt.IsZero() { + query = append(query, "assigned_at = :assigned_at,") + } + if alarm.AssignedBy != "" { + query = append(query, "assigned_by = :assigned_by,") + } + if alarm.AcknowledgedBy != "" { + query = append(query, "acknowledged_by = :acknowledged_by,") + } + if !alarm.AcknowledgedAt.IsZero() { + query = append(query, "acknowledged_at = :acknowledged_at,") + } + if alarm.ResolvedBy != "" { + query = append(query, "resolved_by = :resolved_by,") + } + if !alarm.ResolvedAt.IsZero() { + query = append(query, "resolved_at = :resolved_at,") + } + if alarm.Metadata != nil { + query = append(query, "metadata = :metadata,") + } + if len(query) > 0 { + upq = strings.Join(query, " ") + } + + q := fmt.Sprintf(`UPDATE alarms SET %s updated_by = :updated_by, updated_at = :updated_at WHERE id = :id + RETURNING id, rule_id, domain_id, channel_id, client_id, subtopic, measurement, value, unit, threshold, + cause, status, severity, assignee_id, assigned_at, assigned_by, acknowledged_at, acknowledged_by, + resolved_by, resolved_at, metadata, created_at, updated_by, updated_at;`, upq) + + dba, err := toDBAlarm(alarm) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + row, err := r.db.NamedQueryContext(ctx, q, dba) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba = dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return toAlarm(dba) +} + +func (r *repository) ViewAlarm(ctx context.Context, alarmID, domainID string) (alarms.Alarm, error) { + query := `SELECT * FROM alarms WHERE id = :id AND domain_id = :domain_id;` + row, err := r.db.NamedQueryContext(ctx, query, map[string]any{ + "id": alarmID, "domain_id": domainID, + }) + if err != nil { + return alarms.Alarm{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer row.Close() + + if !row.Next() { + return alarms.Alarm{}, repoerr.ErrNotFound + } + + dba := dbAlarm{} + if err := row.StructScan(&dba); err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + alarm, err := toAlarm(dba) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarm, nil +} + +func (r *repository) ListAllAlarms(ctx context.Context, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + query, err := pageQuery(pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + comQuery := fmt.Sprintf(`SELECT %s FROM alarms %s`, alarmColumns, query) + + return r.alarmsPage(ctx, comQuery, pm) +} + +func (r *repository) ListUserAlarms(ctx context.Context, userID string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + query, err := pageQuery(pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + pm.UserID = userID + comQuery := fmt.Sprintf(`SELECT DISTINCT %s + FROM alarms + INNER JOIN rules_roles rr ON rr.entity_id = alarms.rule_id + INNER JOIN rules_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id + %s`, alarmColumns, query) + + return r.alarmsPage(ctx, comQuery, pm) +} + +func (r *repository) alarmsPage(ctx context.Context, comQuery string, pm alarms.PageMetadata) (alarms.AlarmsPage, error) { + dir := api.DescDir + if pm.Dir == api.AscDir { + dir = api.AscDir + } + + var orderClause string + switch pm.Order { + case api.CreatedAtOrder: + orderClause = fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) + default: + orderClause = fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) + } + + q := fmt.Sprintf(`SELECT * FROM (%s) AS sub_query %s LIMIT :limit OFFSET :offset;`, comQuery, orderClause) + cq := fmt.Sprintf(`SELECT COUNT(*) AS total_count FROM (%s) AS sub_query;`, comQuery) + + rows, err := r.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []alarms.Alarm + for rows.Next() { + dba := dbAlarm{} + if err := rows.StructScan(&dba); err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + a, err := toAlarm(dba) + if err != nil { + return alarms.AlarmsPage{}, err + } + + items = append(items, a) + } + + total, err := postgres.Total(ctx, r.db, cq, pm) + if err != nil { + return alarms.AlarmsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return alarms.AlarmsPage{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Alarms: items, + }, nil +} + +func (r *repository) DeleteAlarm(ctx context.Context, id string) error { + query := `DELETE FROM alarms WHERE id = :id;` + result, err := r.db.NamedExecContext(ctx, query, map[string]any{"id": id}) + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if rowsAffected == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +type dbAlarm struct { + ID string `db:"id"` + RuleID string `db:"rule_id"` + DomainID string `db:"domain_id"` + ChannelID string `db:"channel_id"` + ClientID string `db:"client_id"` + Subtopic string `db:"subtopic"` + Measurement string `db:"measurement"` + Value string `db:"value"` + Unit string `db:"unit"` + Cause string `db:"cause"` + Threshold string `db:"threshold"` + Status alarms.Status `db:"status"` + Severity uint8 `db:"severity"` + AssigneeID string `db:"assignee_id"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy *string `db:"updated_by,omitempty"` + AssignedAt sql.NullTime `db:"assigned_at,omitempty"` + AssignedBy *string `db:"assigned_by,omitempty"` + AcknowledgedAt sql.NullTime `db:"acknowledged_at,omitempty"` + AcknowledgedBy *string `db:"acknowledged_by,omitempty"` + ResolvedAt sql.NullTime `db:"resolved_at,omitempty"` + ResolvedBy *string `db:"resolved_by,omitempty"` + Metadata []byte `db:"metadata,omitempty"` +} + +func toDBAlarm(a alarms.Alarm) (dbAlarm, error) { + if a.CreatedAt.IsZero() { + a.CreatedAt = time.Now() + } + var updatedBy *string + if a.UpdatedBy != "" { + updatedBy = &a.UpdatedBy + } + var updatedAt sql.NullTime + if a.UpdatedAt != (time.Time{}) { + updatedAt = sql.NullTime{Time: a.UpdatedAt, Valid: true} + } + + var acknowledgedBy *string + if a.AcknowledgedBy != "" { + acknowledgedBy = &a.AcknowledgedBy + } + var acknowledgedAt sql.NullTime + if a.AcknowledgedAt != (time.Time{}) { + acknowledgedAt = sql.NullTime{Time: a.AcknowledgedAt, Valid: true} + } + + var resolvedBy *string + if a.ResolvedBy != "" { + resolvedBy = &a.ResolvedBy + } + var resolvedAt sql.NullTime + if a.ResolvedAt != (time.Time{}) { + resolvedAt = sql.NullTime{Time: a.ResolvedAt, Valid: true} + } + + var assignedBy *string + if a.AssignedBy != "" { + assignedBy = &a.AssignedBy + } + var assignedAt sql.NullTime + if a.AssignedAt != (time.Time{}) { + assignedAt = sql.NullTime{Time: a.AssignedAt, Valid: true} + } + + metadata := []byte("{}") + if len(a.Metadata) > 0 { + b, err := json.Marshal(a.Metadata) + if err != nil { + return dbAlarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + metadata = b + } + + return dbAlarm{ + ID: a.ID, + RuleID: a.RuleID, + DomainID: a.DomainID, + ChannelID: a.ChannelID, + ClientID: a.ClientID, + Subtopic: a.Subtopic, + Measurement: a.Measurement, + Value: a.Value, + Unit: a.Unit, + Cause: a.Cause, + Threshold: a.Threshold, + Status: a.Status, + Severity: a.Severity, + AssigneeID: a.AssigneeID, + CreatedAt: a.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func toAlarm(dbr dbAlarm) (alarms.Alarm, error) { + var updatedBy string + if dbr.UpdatedBy != nil { + updatedBy = *dbr.UpdatedBy + } + var updatedAt time.Time + if dbr.UpdatedAt.Valid { + updatedAt = dbr.UpdatedAt.Time + } + + var assignedBy string + if dbr.AssignedBy != nil { + assignedBy = *dbr.AssignedBy + } + var assignedAt time.Time + if dbr.AssignedAt.Valid { + assignedAt = dbr.AssignedAt.Time + } + + var acknowledgedBy string + if dbr.AcknowledgedBy != nil { + acknowledgedBy = *dbr.AcknowledgedBy + } + var acknowledgedAt time.Time + if dbr.AcknowledgedAt.Valid { + acknowledgedAt = dbr.AcknowledgedAt.Time + } + + var resolvedBy string + if dbr.ResolvedBy != nil { + resolvedBy = *dbr.ResolvedBy + } + var resolvedAt time.Time + if dbr.ResolvedAt.Valid { + resolvedAt = dbr.ResolvedAt.Time + } + + var metadata map[string]any + if len(dbr.Metadata) > 0 { + err := json.Unmarshal(dbr.Metadata, &metadata) + if err != nil { + return alarms.Alarm{}, errors.Wrap(repoerr.ErrMalformedEntity, err) + } + } + + return alarms.Alarm{ + ID: dbr.ID, + RuleID: dbr.RuleID, + DomainID: dbr.DomainID, + ChannelID: dbr.ChannelID, + ClientID: dbr.ClientID, + Subtopic: dbr.Subtopic, + Measurement: dbr.Measurement, + Value: dbr.Value, + Unit: dbr.Unit, + Threshold: dbr.Threshold, + Cause: dbr.Cause, + Status: dbr.Status, + Severity: dbr.Severity, + AssigneeID: dbr.AssigneeID, + CreatedAt: dbr.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + AssignedAt: assignedAt, + AssignedBy: assignedBy, + AcknowledgedAt: acknowledgedAt, + AcknowledgedBy: acknowledgedBy, + ResolvedAt: resolvedAt, + ResolvedBy: resolvedBy, + Metadata: metadata, + }, nil +} + +func pageQuery(pm alarms.PageMetadata) (string, error) { + var query []string + if pm.DomainID != "" { + query = append(query, "alarms.domain_id = :domain_id") + } + if pm.RuleID != "" { + query = append(query, "alarms.rule_id = :rule_id") + } + if pm.ChannelID != "" { + query = append(query, "alarms.channel_id = :channel_id") + } + if pm.Subtopic != "" { + query = append(query, "alarms.subtopic = :subtopic") + } + if pm.ClientID != "" { + query = append(query, "alarms.client_id = :client_id") + } + if pm.Measurement != "" { + query = append(query, "alarms.measurement = :measurement") + } + if pm.Status != alarms.AllStatus { + query = append(query, "alarms.status = :status") + } + if pm.Severity != math.MaxUint8 { + query = append(query, "alarms.severity = :severity") + } + if pm.AssigneeID != "" { + query = append(query, "alarms.assignee_id = :assignee_id") + } + if pm.UpdatedBy != "" { + query = append(query, "alarms.updated_by = :updated_by") + } + if pm.ResolvedBy != "" { + query = append(query, "alarms.resolved_by = :resolved_by") + } + if pm.AcknowledgedBy != "" { + query = append(query, "alarms.acknowledged_by = :acknowledged_by") + } + if pm.AssignedBy != "" { + query = append(query, "alarms.assigned_by = :assigned_by") + } + if !pm.CreatedFrom.IsZero() { + query = append(query, "alarms.created_at >= :created_from") + } + if !pm.CreatedTo.IsZero() { + query = append(query, "alarms.created_at <= :created_to") + } + + var emq string + if len(query) > 0 { + emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return emq, nil +} diff --git a/alarms/postgres/alarms_test.go b/alarms/postgres/alarms_test.go new file mode 100644 index 000000000..53167cebb --- /dev/null +++ b/alarms/postgres/alarms_test.go @@ -0,0 +1,659 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/alarms/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func TestCreateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(t), + RuleID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + } + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarm, + err: nil, + }, + { + desc: "duplicate alarm", + alarm: alarm, + err: repoerr.ErrNotFound, + }, + { + desc: "missing rule id", + alarm: alarms.Alarm{ + ID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + + Metadata: map[string]any{ + "key": "value", + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Subtopic: namegen.Generate(), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + + Metadata: map[string]any{ + "key": make(chan int), + }, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.CreateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.NotEmpty(t, alarm.ID) + assert.Equal(t, tc.alarm.RuleID, alarm.RuleID) + assert.Equal(t, tc.alarm.Measurement, alarm.Measurement) + assert.Equal(t, tc.alarm.Value, alarm.Value) + assert.Equal(t, tc.alarm.Unit, alarm.Unit) + assert.Equal(t, tc.alarm.Cause, alarm.Cause) + assert.Equal(t, tc.alarm.Status, alarm.Status) + assert.Equal(t, tc.alarm.DomainID, alarm.DomainID) + assert.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + assert.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(t), + RuleID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + Status: alarms.ClearedStatus, + DomainID: alarm.DomainID, + AssigneeID: generateUUID(t), + AssignedBy: generateUUID(t), + AssignedAt: time.Now().UTC(), + AcknowledgedBy: generateUUID(t), + AcknowledgedAt: time.Now().UTC(), + CreatedAt: alarm.CreatedAt, + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + ResolvedAt: time.Now().UTC(), + ResolvedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + ID: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "invalid alarm", + alarm: alarms.Alarm{ + ID: alarm.ID, + RuleID: generateUUID(t), + Status: 0, + DomainID: generateUUID(t), + AssigneeID: strings.Repeat("a", 40), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "empty alarm", + alarm: alarms.Alarm{}, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.UpdateAlarm(context.Background(), tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.NotEmpty(t, alarm.ID) + assert.Equal(t, tc.alarm.Status, alarm.Status) + assert.Equal(t, tc.alarm.DomainID, alarm.DomainID) + assert.Equal(t, tc.alarm.AssigneeID, alarm.AssigneeID) + assert.Equal(t, tc.alarm.UpdatedBy, alarm.UpdatedBy) + assert.Equal(t, tc.alarm.ResolvedBy, alarm.ResolvedBy) + assert.Equal(t, tc.alarm.AcknowledgedBy, alarm.AcknowledgedBy) + assert.Equal(t, tc.alarm.Metadata, alarm.Metadata) + }) + } +} + +func TestViewAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(t), + RuleID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + domainID: alarm.DomainID, + err: nil, + }, + { + desc: "non existing alarm id", + id: generateUUID(t), + domainID: alarm.DomainID, + err: repoerr.ErrNotFound, + }, + { + desc: "non existing domain id", + id: alarm.ID, + domainID: generateUUID(t), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarm, err := repo.ViewAlarm(context.Background(), tc.id, tc.domainID) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.NotEmpty(t, alarm.ID) + assert.Equal(t, tc.id, alarm.ID) + }) + } +} + +func TestListAlarms(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + repo := postgres.NewAlarmsRepo(db) + items := make([]alarms.Alarm, 1000) + for i := range 1000 { + items[i] = alarms.Alarm{ + ID: generateUUID(t), + RuleID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), items[i]) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + items[i].ID = alarm.ID + } + + cases := []struct { + desc string + pm alarms.PageMetadata + response []alarms.Alarm + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + response: items[:10], + err: nil, + }, + { + desc: "offset and limit", + pm: alarms.PageMetadata{ + Offset: 10, + Limit: 50, + }, + response: items[10:60], + err: nil, + }, + { + desc: "empty page", + pm: alarms.PageMetadata{}, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid page", + pm: alarms.PageMetadata{ + Offset: 1000, + Limit: 10, + }, + response: []alarms.Alarm{}, + err: nil, + }, + { + desc: "invalid assignee id", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + AssigneeID: generateUUID(t), + }, + response: []alarms.Alarm{}, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + alarms, err := repo.ListAllAlarms(context.Background(), tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, len(tc.response), len(alarms.Alarms)) + }) + } +} + +func TestListUserAlarms(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM rules") + require.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + domainID := generateUUID(t) + userID := generateUUID(t) + otherUserID := generateUUID(t) + adminUserID := generateUUID(t) + + // Create 10 rules and 10 alarms referencing them. + // Assign userID to the first 6 rules via role membership. + var ruleIDs []string + var createdAlarms []alarms.Alarm + for i := range 10 { + ruleID := generateUUID(t) + _, err := db.Exec(`INSERT INTO rules (id, name, domain_id, status, logic_type, logic_value) VALUES ($1, $2, $3, 0, 0, '')`, + ruleID, fmt.Sprintf("rule-%d", i), domainID) + require.Nil(t, err, fmt.Sprintf("insert rule unexpected error: %s", err)) + ruleIDs = append(ruleIDs, ruleID) + + alarm := alarms.Alarm{ + ID: generateUUID(t), + RuleID: ruleID, + DomainID: domainID, + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute), + } + alarm, err = repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + createdAlarms = append(createdAlarms, alarm) + } + + // Assign userID to the first 6 rules via rules_roles + rules_role_members. + userRoleIDs := make([]string, 6) + for i := range 6 { + roleID := generateUUID(t) + userRoleIDs[i] = roleID + _, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i]) + require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err)) + _, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, ruleIDs[i]) + require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err)) + } + + for i := range 10 { + var roleID string + if i < 6 { + roleID = userRoleIDs[i] + } else { + roleID = generateUUID(t) + _, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", ruleIDs[i]) + require.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err)) + } + _, err := db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, adminUserID, ruleIDs[i]) + require.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err)) + } + + _ = createdAlarms + + cases := []struct { + desc string + userID string + pm alarms.PageMetadata + count int + err error + }{ + { + desc: "list user alarms returns only accessible alarms", + userID: userID, + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 100, + }, + count: 6, + err: nil, + }, + { + desc: "list user alarms with limit", + userID: userID, + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 3, + }, + count: 3, + err: nil, + }, + { + desc: "list user alarms with offset", + userID: userID, + pm: alarms.PageMetadata{ + Offset: 4, + Limit: 100, + }, + count: 2, + err: nil, + }, + { + desc: "list user alarms with domain filter", + userID: userID, + pm: alarms.PageMetadata{ + DomainID: domainID, + Offset: 0, + Limit: 100, + }, + count: 6, + err: nil, + }, + { + desc: "list user alarms with non-existing domain returns 0", + userID: userID, + pm: alarms.PageMetadata{ + DomainID: generateUUID(t), + Offset: 0, + Limit: 100, + }, + count: 0, + err: nil, + }, + { + desc: "list alarms for user with no role assignments returns 0", + userID: otherUserID, + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 100, + }, + count: 0, + err: nil, + }, + { + desc: "list alarms for admin user with role on all rules returns all alarms", + userID: adminUserID, + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 100, + }, + count: 10, + err: nil, + }, + { + desc: "list user alarms ordered by created_at ascending", + userID: userID, + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 100, + Order: "created_at", + Dir: "asc", + }, + count: 6, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListUserAlarms(context.Background(), tc.userID, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, tc.count, len(page.Alarms), fmt.Sprintf("%s: expected %d alarms, got %d", tc.desc, tc.count, len(page.Alarms))) + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM alarms") + require.Nil(t, err, fmt.Sprintf("clean alarms unexpected error: %s", err)) + }) + + repo := postgres.NewAlarmsRepo(db) + + alarm := alarms.Alarm{ + ID: generateUUID(t), + RuleID: generateUUID(t), + DomainID: generateUUID(t), + ChannelID: generateUUID(t), + ClientID: generateUUID(t), + Measurement: namegen.Generate(), + Value: namegen.Generate(), + Unit: namegen.Generate(), + Threshold: namegen.Generate(), + Cause: namegen.Generate(), + Status: 0, + AssigneeID: generateUUID(t), + CreatedAt: time.Now().UTC(), + Metadata: map[string]any{ + "key": "value", + }, + } + alarm, err := repo.CreateAlarm(context.Background(), alarm) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: alarm.ID, + err: nil, + }, + { + desc: "non existing alarm", + id: generateUUID(t), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.DeleteAlarm(context.Background(), tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} + +func generateUUID(t *testing.T) string { + ulid, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + return ulid +} diff --git a/alarms/postgres/init.go b/alarms/postgres/init.go new file mode 100644 index 000000000..8e7389c03 --- /dev/null +++ b/alarms/postgres/init.go @@ -0,0 +1,65 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + rpostgres "github.com/absmach/supermq/re/postgres" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + migrate "github.com/rubenv/sql-migrate" +) + +// Migration of Alarms service. +func Migration() (*migrate.MemoryMigrationSource, error) { + alarmsMigration := &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "alarms_01", + // VARCHAR(36) for columns with IDs as UUIDS have a maximum of 36 characters + Up: []string{ + `CREATE TABLE IF NOT EXISTS alarms ( + id VARCHAR(36) PRIMARY KEY, + rule_id VARCHAR(36) NOT NULL CHECK (length(rule_id) > 0), + domain_id VARCHAR(36) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + subtopic TEXT NOT NULL, + client_id VARCHAR(36) NOT NULL, + measurement TEXT NOT NULL, + value TEXT NOT NULL, + unit TEXT NOT NULL, + threshold TEXT NOT NULL, + cause TEXT NOT NULL, + status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0), + severity SMALLINT NOT NULL DEFAULT 0 CHECK (severity >= 0), + assignee_id VARCHAR(36), + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NULL, + updated_by VARCHAR(36) NULL, + assigned_at TIMESTAMPTZ NULL, + assigned_by VARCHAR(36) NULL, + acknowledged_at TIMESTAMPTZ NULL, + acknowledged_by VARCHAR(36) NULL, + resolved_at TIMESTAMPTZ NULL, + resolved_by VARCHAR(36) NULL, + metadata JSONB + );`, + "CREATE INDEX IF NOT EXISTS idx_alarms_state ON alarms (domain_id, rule_id, channel_id, subtopic, client_id, measurement, created_at DESC);", + }, + Down: []string{ + `DROP TABLE IF EXISTS alarms`, + }, + }, + }, + } + + rulesMigration, err := rpostgres.Migration() + if err != nil { + return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err) + } + + alarmsMigration.Migrations = append(alarmsMigration.Migrations, rulesMigration.Migrations...) + + return alarmsMigration, nil +} diff --git a/alarms/postgres/setup_test.go b/alarms/postgres/setup_test.go new file mode 100644 index 000000000..51d35d344 --- /dev/null +++ b/alarms/postgres/setup_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + apostgres "github.com/absmach/supermq/alarms/postgres" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + migration, err := apostgres.Migration() + if err != nil { + log.Fatalf("Could not get migration: %s", err) + } + if db, err = postgres.Setup(dbConfig, *migration); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/alarms/service.go b/alarms/service.go new file mode 100644 index 000000000..75c44abed --- /dev/null +++ b/alarms/service.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "context" + "time" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/pkg/authn" + repoerr "github.com/absmach/supermq/pkg/errors/repository" +) + +type service struct { + idp supermq.IDProvider + repo Repository +} + +var _ Service = (*service)(nil) + +func NewService(idp supermq.IDProvider, repo Repository) Service { + return &service{ + idp: idp, + repo: repo, + } +} + +func (s *service) CreateAlarm(ctx context.Context, alarm Alarm) error { + id, err := s.idp.ID() + if err != nil { + return err + } + alarm.ID = id + if alarm.CreatedAt.IsZero() { + alarm.CreatedAt = time.Now() + } + + if err := alarm.Validate(); err != nil { + return err + } + + if _, err = s.repo.CreateAlarm(ctx, alarm); err != nil && err != repoerr.ErrNotFound { + return err + } + + return nil +} + +func (s *service) ViewAlarm(ctx context.Context, session authn.Session, alarmID string) (Alarm, error) { + return s.repo.ViewAlarm(ctx, alarmID, session.DomainID) +} + +func (s *service) ListAlarms(ctx context.Context, session authn.Session, pm PageMetadata) (AlarmsPage, error) { + if session.SuperAdmin { + return s.repo.ListAllAlarms(ctx, pm) + } + return s.repo.ListUserAlarms(ctx, session.UserID, pm) +} + +func (s *service) DeleteAlarm(ctx context.Context, session authn.Session, alarmID string) error { + return s.repo.DeleteAlarm(ctx, alarmID) +} + +func (s *service) UpdateAlarm(ctx context.Context, session authn.Session, alarm Alarm) (Alarm, error) { + alarm.UpdatedAt = time.Now() + alarm.UpdatedBy = session.UserID + + return s.repo.UpdateAlarm(ctx, alarm) +} diff --git a/alarms/service_test.go b/alarms/service_test.go new file mode 100644 index 000000000..d7180ebb0 --- /dev/null +++ b/alarms/service_test.go @@ -0,0 +1,254 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/alarms/mocks" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var idp = uuid.New() + +func newService(t *testing.T, repo *mocks.Repository) alarms.Service { + return alarms.NewService(idp, repo) +} + +func TestCreateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := newService(t, repo) + ts := time.Now() + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + CreatedAt: ts, + }, + err: nil, + }, + { + desc: "missing rule_id", + alarm: alarms.Alarm{ + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + CreatedAt: ts, + }, + err: errors.New("rule_id is required"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("CreateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + err := svc.CreateAlarm(context.Background(), tc.alarm) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestViewAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := newService(t, repo) + + cases := []struct { + desc string + id string + domainID string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + domainID: "domain-id", + err: nil, + }, + { + desc: "non existing alarm id", + id: "alarm-id", + domainID: "domain-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.domainID} + repoCall := repo.On("ViewAlarm", context.Background(), tc.id, tc.domainID).Return(alarms.Alarm{}, tc.err) + _, err := svc.ViewAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestUpdateAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := newService(t, repo) + + cases := []struct { + desc string + alarm alarms.Alarm + err error + }{ + { + desc: "valid alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: nil, + }, + { + desc: "non existing alarm", + alarm: alarms.Alarm{ + RuleID: "rule-id", + DomainID: "domain-id", + ChannelID: "channel-id", + ClientID: "client-id", + Subtopic: "subtopic", + Measurement: "measurement", + Value: "value", + Unit: "unit", + Cause: "cause", + Severity: 100, + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.alarm.DomainID} + repoCall := repo.On("UpdateAlarm", context.Background(), mock.Anything).Return(tc.alarm, tc.err) + _, err := svc.UpdateAlarm(context.Background(), s, tc.alarm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestListAlarms(t *testing.T) { + repo := new(mocks.Repository) + svc := newService(t, repo) + + cases := []struct { + desc string + pm alarms.PageMetadata + page alarms.AlarmsPage + err error + }{ + { + desc: "valid page", + pm: alarms.PageMetadata{ + Offset: 0, + Limit: 10, + }, + page: alarms.AlarmsPage{ + Offset: 0, + Limit: 10, + Total: 10, + Alarms: []alarms.Alarm{}, + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.pm.DomainID} + repoCall := repo.On("ListUserAlarms", context.Background(), s.UserID, tc.pm).Return(tc.page, tc.err) + _, err := svc.ListAlarms(context.Background(), s, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + repo := new(mocks.Repository) + svc := newService(t, repo) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid alarm", + id: "alarm-id", + err: nil, + }, + { + desc: "non existing alarm", + id: "alarm-id", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + s := authn.Session{DomainID: tc.id} + repoCall := repo.On("DeleteAlarm", context.Background(), tc.id).Return(tc.err) + err := svc.DeleteAlarm(context.Background(), s, tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + return + } + repoCall.Unset() + }) + } +} diff --git a/alarms/status.go b/alarms/status.go new file mode 100644 index 000000000..92e03157b --- /dev/null +++ b/alarms/status.go @@ -0,0 +1,70 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package alarms + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +type Status uint8 + +const ( + ActiveStatus Status = iota + ClearedStatus + + // AllStatus is used for querying purposes to list alarms irrespective + // of their status. It is never stored in the database as the actual + // Alarm status and should always be the largest value in this enumeration. + AllStatus +) + +const ( + Active = "active" + Cleared = "cleared" + Unknown = "unknown" + All = "all" +) + +// String converts alarm status to string literal. +func (s Status) String() string { + switch s { + case ActiveStatus: + return Active + case ClearedStatus: + return Cleared + default: + return Unknown + } +} + +// ToStatus converts string value to a valid Alarm status. +func ToStatus(status string) (Status, error) { + switch strings.ToLower(status) { + case Active: + return ActiveStatus, nil + case Cleared: + return ClearedStatus, nil + case All: + return AllStatus, nil + default: + return Status(0), svcerr.ErrInvalidStatus + } +} + +// Custom Marshaller for Alarm. +func (s Status) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +// Custom Unmarshaler for Alarm. +func (s *Status) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToStatus(str) + *s = val + + return err +} diff --git a/api/grpc/certs/v1/certs.pb.go b/api/grpc/certs/v1/certs.pb.go new file mode 100644 index 000000000..7879f6e6e --- /dev/null +++ b/api/grpc/certs/v1/certs.pb.go @@ -0,0 +1,228 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.0 +// source: certs/v1/certs.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + emptypb "google.golang.org/protobuf/types/known/emptypb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type EntityReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + SerialNumber string `protobuf:"bytes,1,opt,name=serial_number,json=serialNumber,proto3" json:"serial_number,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EntityReq) Reset() { + *x = EntityReq{} + mi := &file_certs_v1_certs_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EntityReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EntityReq) ProtoMessage() {} + +func (x *EntityReq) ProtoReflect() protoreflect.Message { + mi := &file_certs_v1_certs_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EntityReq.ProtoReflect.Descriptor instead. +func (*EntityReq) Descriptor() ([]byte, []int) { + return file_certs_v1_certs_proto_rawDescGZIP(), []int{0} +} + +func (x *EntityReq) GetSerialNumber() string { + if x != nil { + return x.SerialNumber + } + return "" +} + +type EntityRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + EntityId string `protobuf:"bytes,1,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EntityRes) Reset() { + *x = EntityRes{} + mi := &file_certs_v1_certs_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EntityRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EntityRes) ProtoMessage() {} + +func (x *EntityRes) ProtoReflect() protoreflect.Message { + mi := &file_certs_v1_certs_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EntityRes.ProtoReflect.Descriptor instead. +func (*EntityRes) Descriptor() ([]byte, []int) { + return file_certs_v1_certs_proto_rawDescGZIP(), []int{1} +} + +func (x *EntityRes) GetEntityId() string { + if x != nil { + return x.EntityId + } + return "" +} + +type RevokeReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + EntityId string `protobuf:"bytes,1,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeReq) Reset() { + *x = RevokeReq{} + mi := &file_certs_v1_certs_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeReq) ProtoMessage() {} + +func (x *RevokeReq) ProtoReflect() protoreflect.Message { + mi := &file_certs_v1_certs_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead. +func (*RevokeReq) Descriptor() ([]byte, []int) { + return file_certs_v1_certs_proto_rawDescGZIP(), []int{2} +} + +func (x *RevokeReq) GetEntityId() string { + if x != nil { + return x.EntityId + } + return "" +} + +var File_certs_v1_certs_proto protoreflect.FileDescriptor + +const file_certs_v1_certs_proto_rawDesc = "" + + "\n" + + "\x14certs/v1/certs.proto\x12\rabsmach.certs\x1a\x1bgoogle/protobuf/empty.proto\"0\n" + + "\tEntityReq\x12#\n" + + "\rserial_number\x18\x01 \x01(\tR\fserialNumber\"(\n" + + "\tEntityRes\x12\x1b\n" + + "\tentity_id\x18\x01 \x01(\tR\bentityId\"(\n" + + "\tRevokeReq\x12\x1b\n" + + "\tentity_id\x18\x01 \x01(\tR\bentityId2\x96\x01\n" + + "\fCertsService\x12C\n" + + "\vGetEntityID\x12\x18.absmach.certs.EntityReq\x1a\x18.absmach.certs.EntityRes\"\x00\x12A\n" + + "\vRevokeCerts\x12\x18.absmach.certs.RevokeReq\x1a\x16.google.protobuf.Empty\"\x00B.Z,github.com/absmach/supermq/api/grpc/certs/v1b\x06proto3" + +var ( + file_certs_v1_certs_proto_rawDescOnce sync.Once + file_certs_v1_certs_proto_rawDescData []byte +) + +func file_certs_v1_certs_proto_rawDescGZIP() []byte { + file_certs_v1_certs_proto_rawDescOnce.Do(func() { + file_certs_v1_certs_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_certs_v1_certs_proto_rawDesc), len(file_certs_v1_certs_proto_rawDesc))) + }) + return file_certs_v1_certs_proto_rawDescData +} + +var file_certs_v1_certs_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_certs_v1_certs_proto_goTypes = []any{ + (*EntityReq)(nil), // 0: absmach.certs.EntityReq + (*EntityRes)(nil), // 1: absmach.certs.EntityRes + (*RevokeReq)(nil), // 2: absmach.certs.RevokeReq + (*emptypb.Empty)(nil), // 3: google.protobuf.Empty +} +var file_certs_v1_certs_proto_depIdxs = []int32{ + 0, // 0: absmach.certs.CertsService.GetEntityID:input_type -> absmach.certs.EntityReq + 2, // 1: absmach.certs.CertsService.RevokeCerts:input_type -> absmach.certs.RevokeReq + 1, // 2: absmach.certs.CertsService.GetEntityID:output_type -> absmach.certs.EntityRes + 3, // 3: absmach.certs.CertsService.RevokeCerts:output_type -> google.protobuf.Empty + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_certs_v1_certs_proto_init() } +func file_certs_v1_certs_proto_init() { + if File_certs_v1_certs_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_certs_v1_certs_proto_rawDesc), len(file_certs_v1_certs_proto_rawDesc)), + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_certs_v1_certs_proto_goTypes, + DependencyIndexes: file_certs_v1_certs_proto_depIdxs, + MessageInfos: file_certs_v1_certs_proto_msgTypes, + }.Build() + File_certs_v1_certs_proto = out.File + file_certs_v1_certs_proto_goTypes = nil + file_certs_v1_certs_proto_depIdxs = nil +} diff --git a/api/grpc/certs/v1/certs_grpc.pb.go b/api/grpc/certs/v1/certs_grpc.pb.go new file mode 100644 index 000000000..8a5a05ca5 --- /dev/null +++ b/api/grpc/certs/v1/certs_grpc.pb.go @@ -0,0 +1,163 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.0 +// - protoc v6.33.0 +// source: certs/v1/certs.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" + emptypb "google.golang.org/protobuf/types/known/emptypb" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + CertsService_GetEntityID_FullMethodName = "/absmach.certs.CertsService/GetEntityID" + CertsService_RevokeCerts_FullMethodName = "/absmach.certs.CertsService/RevokeCerts" +) + +// CertsServiceClient is the client API for CertsService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type CertsServiceClient interface { + GetEntityID(ctx context.Context, in *EntityReq, opts ...grpc.CallOption) (*EntityRes, error) + RevokeCerts(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*emptypb.Empty, error) +} + +type certsServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewCertsServiceClient(cc grpc.ClientConnInterface) CertsServiceClient { + return &certsServiceClient{cc} +} + +func (c *certsServiceClient) GetEntityID(ctx context.Context, in *EntityReq, opts ...grpc.CallOption) (*EntityRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(EntityRes) + err := c.cc.Invoke(ctx, CertsService_GetEntityID_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *certsServiceClient) RevokeCerts(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*emptypb.Empty, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(emptypb.Empty) + err := c.cc.Invoke(ctx, CertsService_RevokeCerts_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// CertsServiceServer is the server API for CertsService service. +// All implementations must embed UnimplementedCertsServiceServer +// for forward compatibility. +type CertsServiceServer interface { + GetEntityID(context.Context, *EntityReq) (*EntityRes, error) + RevokeCerts(context.Context, *RevokeReq) (*emptypb.Empty, error) + mustEmbedUnimplementedCertsServiceServer() +} + +// UnimplementedCertsServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedCertsServiceServer struct{} + +func (UnimplementedCertsServiceServer) GetEntityID(context.Context, *EntityReq) (*EntityRes, error) { + return nil, status.Error(codes.Unimplemented, "method GetEntityID not implemented") +} +func (UnimplementedCertsServiceServer) RevokeCerts(context.Context, *RevokeReq) (*emptypb.Empty, error) { + return nil, status.Error(codes.Unimplemented, "method RevokeCerts not implemented") +} +func (UnimplementedCertsServiceServer) mustEmbedUnimplementedCertsServiceServer() {} +func (UnimplementedCertsServiceServer) testEmbeddedByValue() {} + +// UnsafeCertsServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to CertsServiceServer will +// result in compilation errors. +type UnsafeCertsServiceServer interface { + mustEmbedUnimplementedCertsServiceServer() +} + +func RegisterCertsServiceServer(s grpc.ServiceRegistrar, srv CertsServiceServer) { + // If the following call panics, it indicates UnimplementedCertsServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&CertsService_ServiceDesc, srv) +} + +func _CertsService_GetEntityID_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(EntityReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CertsServiceServer).GetEntityID(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: CertsService_GetEntityID_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CertsServiceServer).GetEntityID(ctx, req.(*EntityReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _CertsService_RevokeCerts_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CertsServiceServer).RevokeCerts(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: CertsService_RevokeCerts_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CertsServiceServer).RevokeCerts(ctx, req.(*RevokeReq)) + } + return interceptor(ctx, in, info, handler) +} + +// CertsService_ServiceDesc is the grpc.ServiceDesc for CertsService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var CertsService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "absmach.certs.CertsService", + HandlerType: (*CertsServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "GetEntityID", + Handler: _CertsService_GetEntityID_Handler, + }, + { + MethodName: "RevokeCerts", + Handler: _CertsService_RevokeCerts_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "certs/v1/certs.proto", +} diff --git a/api/grpc/readers/v1/readers.pb.go b/api/grpc/readers/v1/readers.pb.go new file mode 100644 index 000000000..7ede7c3cc --- /dev/null +++ b/api/grpc/readers/v1/readers.pb.go @@ -0,0 +1,873 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.0 +// source: readers/v1/readers.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Aggregation defines supported data aggregations. +type Aggregation int32 + +const ( + Aggregation_AGGREGATION_UNSPECIFIED Aggregation = 0 + Aggregation_AGGREGATION_MAX Aggregation = 1 + Aggregation_AGGREGATION_MIN Aggregation = 2 + Aggregation_AGGREGATION_SUM Aggregation = 3 + Aggregation_AGGREGATION_COUNT Aggregation = 4 + Aggregation_AGGREGATION_AVG Aggregation = 5 +) + +// Enum value maps for Aggregation. +var ( + Aggregation_name = map[int32]string{ + 0: "AGGREGATION_UNSPECIFIED", + 1: "AGGREGATION_MAX", + 2: "AGGREGATION_MIN", + 3: "AGGREGATION_SUM", + 4: "AGGREGATION_COUNT", + 5: "AGGREGATION_AVG", + } + Aggregation_value = map[string]int32{ + "AGGREGATION_UNSPECIFIED": 0, + "AGGREGATION_MAX": 1, + "AGGREGATION_MIN": 2, + "AGGREGATION_SUM": 3, + "AGGREGATION_COUNT": 4, + "AGGREGATION_AVG": 5, + } +) + +func (x Aggregation) Enum() *Aggregation { + p := new(Aggregation) + *p = x + return p +} + +func (x Aggregation) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (Aggregation) Descriptor() protoreflect.EnumDescriptor { + return file_readers_v1_readers_proto_enumTypes[0].Descriptor() +} + +func (Aggregation) Type() protoreflect.EnumType { + return &file_readers_v1_readers_proto_enumTypes[0] +} + +func (x Aggregation) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use Aggregation.Descriptor instead. +func (Aggregation) EnumDescriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{0} +} + +type PageMetadata struct { + state protoimpl.MessageState `protogen:"open.v1"` + Limit uint64 `protobuf:"varint,1,opt,name=limit,proto3" json:"limit,omitempty"` + Offset uint64 `protobuf:"varint,2,opt,name=offset,proto3" json:"offset,omitempty"` + Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + Name string `protobuf:"bytes,4,opt,name=name,proto3" json:"name,omitempty"` + Value float64 `protobuf:"fixed64,5,opt,name=value,proto3" json:"value,omitempty"` + Publisher string `protobuf:"bytes,6,opt,name=publisher,proto3" json:"publisher,omitempty"` + BoolValue bool `protobuf:"varint,7,opt,name=bool_value,json=boolValue,proto3" json:"bool_value,omitempty"` + StringValue string `protobuf:"bytes,8,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` + DataValue string `protobuf:"bytes,9,opt,name=data_value,json=dataValue,proto3" json:"data_value,omitempty"` + From float64 `protobuf:"fixed64,10,opt,name=from,proto3" json:"from,omitempty"` + To float64 `protobuf:"fixed64,11,opt,name=to,proto3" json:"to,omitempty"` + Subtopic string `protobuf:"bytes,12,opt,name=subtopic,proto3" json:"subtopic,omitempty"` + Interval string `protobuf:"bytes,13,opt,name=interval,proto3" json:"interval,omitempty"` + Read bool `protobuf:"varint,14,opt,name=read,proto3" json:"read,omitempty"` + Aggregation Aggregation `protobuf:"varint,15,opt,name=aggregation,proto3,enum=readers.v1.Aggregation" json:"aggregation,omitempty"` + Comparator string `protobuf:"bytes,16,opt,name=comparator,proto3" json:"comparator,omitempty"` + Format string `protobuf:"bytes,17,opt,name=format,proto3" json:"format,omitempty"` + Order string `protobuf:"bytes,18,opt,name=order,proto3" json:"order,omitempty"` + Dir string `protobuf:"bytes,19,opt,name=dir,proto3" json:"dir,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PageMetadata) Reset() { + *x = PageMetadata{} + mi := &file_readers_v1_readers_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PageMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PageMetadata) ProtoMessage() {} + +func (x *PageMetadata) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PageMetadata.ProtoReflect.Descriptor instead. +func (*PageMetadata) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{0} +} + +func (x *PageMetadata) GetLimit() uint64 { + if x != nil { + return x.Limit + } + return 0 +} + +func (x *PageMetadata) GetOffset() uint64 { + if x != nil { + return x.Offset + } + return 0 +} + +func (x *PageMetadata) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *PageMetadata) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *PageMetadata) GetValue() float64 { + if x != nil { + return x.Value + } + return 0 +} + +func (x *PageMetadata) GetPublisher() string { + if x != nil { + return x.Publisher + } + return "" +} + +func (x *PageMetadata) GetBoolValue() bool { + if x != nil { + return x.BoolValue + } + return false +} + +func (x *PageMetadata) GetStringValue() string { + if x != nil { + return x.StringValue + } + return "" +} + +func (x *PageMetadata) GetDataValue() string { + if x != nil { + return x.DataValue + } + return "" +} + +func (x *PageMetadata) GetFrom() float64 { + if x != nil { + return x.From + } + return 0 +} + +func (x *PageMetadata) GetTo() float64 { + if x != nil { + return x.To + } + return 0 +} + +func (x *PageMetadata) GetSubtopic() string { + if x != nil { + return x.Subtopic + } + return "" +} + +func (x *PageMetadata) GetInterval() string { + if x != nil { + return x.Interval + } + return "" +} + +func (x *PageMetadata) GetRead() bool { + if x != nil { + return x.Read + } + return false +} + +func (x *PageMetadata) GetAggregation() Aggregation { + if x != nil { + return x.Aggregation + } + return Aggregation_AGGREGATION_UNSPECIFIED +} + +func (x *PageMetadata) GetComparator() string { + if x != nil { + return x.Comparator + } + return "" +} + +func (x *PageMetadata) GetFormat() string { + if x != nil { + return x.Format + } + return "" +} + +func (x *PageMetadata) GetOrder() string { + if x != nil { + return x.Order + } + return "" +} + +func (x *PageMetadata) GetDir() string { + if x != nil { + return x.Dir + } + return "" +} + +type ReadMessagesRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + Total uint64 `protobuf:"varint,1,opt,name=total,proto3" json:"total,omitempty"` + PageMetadata *PageMetadata `protobuf:"bytes,2,opt,name=page_metadata,json=pageMetadata,proto3" json:"page_metadata,omitempty"` + Messages []*Message `protobuf:"bytes,3,rep,name=messages,proto3" json:"messages,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReadMessagesRes) Reset() { + *x = ReadMessagesRes{} + mi := &file_readers_v1_readers_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReadMessagesRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReadMessagesRes) ProtoMessage() {} + +func (x *ReadMessagesRes) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReadMessagesRes.ProtoReflect.Descriptor instead. +func (*ReadMessagesRes) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{1} +} + +func (x *ReadMessagesRes) GetTotal() uint64 { + if x != nil { + return x.Total + } + return 0 +} + +func (x *ReadMessagesRes) GetPageMetadata() *PageMetadata { + if x != nil { + return x.PageMetadata + } + return nil +} + +func (x *ReadMessagesRes) GetMessages() []*Message { + if x != nil { + return x.Messages + } + return nil +} + +type Message struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Types that are valid to be assigned to Payload: + // + // *Message_Senml + // *Message_Json + Payload isMessage_Payload `protobuf_oneof:"payload"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Message) Reset() { + *x = Message{} + mi := &file_readers_v1_readers_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Message) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Message) ProtoMessage() {} + +func (x *Message) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Message.ProtoReflect.Descriptor instead. +func (*Message) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{2} +} + +func (x *Message) GetPayload() isMessage_Payload { + if x != nil { + return x.Payload + } + return nil +} + +func (x *Message) GetSenml() *SenMLMessage { + if x != nil { + if x, ok := x.Payload.(*Message_Senml); ok { + return x.Senml + } + } + return nil +} + +func (x *Message) GetJson() *JsonMessage { + if x != nil { + if x, ok := x.Payload.(*Message_Json); ok { + return x.Json + } + } + return nil +} + +type isMessage_Payload interface { + isMessage_Payload() +} + +type Message_Senml struct { + Senml *SenMLMessage `protobuf:"bytes,1,opt,name=senml,proto3,oneof"` +} + +type Message_Json struct { + Json *JsonMessage `protobuf:"bytes,2,opt,name=json,proto3,oneof"` +} + +func (*Message_Senml) isMessage_Payload() {} + +func (*Message_Json) isMessage_Payload() {} + +type BaseMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Channel string `protobuf:"bytes,1,opt,name=channel,proto3" json:"channel,omitempty"` + Subtopic string `protobuf:"bytes,2,opt,name=subtopic,proto3" json:"subtopic,omitempty"` + Publisher string `protobuf:"bytes,3,opt,name=publisher,proto3" json:"publisher,omitempty"` + Protocol string `protobuf:"bytes,4,opt,name=protocol,proto3" json:"protocol,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BaseMessage) Reset() { + *x = BaseMessage{} + mi := &file_readers_v1_readers_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BaseMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BaseMessage) ProtoMessage() {} + +func (x *BaseMessage) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BaseMessage.ProtoReflect.Descriptor instead. +func (*BaseMessage) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{3} +} + +func (x *BaseMessage) GetChannel() string { + if x != nil { + return x.Channel + } + return "" +} + +func (x *BaseMessage) GetSubtopic() string { + if x != nil { + return x.Subtopic + } + return "" +} + +func (x *BaseMessage) GetPublisher() string { + if x != nil { + return x.Publisher + } + return "" +} + +func (x *BaseMessage) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +type SenMLMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Base *BaseMessage `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Unit string `protobuf:"bytes,3,opt,name=unit,proto3" json:"unit,omitempty"` + Time float64 `protobuf:"fixed64,4,opt,name=time,proto3" json:"time,omitempty"` + UpdateTime float64 `protobuf:"fixed64,5,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"` + Value *float64 `protobuf:"fixed64,6,opt,name=value,proto3,oneof" json:"value,omitempty"` + StringValue *string `protobuf:"bytes,7,opt,name=string_value,json=stringValue,proto3,oneof" json:"string_value,omitempty"` + DataValue *string `protobuf:"bytes,8,opt,name=data_value,json=dataValue,proto3,oneof" json:"data_value,omitempty"` + BoolValue *bool `protobuf:"varint,9,opt,name=bool_value,json=boolValue,proto3,oneof" json:"bool_value,omitempty"` + Sum *float64 `protobuf:"fixed64,10,opt,name=sum,proto3,oneof" json:"sum,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SenMLMessage) Reset() { + *x = SenMLMessage{} + mi := &file_readers_v1_readers_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SenMLMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SenMLMessage) ProtoMessage() {} + +func (x *SenMLMessage) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SenMLMessage.ProtoReflect.Descriptor instead. +func (*SenMLMessage) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{4} +} + +func (x *SenMLMessage) GetBase() *BaseMessage { + if x != nil { + return x.Base + } + return nil +} + +func (x *SenMLMessage) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *SenMLMessage) GetUnit() string { + if x != nil { + return x.Unit + } + return "" +} + +func (x *SenMLMessage) GetTime() float64 { + if x != nil { + return x.Time + } + return 0 +} + +func (x *SenMLMessage) GetUpdateTime() float64 { + if x != nil { + return x.UpdateTime + } + return 0 +} + +func (x *SenMLMessage) GetValue() float64 { + if x != nil && x.Value != nil { + return *x.Value + } + return 0 +} + +func (x *SenMLMessage) GetStringValue() string { + if x != nil && x.StringValue != nil { + return *x.StringValue + } + return "" +} + +func (x *SenMLMessage) GetDataValue() string { + if x != nil && x.DataValue != nil { + return *x.DataValue + } + return "" +} + +func (x *SenMLMessage) GetBoolValue() bool { + if x != nil && x.BoolValue != nil { + return *x.BoolValue + } + return false +} + +func (x *SenMLMessage) GetSum() float64 { + if x != nil && x.Sum != nil { + return *x.Sum + } + return 0 +} + +type JsonMessage struct { + state protoimpl.MessageState `protogen:"open.v1"` + Base *BaseMessage `protobuf:"bytes,1,opt,name=base,proto3" json:"base,omitempty"` + Created int64 `protobuf:"varint,2,opt,name=created,proto3" json:"created,omitempty"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3" json:"payload,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *JsonMessage) Reset() { + *x = JsonMessage{} + mi := &file_readers_v1_readers_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *JsonMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*JsonMessage) ProtoMessage() {} + +func (x *JsonMessage) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use JsonMessage.ProtoReflect.Descriptor instead. +func (*JsonMessage) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{5} +} + +func (x *JsonMessage) GetBase() *BaseMessage { + if x != nil { + return x.Base + } + return nil +} + +func (x *JsonMessage) GetCreated() int64 { + if x != nil { + return x.Created + } + return 0 +} + +func (x *JsonMessage) GetPayload() []byte { + if x != nil { + return x.Payload + } + return nil +} + +type ReadMessagesReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + ChannelId string `protobuf:"bytes,1,opt,name=channel_id,json=channelId,proto3" json:"channel_id,omitempty"` + DomainId string `protobuf:"bytes,2,opt,name=domain_id,json=domainId,proto3" json:"domain_id,omitempty"` + PageMetadata *PageMetadata `protobuf:"bytes,3,opt,name=page_metadata,json=pageMetadata,proto3" json:"page_metadata,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ReadMessagesReq) Reset() { + *x = ReadMessagesReq{} + mi := &file_readers_v1_readers_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ReadMessagesReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ReadMessagesReq) ProtoMessage() {} + +func (x *ReadMessagesReq) ProtoReflect() protoreflect.Message { + mi := &file_readers_v1_readers_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ReadMessagesReq.ProtoReflect.Descriptor instead. +func (*ReadMessagesReq) Descriptor() ([]byte, []int) { + return file_readers_v1_readers_proto_rawDescGZIP(), []int{6} +} + +func (x *ReadMessagesReq) GetChannelId() string { + if x != nil { + return x.ChannelId + } + return "" +} + +func (x *ReadMessagesReq) GetDomainId() string { + if x != nil { + return x.DomainId + } + return "" +} + +func (x *ReadMessagesReq) GetPageMetadata() *PageMetadata { + if x != nil { + return x.PageMetadata + } + return nil +} + +var File_readers_v1_readers_proto protoreflect.FileDescriptor + +const file_readers_v1_readers_proto_rawDesc = "" + + "\n" + + "\x18readers/v1/readers.proto\x12\n" + + "readers.v1\"\x8c\x04\n" + + "\fPageMetadata\x12\x14\n" + + "\x05limit\x18\x01 \x01(\x04R\x05limit\x12\x16\n" + + "\x06offset\x18\x02 \x01(\x04R\x06offset\x12\x1a\n" + + "\bprotocol\x18\x03 \x01(\tR\bprotocol\x12\x12\n" + + "\x04name\x18\x04 \x01(\tR\x04name\x12\x14\n" + + "\x05value\x18\x05 \x01(\x01R\x05value\x12\x1c\n" + + "\tpublisher\x18\x06 \x01(\tR\tpublisher\x12\x1d\n" + + "\n" + + "bool_value\x18\a \x01(\bR\tboolValue\x12!\n" + + "\fstring_value\x18\b \x01(\tR\vstringValue\x12\x1d\n" + + "\n" + + "data_value\x18\t \x01(\tR\tdataValue\x12\x12\n" + + "\x04from\x18\n" + + " \x01(\x01R\x04from\x12\x0e\n" + + "\x02to\x18\v \x01(\x01R\x02to\x12\x1a\n" + + "\bsubtopic\x18\f \x01(\tR\bsubtopic\x12\x1a\n" + + "\binterval\x18\r \x01(\tR\binterval\x12\x12\n" + + "\x04read\x18\x0e \x01(\bR\x04read\x129\n" + + "\vaggregation\x18\x0f \x01(\x0e2\x17.readers.v1.AggregationR\vaggregation\x12\x1e\n" + + "\n" + + "comparator\x18\x10 \x01(\tR\n" + + "comparator\x12\x16\n" + + "\x06format\x18\x11 \x01(\tR\x06format\x12\x14\n" + + "\x05order\x18\x12 \x01(\tR\x05order\x12\x10\n" + + "\x03dir\x18\x13 \x01(\tR\x03dir\"\x97\x01\n" + + "\x0fReadMessagesRes\x12\x14\n" + + "\x05total\x18\x01 \x01(\x04R\x05total\x12=\n" + + "\rpage_metadata\x18\x02 \x01(\v2\x18.readers.v1.PageMetadataR\fpageMetadata\x12/\n" + + "\bmessages\x18\x03 \x03(\v2\x13.readers.v1.MessageR\bmessages\"u\n" + + "\aMessage\x120\n" + + "\x05senml\x18\x01 \x01(\v2\x18.readers.v1.SenMLMessageH\x00R\x05senml\x12-\n" + + "\x04json\x18\x02 \x01(\v2\x17.readers.v1.JsonMessageH\x00R\x04jsonB\t\n" + + "\apayload\"}\n" + + "\vBaseMessage\x12\x18\n" + + "\achannel\x18\x01 \x01(\tR\achannel\x12\x1a\n" + + "\bsubtopic\x18\x02 \x01(\tR\bsubtopic\x12\x1c\n" + + "\tpublisher\x18\x03 \x01(\tR\tpublisher\x12\x1a\n" + + "\bprotocol\x18\x04 \x01(\tR\bprotocol\"\xfb\x02\n" + + "\fSenMLMessage\x12+\n" + + "\x04base\x18\x01 \x01(\v2\x17.readers.v1.BaseMessageR\x04base\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x12\n" + + "\x04unit\x18\x03 \x01(\tR\x04unit\x12\x12\n" + + "\x04time\x18\x04 \x01(\x01R\x04time\x12\x1f\n" + + "\vupdate_time\x18\x05 \x01(\x01R\n" + + "updateTime\x12\x19\n" + + "\x05value\x18\x06 \x01(\x01H\x00R\x05value\x88\x01\x01\x12&\n" + + "\fstring_value\x18\a \x01(\tH\x01R\vstringValue\x88\x01\x01\x12\"\n" + + "\n" + + "data_value\x18\b \x01(\tH\x02R\tdataValue\x88\x01\x01\x12\"\n" + + "\n" + + "bool_value\x18\t \x01(\bH\x03R\tboolValue\x88\x01\x01\x12\x15\n" + + "\x03sum\x18\n" + + " \x01(\x01H\x04R\x03sum\x88\x01\x01B\b\n" + + "\x06_valueB\x0f\n" + + "\r_string_valueB\r\n" + + "\v_data_valueB\r\n" + + "\v_bool_valueB\x06\n" + + "\x04_sum\"n\n" + + "\vJsonMessage\x12+\n" + + "\x04base\x18\x01 \x01(\v2\x17.readers.v1.BaseMessageR\x04base\x12\x18\n" + + "\acreated\x18\x02 \x01(\x03R\acreated\x12\x18\n" + + "\apayload\x18\x03 \x01(\fR\apayload\"\x8c\x01\n" + + "\x0fReadMessagesReq\x12\x1d\n" + + "\n" + + "channel_id\x18\x01 \x01(\tR\tchannelId\x12\x1b\n" + + "\tdomain_id\x18\x02 \x01(\tR\bdomainId\x12=\n" + + "\rpage_metadata\x18\x03 \x01(\v2\x18.readers.v1.PageMetadataR\fpageMetadata*\x95\x01\n" + + "\vAggregation\x12\x1b\n" + + "\x17AGGREGATION_UNSPECIFIED\x10\x00\x12\x13\n" + + "\x0fAGGREGATION_MAX\x10\x01\x12\x13\n" + + "\x0fAGGREGATION_MIN\x10\x02\x12\x13\n" + + "\x0fAGGREGATION_SUM\x10\x03\x12\x15\n" + + "\x11AGGREGATION_COUNT\x10\x04\x12\x13\n" + + "\x0fAGGREGATION_AVG\x10\x052\\\n" + + "\x0eReadersService\x12J\n" + + "\fReadMessages\x12\x1b.readers.v1.ReadMessagesReq\x1a\x1b.readers.v1.ReadMessagesRes\"\x00B0Z.github.com/absmach/supermq/api/grpc/readers/v1b\x06proto3" + +var ( + file_readers_v1_readers_proto_rawDescOnce sync.Once + file_readers_v1_readers_proto_rawDescData []byte +) + +func file_readers_v1_readers_proto_rawDescGZIP() []byte { + file_readers_v1_readers_proto_rawDescOnce.Do(func() { + file_readers_v1_readers_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_readers_v1_readers_proto_rawDesc), len(file_readers_v1_readers_proto_rawDesc))) + }) + return file_readers_v1_readers_proto_rawDescData +} + +var file_readers_v1_readers_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_readers_v1_readers_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_readers_v1_readers_proto_goTypes = []any{ + (Aggregation)(0), // 0: readers.v1.Aggregation + (*PageMetadata)(nil), // 1: readers.v1.PageMetadata + (*ReadMessagesRes)(nil), // 2: readers.v1.ReadMessagesRes + (*Message)(nil), // 3: readers.v1.Message + (*BaseMessage)(nil), // 4: readers.v1.BaseMessage + (*SenMLMessage)(nil), // 5: readers.v1.SenMLMessage + (*JsonMessage)(nil), // 6: readers.v1.JsonMessage + (*ReadMessagesReq)(nil), // 7: readers.v1.ReadMessagesReq +} +var file_readers_v1_readers_proto_depIdxs = []int32{ + 0, // 0: readers.v1.PageMetadata.aggregation:type_name -> readers.v1.Aggregation + 1, // 1: readers.v1.ReadMessagesRes.page_metadata:type_name -> readers.v1.PageMetadata + 3, // 2: readers.v1.ReadMessagesRes.messages:type_name -> readers.v1.Message + 5, // 3: readers.v1.Message.senml:type_name -> readers.v1.SenMLMessage + 6, // 4: readers.v1.Message.json:type_name -> readers.v1.JsonMessage + 4, // 5: readers.v1.SenMLMessage.base:type_name -> readers.v1.BaseMessage + 4, // 6: readers.v1.JsonMessage.base:type_name -> readers.v1.BaseMessage + 1, // 7: readers.v1.ReadMessagesReq.page_metadata:type_name -> readers.v1.PageMetadata + 7, // 8: readers.v1.ReadersService.ReadMessages:input_type -> readers.v1.ReadMessagesReq + 2, // 9: readers.v1.ReadersService.ReadMessages:output_type -> readers.v1.ReadMessagesRes + 9, // [9:10] is the sub-list for method output_type + 8, // [8:9] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_readers_v1_readers_proto_init() } +func file_readers_v1_readers_proto_init() { + if File_readers_v1_readers_proto != nil { + return + } + file_readers_v1_readers_proto_msgTypes[2].OneofWrappers = []any{ + (*Message_Senml)(nil), + (*Message_Json)(nil), + } + file_readers_v1_readers_proto_msgTypes[4].OneofWrappers = []any{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_readers_v1_readers_proto_rawDesc), len(file_readers_v1_readers_proto_rawDesc)), + NumEnums: 1, + NumMessages: 7, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_readers_v1_readers_proto_goTypes, + DependencyIndexes: file_readers_v1_readers_proto_depIdxs, + EnumInfos: file_readers_v1_readers_proto_enumTypes, + MessageInfos: file_readers_v1_readers_proto_msgTypes, + }.Build() + File_readers_v1_readers_proto = out.File + file_readers_v1_readers_proto_goTypes = nil + file_readers_v1_readers_proto_depIdxs = nil +} diff --git a/api/grpc/readers/v1/readers_grpc.pb.go b/api/grpc/readers/v1/readers_grpc.pb.go new file mode 100644 index 000000000..5cc37d45b --- /dev/null +++ b/api/grpc/readers/v1/readers_grpc.pb.go @@ -0,0 +1,130 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.0 +// - protoc v6.33.0 +// source: readers/v1/readers.proto + +package v1 + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + ReadersService_ReadMessages_FullMethodName = "/readers.v1.ReadersService/ReadMessages" +) + +// ReadersServiceClient is the client API for ReadersService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// ReadersService is a service that provides access to +// readers functionalities for SuperMQ services. +type ReadersServiceClient interface { + ReadMessages(ctx context.Context, in *ReadMessagesReq, opts ...grpc.CallOption) (*ReadMessagesRes, error) +} + +type readersServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewReadersServiceClient(cc grpc.ClientConnInterface) ReadersServiceClient { + return &readersServiceClient{cc} +} + +func (c *readersServiceClient) ReadMessages(ctx context.Context, in *ReadMessagesReq, opts ...grpc.CallOption) (*ReadMessagesRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ReadMessagesRes) + err := c.cc.Invoke(ctx, ReadersService_ReadMessages_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// ReadersServiceServer is the server API for ReadersService service. +// All implementations must embed UnimplementedReadersServiceServer +// for forward compatibility. +// +// ReadersService is a service that provides access to +// readers functionalities for SuperMQ services. +type ReadersServiceServer interface { + ReadMessages(context.Context, *ReadMessagesReq) (*ReadMessagesRes, error) + mustEmbedUnimplementedReadersServiceServer() +} + +// UnimplementedReadersServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedReadersServiceServer struct{} + +func (UnimplementedReadersServiceServer) ReadMessages(context.Context, *ReadMessagesReq) (*ReadMessagesRes, error) { + return nil, status.Error(codes.Unimplemented, "method ReadMessages not implemented") +} +func (UnimplementedReadersServiceServer) mustEmbedUnimplementedReadersServiceServer() {} +func (UnimplementedReadersServiceServer) testEmbeddedByValue() {} + +// UnsafeReadersServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to ReadersServiceServer will +// result in compilation errors. +type UnsafeReadersServiceServer interface { + mustEmbedUnimplementedReadersServiceServer() +} + +func RegisterReadersServiceServer(s grpc.ServiceRegistrar, srv ReadersServiceServer) { + // If the following call panics, it indicates UnimplementedReadersServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&ReadersService_ServiceDesc, srv) +} + +func _ReadersService_ReadMessages_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ReadMessagesReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ReadersServiceServer).ReadMessages(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ReadersService_ReadMessages_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ReadersServiceServer).ReadMessages(ctx, req.(*ReadMessagesReq)) + } + return interceptor(ctx, in, info, handler) +} + +// ReadersService_ServiceDesc is the grpc.ServiceDesc for ReadersService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var ReadersService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "readers.v1.ReadersService", + HandlerType: (*ReadersServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "ReadMessages", + Handler: _ReadersService_ReadMessages_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "readers/v1/readers.proto", +} diff --git a/apidocs/openapi/alarms.yaml b/apidocs/openapi/alarms.yaml new file mode 100644 index 000000000..b3f1a8846 --- /dev/null +++ b/apidocs/openapi/alarms.yaml @@ -0,0 +1,508 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala Alarms API + description: | + HTTP API for managing alarms service. + Some useful links: + - [The Magistrala repository](https://github.com/absmach/supermq) + contact: + email: info@absmach.eu + license: + name: Apache 2.0 + url: https://github.com/absmach/supermq/blob/main/LICENSE + version: 0.18.5 + +servers: + - url: http://localhost:8050 + - url: https://localhost:8050 + +tags: + - name: alarms + description: Everything about your Alarms + externalDocs: + description: Find out more about alarms + url: https://docs.magistrala.absmach.eu + +paths: + /{domainID}/alarms: + get: + operationId: listAlarms + summary: List Alarms + description: | + Retrieves a list of alarms with optional filtering + tags: + - alarms + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/Offset' + - $ref: '#/components/parameters/Limit' + - $ref: '#/components/parameters/Order' + - $ref: '#/components/parameters/Dir' + - $ref: '#/components/parameters/ChannelID' + - $ref: '#/components/parameters/ClientID' + - $ref: '#/components/parameters/Subtopic' + - $ref: '#/components/parameters/RuleID' + - $ref: '#/components/parameters/Status' + - $ref: '#/components/parameters/AssigneeID' + - $ref: '#/components/parameters/Severity' + - $ref: '#/components/parameters/UpdatedBy' + - $ref: '#/components/parameters/AssignedBy' + - $ref: '#/components/parameters/AcknowledgedBy' + - $ref: '#/components/parameters/ResolvedBy' + - $ref: '#/components/parameters/CreatedFrom' + - $ref: '#/components/parameters/CreatedTo' + security: + - bearerAuth: [] + responses: + '200': + $ref: '#/components/responses/AlarmsPageRes' + '400': + description: Failed due to malformed query parameters + '401': + description: Missing or invalid access token + '422': + description: Database can't process request + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/alarms/{alarmID}: + get: + operationId: viewAlarm + summary: View Alarm + description: Retrieves an alarm by ID + tags: + - alarms + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/AlarmID' + security: + - bearerAuth: [] + responses: + '200': + $ref: '#/components/responses/AlarmRes' + '400': + description: Missing or invalid alarm ID + '401': + description: Missing or invalid access token + '403': + description: Failed to perform authorization over the entity + '404': + description: Alarm does not exist + '422': + description: Database can't process request + '500': + $ref: '#/components/responses/ServiceError' + put: + operationId: updateAlarm + summary: Update Alarm + description: Updates an existing alarm + tags: + - alarms + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/AlarmID' + security: + - bearerAuth: [] + requestBody: + $ref: '#/components/requestBodies/AlarmUpdateReq' + responses: + '200': + $ref: '#/components/responses/AlarmRes' + '400': + description: Failed due to malformed JSON + '401': + description: Missing or invalid access token + '403': + description: Failed to perform authorization over the entity + '404': + description: Alarm does not exist + '415': + description: Missing or invalid content type + '422': + description: Database can't process request + '500': + $ref: '#/components/responses/ServiceError' + delete: + operationId: deleteAlarm + summary: Delete Alarm + description: Deletes an alarm + tags: + - alarms + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/AlarmID' + security: + - bearerAuth: [] + responses: + '204': + description: Alarm deleted successfully + '400': + description: Failed due to malformed alarm ID + '401': + description: Missing or invalid access token + '403': + description: Failed to perform authorization over the entity + '404': + description: Alarm does not exist + '422': + description: Database can't process request + '500': + $ref: '#/components/responses/ServiceError' + + /health: + get: + summary: Retrieves service health check info + tags: + - health + security: [] + responses: + '200': + $ref: '#/components/responses/HealthRes' + '500': + $ref: '#/components/responses/ServiceError' + +components: + schemas: + Alarm: + type: object + properties: + id: + type: string + description: Unique alarm identifier + readOnly: true + rule_id: + type: string + description: Rule ID that triggered this alarm + domain_id: + type: string + description: Domain ID this alarm belongs to + channel_id: + type: string + description: Channel ID where the alarm was triggered + client_id: + type: string + description: Client ID that triggered the alarm + subtopic: + type: string + description: Subtopic associated with the alarm + status: + type: string + description: Alarm status + enum: [active, cleared] + measurement: + type: string + description: Measurement that triggered the alarm + value: + type: string + description: Value that triggered the alarm + unit: + type: string + description: Unit of measurement + threshold: + type: string + description: Threshold value that was exceeded + cause: + type: string + description: Cause or description of the alarm + severity: + type: integer + description: Severity level (0-100) + minimum: 0 + maximum: 100 + assignee_id: + type: string + description: ID of the user assigned to this alarm + created_at: + type: string + format: date-time + description: Creation timestamp + readOnly: true + updated_at: + type: string + format: date-time + description: Last update timestamp + readOnly: true + updated_by: + type: string + description: User who last updated the alarm + readOnly: true + assigned_at: + type: string + format: date-time + description: When the alarm was assigned + readOnly: true + assigned_by: + type: string + description: User who assigned the alarm + readOnly: true + acknowledged_at: + type: string + format: date-time + description: When the alarm was acknowledged + readOnly: true + acknowledged_by: + type: string + description: User who acknowledged the alarm + readOnly: true + resolved_at: + type: string + format: date-time + description: When the alarm was resolved + readOnly: true + resolved_by: + type: string + description: User who resolved the alarm + readOnly: true + metadata: + type: object + description: Custom metadata + additionalProperties: true + + AlarmsPage: + type: object + properties: + offset: + type: integer + description: Number of items to skip during retrieval + minimum: 0 + default: 0 + limit: + type: integer + description: Size of the subset to retrieve + minimum: 1 + maximum: 1000 + default: 10 + total: + type: integer + description: Total number of results + minimum: 0 + alarms: + type: array + minItems: 0 + items: + $ref: '#/components/schemas/Alarm' + required: + - alarms + - total + - offset + - limit + + parameters: + DomainID: + name: domainID + description: Domain ID + in: path + required: true + schema: + type: string + AlarmID: + name: alarmID + description: Alarm ID + in: path + required: true + schema: + type: string + Offset: + name: offset + description: Number of items to skip + in: query + required: false + schema: + type: integer + default: 0 + minimum: 0 + Limit: + name: limit + description: Size of the subset to retrieve + in: query + required: false + schema: + type: integer + default: 10 + minimum: 1 + maximum: 1000 + Order: + name: order + description: Order by field + in: query + required: false + schema: + type: string + enum: [created_at, updated_at] + default: created_at + Dir: + name: dir + description: Sort direction + in: query + required: false + schema: + type: string + enum: [asc, desc] + default: desc + ChannelID: + name: channel_id + description: Filter by channel ID + in: query + required: false + schema: + type: string + ClientID: + name: client_id + description: Filter by client ID + in: query + required: false + schema: + type: string + Subtopic: + name: subtopic + description: Filter by subtopic + in: query + required: false + schema: + type: string + RuleID: + name: rule_id + description: Filter by rule ID + in: query + required: false + schema: + type: string + Status: + name: status + description: Filter by alarm status + in: query + required: false + schema: + type: string + enum: [active, cleared, all] + default: all + AssigneeID: + name: assignee_id + description: Filter by assignee ID + in: query + required: false + schema: + type: string + Severity: + name: severity + description: Filter by severity level + in: query + required: false + schema: + type: integer + minimum: 0 + maximum: 100 + UpdatedBy: + name: updated_by + description: Filter by user who updated + in: query + required: false + schema: + type: string + AssignedBy: + name: assigned_by + description: Filter by user who assigned + in: query + required: false + schema: + type: string + AcknowledgedBy: + name: acknowledged_by + description: Filter by user who acknowledged + in: query + required: false + schema: + type: string + ResolvedBy: + name: resolved_by + description: Filter by user who resolved + in: query + required: false + schema: + type: string + CreatedFrom: + name: created_from + description: Filter alarms created after this time (RFC3339 format) + in: query + required: false + schema: + type: string + format: date-time + CreatedTo: + name: created_to + description: Filter alarms created before this time (RFC3339 format) + in: query + required: false + schema: + type: string + format: date-time + + requestBodies: + AlarmUpdateReq: + description: JSON-formatted document describing the alarm update + required: true + content: + application/json: + schema: + type: object + properties: + status: + type: string + description: Alarm status + enum: [active, cleared] + assignee_id: + type: string + description: ID of the user assigned to this alarm + severity: + type: integer + description: Severity level (0-100) + minimum: 0 + maximum: 100 + metadata: + type: object + description: Custom metadata + additionalProperties: true + + responses: + AlarmRes: + description: Alarm data retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/Alarm' + links: + update: + operationId: updateAlarm + parameters: + alarmID: $response.body#/id + domainID: $response.body#/domain_id + delete: + operationId: deleteAlarm + parameters: + alarmID: $response.body#/id + domainID: $response.body#/domain_id + AlarmsPageRes: + description: Alarms page retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/AlarmsPage' + ServiceError: + description: Unexpected server-side error occurred + HealthRes: + description: Service Health Check + content: + application/health+json: + schema: + $ref: "./schemas/health_info.yaml" + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: | + * Users access: "Authorization: Bearer " diff --git a/apidocs/openapi/bootstrap.yaml b/apidocs/openapi/bootstrap.yaml new file mode 100644 index 000000000..f44303102 --- /dev/null +++ b/apidocs/openapi/bootstrap.yaml @@ -0,0 +1,699 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala Bootstrap service + description: | + HTTP API for managing platform clients configuration. + Some useful links: + - [The Magistrala repository](https://github.com/absmach/supermq) + contact: + email: info@absmach.eu + license: + name: Apache 2.0 + url: https://github.com/absmach/supermq/blob/main/LICENSE + version: 0.18.5 + +servers: + - url: http://localhost:9013 + - url: https://localhost:9013 + +tags: + - name: configs + description: Everything about your Configs + externalDocs: + description: Find out more about Configs + url: https://docs.magistrala.absmach.eu + +paths: + /{domainID}/clients/configs: + post: + operationId: createConfig + summary: Adds new config + description: | + Adds new config to the list of config owned by user identified using + the provided access token. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + requestBody: + $ref: "#/components/requestBodies/ConfigCreateReq" + responses: + "201": + $ref: "#/components/responses/ConfigCreateRes" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "409": + description: Failed due to using an existing identity. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + "503": + description: Failed to receive response from the clients service. + get: + operationId: getConfigs + summary: Retrieves managed configs + description: | + Retrieves a list of managed configs. Due to performance concerns, data + is retrieved in subsets. The API configs must ensure that the entire + dataset is consumed either by making subsequent requests, or by + increasing the subset size of the initial request. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/State" + - $ref: "#/components/parameters/Name" + responses: + "200": + $ref: "#/components/responses/ConfigListRes" + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/clients/configs/{configID}: + get: + operationId: getConfig + summary: Retrieves config info (with channels). + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + responses: + "200": + $ref: "#/components/responses/ConfigRes" + "400": + description: Missing or invalid config. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: Config does not exist. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + put: + operationId: updateConfig + summary: Updates config info + description: | + Update is performed by replacing the current resource data with values + provided in a request payload. Note that the owner, ID, external ID, + external key, SuperMQ Client ID and key cannot be changed. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + requestBody: + $ref: "#/components/requestBodies/ConfigUpdateReq" + responses: + "200": + description: Config updated. + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: Config does not exist. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + delete: + operationId: removeConfig + summary: Removes a Config + description: | + Removes a Config. In case of successful removal the service will ensure + that the removed config is disconnected from all of the SuperMQ channels. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + responses: + "204": + description: Config removed. + "400": + description: Failed due to malformed config ID. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/clients/configs/certs/{configID}: + patch: + operationId: updateConfigCerts + summary: Updates certs + description: | + Update is performed by replacing the current certificate data with values + provided in a request payload. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + requestBody: + $ref: "#/components/requestBodies/ConfigCertUpdateReq" + responses: + "200": + description: Config updated. + $ref: "#/components/responses/ConfigUpdateCertsRes" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: Config does not exist. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/clients/configs/connections/{configID}: + put: + operationId: updateConfigConnections + summary: Updates channels the client is connected to + description: | + Update connections performs update of the channel list corresponding + Client is connected to. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + requestBody: + $ref: "#/components/requestBodies/ConfigConnUpdateReq" + responses: + "200": + description: Config updated. + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: Config does not exist. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /clients/bootstrap/{externalId}: + get: + operationId: getBootstrapConfig + summary: Retrieves configuration. + description: | + Retrieves a configuration with given external ID and external key. + tags: + - configs + security: + - bootstrapAuth: [] + parameters: + - $ref: "#/components/parameters/ExternalId" + responses: + "200": + $ref: "#/components/responses/BootstrapConfigRes" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid external key provided. + "404": + description: Failed to retrieve corresponding config. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /clients/bootstrap/secure/{externalId}: + get: + operationId: getSecureBootstrapConfig + summary: Retrieves configuration. + description: | + Retrieves a configuration with given external ID and encrypted external key. + tags: + - configs + security: + - bootstrapEncAuth: [] + parameters: + - $ref: "#/components/parameters/ExternalId" + responses: + "200": + $ref: "#/components/responses/BootstrapConfigRes" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "404": + description: | + Failed to retrieve corresponding config. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/clients/state/{configID}: + put: + operationId: updateConfigState + summary: Updates Config state. + description: | + Updating state represents enabling/disabling Config, i.e. connecting + and disconnecting corresponding SuperMQ Client to the list of Channels. + tags: + - configs + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ConfigId" + requestBody: + $ref: "#/components/requestBodies/ConfigStateUpdateReq" + responses: + "204": + description: Config removed. + "400": + description: Failed due to malformed config's ID. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /health: + get: + summary: Retrieves service health check info. + tags: + - health + security: [] + responses: + "200": + $ref: "#/components/responses/HealthRes" + "500": + $ref: "#/components/responses/ServiceError" + +components: + schemas: + State: + type: integer + enum: [0, 1] + Config: + type: object + properties: + client_id: + type: string + format: uuid + description: Corresponding SuperMQ Client ID. + magistrala_secret: + type: string + format: uuid + description: Corresponding SuperMQ Client key. + channels: + type: array + minItems: 0 + items: + type: object + properties: + id: + type: string + format: uuid + description: Channel unique identifier. + name: + type: string + description: Name of the Channel. + metadata: + type: object + description: Custom metadata related to the Channel. + external_id: + type: string + description: External ID (MAC address or some unique identifier). + external_key: + type: string + description: External key. + content: + type: string + description: Free-form custom configuration. + state: + $ref: "#/components/schemas/State" + client_cert: + type: string + description: Client certificate. + ca_cert: + type: string + description: Issuing CA certificate. + required: + - external_id + - external_key + ConfigList: + type: object + properties: + total: + type: integer + description: Total number of results. + minimum: 0 + offset: + type: integer + description: Number of items to skip during retrieval. + minimum: 0 + default: 0 + limit: + type: integer + description: Size of the subset to retrieve. + maximum: 100 + default: 10 + configs: + type: array + minItems: 0 + uniqueItems: true + items: + $ref: "#/components/schemas/Config" + required: + - configs + BootstrapConfig: + type: object + properties: + client_id: + type: string + format: uuid + description: Corresponding SuperMQ Client ID. + client_key: + type: string + format: uuid + description: Corresponding SuperMQ Client key. + channels: + type: array + minItems: 0 + items: + type: string + content: + type: string + description: Free-form custom configuration. + client_cert: + type: string + description: Client certificate. + required: + - client_id + - client_key + - channels + - content + ConfigUpdateCerts: + type: object + properties: + client_id: + type: string + format: uuid + description: Corresponding SuperMQ Client ID. + client_cert: + type: string + description: Client certificate. + client_key: + type: string + description: Key for the client_cert. + ca_cert: + type: string + description: Issuing CA certificate. + required: + - client_id + - client_key + - channels + - content + + parameters: + ConfigId: + name: configID + description: Unique Config identifier. It's the ID of the corresponding Client. + in: path + schema: + type: string + format: uuid + required: true + ExternalId: + name: externalId + description: Unique Config identifier provided by external entity. + in: path + schema: + type: string + required: true + Limit: + name: limit + description: Size of the subset to retrieve. + in: query + schema: + type: integer + default: 10 + maximum: 100 + minimum: 1 + required: false + Offset: + name: offset + description: Number of items to skip during retrieval. + in: query + schema: + type: integer + default: 0 + minimum: 0 + required: false + State: + name: state + description: A state of items + in: query + schema: + $ref: "#/components/schemas/State" + required: false + Name: + name: name + description: Name of the config. Search by name is partial-match and case-insensitive. + in: query + schema: + type: string + required: false + DomainID: + name: domainID + description: Unique domain identifier. + in: path + schema: + type: string + format: uuid + required: true + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + + requestBodies: + ConfigCreateReq: + description: JSON-formatted document describing the new config. + required: true + content: + application/json: + schema: + type: object + properties: + external_id: + type: string + description: External ID (MAC address or some unique identifier). + external_key: + type: string + description: External key. + client_id: + type: string + format: uuid + description: ID of the corresponding SuperMQ Client. + channels: + type: array + minItems: 0 + items: + type: string + format: uuid + content: + type: string + name: + type: string + client_cert: + type: string + description: Client Certificate. + client_key: + type: string + description: Client Private Key. + ca_cert: + type: string + required: + - external_id + - external_key + ConfigUpdateReq: + description: JSON-formatted document describing the updated client. + content: + application/json: + schema: + type: object + properties: + content: + type: string + name: + type: string + required: + - content + - name + ConfigCertUpdateReq: + description: JSON-formatted document describing the updated client. + content: + application/json: + schema: + type: object + properties: + client_cert: + type: string + client_key: + type: string + ca_cert: + type: string + ConfigConnUpdateReq: + description: Array if IDs the client is be connected to. + content: + application/json: + schema: + type: object + properties: + channels: + type: array + minItems: 0 + items: + type: string + format: uuid + ConfigStateUpdateReq: + description: Update the state of the Config. + content: + application/json: + schema: + type: object + properties: + state: + $ref: "#/components/schemas/State" + + responses: + ConfigCreateRes: + description: Config registered. + headers: + Location: + content: + text/plain: + schema: + type: string + description: Created configuration's relative URL (i.e. /clients/configs/{configID}). + ConfigListRes: + description: Data retrieved. Configs from this list don't contain channels. + content: + application/json: + schema: + $ref: "#/components/schemas/ConfigList" + ConfigRes: + description: Data retrieved. + content: + application/json: + schema: + $ref: "#/components/schemas/Config" + links: + update: + operationId: updateConfig + parameters: + configID: $response.body#/id + updateCerts: + operationId: updateConfigCerts + parameters: + configID: $response.body#/id + updateConnections: + operationId: updateConfigConnections + parameters: + configID: $response.body#/id + updateState: + operationId: updateConfigState + parameters: + configID: $response.body#/id + delete: + operationId: removeConfig + parameters: + configID: $response.body#/id + BootstrapConfigRes: + description: | + Data retrieved. If secure, a response is encrypted using + the secret key, so the response is in the binary form. + content: + application/json: + schema: + $ref: "#/components/schemas/BootstrapConfig" + ServiceError: + description: Unexpected server-side error occurred. + HealthRes: + description: Service Health Check. + content: + application/health+json: + schema: + $ref: "./schemas/health_info.yaml" + ConfigUpdateCertsRes: + description: Data retrieved. Config certs updated. + content: + application/json: + schema: + $ref: "#/components/schemas/ConfigUpdateCerts" + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: | + * Users access: "Authorization: Bearer " + + bootstrapAuth: + type: http + scheme: bearer + bearerFormat: string + description: | + * Clients access: "Authorization: Client " + + bootstrapEncAuth: + type: http + scheme: bearer + bearerFormat: aes-sha256-uuid + description: | + * Clients access: "Authorization: Client " + Hex-encoded configuration external key encrypted using + the AES algorithm and SHA256 sum of the external key + itself as an encryption key. + +security: + - bearerAuth: [] diff --git a/apidocs/openapi/certs.yaml b/apidocs/openapi/certs.yaml new file mode 100644 index 000000000..27622b5a5 --- /dev/null +++ b/apidocs/openapi/certs.yaml @@ -0,0 +1,722 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.3 +info: + title: Certs Service API + description: | + Certificate management service for issuing, renewing, revoking, and managing X.509 certificates. + This service provides PKI functionality including certificate lifecycle management, OCSP responder, + and CRL generation. + version: 1.0.0 + contact: + name: Abstract Machines + license: + name: Apache-2.0 + url: https://www.apache.org/licenses/LICENSE-2.0.html + +servers: + - url: http://localhost:9019 + description: Development server + +tags: + - name: certificates + description: Certificate lifecycle management operations + - name: pki + description: PKI infrastructure operations (OCSP, CRL, CA) + - name: health + description: Service health and monitoring + +security: + - BearerAuth: [] + +paths: + /{domainID}/certs/issue/{entityID}: + post: + tags: + - certificates + summary: Issue a new certificate + description: Issues a new X.509 certificate for the specified entity with custom subject options + operationId: issueCert + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/EntityID' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/IssueCertRequest' + responses: + '201': + description: Certificate successfully issued + content: + application/json: + schema: + $ref: '#/components/schemas/CertificateResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs/{id}/renew: + patch: + tags: + - certificates + summary: Renew a certificate + description: Renews an existing certificate with extended TTL and new serial number + operationId: renewCert + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/CertID' + responses: + '200': + description: Certificate successfully renewed + content: + application/json: + schema: + $ref: '#/components/schemas/RenewCertResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs/{id}/revoke: + patch: + tags: + - certificates + summary: Revoke a certificate + description: Revokes a certificate by its serial number + operationId: revokeCert + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/CertID' + responses: + '204': + description: Certificate successfully revoked + '401': + $ref: '#/components/responses/Unauthorized' + '404': + $ref: '#/components/responses/NotFound' + '422': + $ref: '#/components/responses/UnprocessableEntity' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs/{entityID}/delete: + delete: + tags: + - certificates + summary: Delete certificates for an entity + description: Deletes all certificates associated with the specified entity + operationId: deleteCert + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/EntityID' + responses: + '204': + description: Certificates successfully deleted + '401': + $ref: '#/components/responses/Unauthorized' + '422': + $ref: '#/components/responses/UnprocessableEntity' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs: + get: + tags: + - certificates + summary: List certificates + description: Retrieves a paginated list of certificates with optional filtering by entity ID + operationId: listCerts + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/Offset' + - $ref: '#/components/parameters/Limit' + - $ref: '#/components/parameters/EntityIDFilter' + responses: + '200': + description: Certificates successfully retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/CertificateListResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs/{id}: + get: + tags: + - certificates + summary: View certificate details + description: Retrieves detailed information about a specific certificate by serial number + operationId: viewCert + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/CertID' + responses: + '200': + description: Certificate details successfully retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/ViewCertResponse' + '401': + $ref: '#/components/responses/Unauthorized' + '404': + $ref: '#/components/responses/NotFound' + '500': + $ref: '#/components/responses/InternalServerError' + + /{domainID}/certs/csrs/{entityID}: + post: + tags: + - certificates + summary: Issue certificate from CSR + description: Issues a certificate from a Certificate Signing Request (CSR) + operationId: issueFromCSR + security: + - BearerAuth: [] + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/EntityID' + - $ref: '#/components/parameters/TTL' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/IssueFromCSRRequest' + responses: + '200': + description: Certificate successfully issued from CSR + content: + application/json: + schema: + $ref: '#/components/schemas/IssueFromCSRResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '500': + $ref: '#/components/responses/InternalServerError' + + /certs/csrs/{entityID}: + post: + tags: + - certificates + summary: Issue certificate from CSR (Internal) + description: Issues a certificate from a CSR using internal agent authentication + operationId: issueFromCSRInternal + security: + - AgentAuth: [] + parameters: + - $ref: '#/components/parameters/EntityID' + - $ref: '#/components/parameters/TTL' + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/IssueFromCSRRequest' + responses: + '200': + description: Certificate successfully issued from CSR + content: + application/json: + schema: + $ref: '#/components/schemas/IssueFromCSRResponse' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthorized' + '500': + $ref: '#/components/responses/InternalServerError' + + /certs/ocsp: + post: + tags: + - pki + summary: OCSP responder + description: | + Online Certificate Status Protocol (OCSP) responder endpoint. + Accepts both binary OCSP requests and JSON format requests. + operationId: ocsp + security: [] + parameters: + - name: force_status + in: query + description: Force a specific OCSP status for testing + required: false + schema: + type: string + requestBody: + required: true + content: + application/ocsp-request: + schema: + type: string + format: binary + description: DER-encoded OCSP request + application/json: + schema: + $ref: '#/components/schemas/OCSPRequest' + responses: + '200': + description: OCSP response + content: + application/ocsp-response: + schema: + type: string + format: binary + description: DER-encoded OCSP response + '400': + $ref: '#/components/responses/BadRequest' + '500': + $ref: '#/components/responses/InternalServerError' + + /certs/crl: + get: + tags: + - pki + summary: Generate Certificate Revocation List + description: Generates and returns the current Certificate Revocation List (CRL) + operationId: generateCRL + security: [] + responses: + '200': + description: CRL successfully generated + content: + application/json: + schema: + $ref: '#/components/schemas/CRLResponse' + '500': + $ref: '#/components/responses/InternalServerError' + + /certs/view-ca: + get: + tags: + - pki + summary: View CA certificate + description: Retrieves the CA certificate chain (root and intermediate certificates) + operationId: viewCA + security: [] + responses: + '200': + description: CA certificate successfully retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/ViewCertResponse' + '500': + $ref: '#/components/responses/InternalServerError' + + /certs/download-ca: + get: + tags: + - pki + summary: Download CA certificate + description: Downloads the CA certificate as a ZIP file + operationId: downloadCA + security: [] + responses: + '200': + description: CA certificate ZIP file + content: + application/zip: + schema: + type: string + format: binary + '500': + $ref: '#/components/responses/InternalServerError' + + /health: + get: + summary: Retrieves service health check info. + tags: + - health + security: [] + responses: + '200': + $ref: '#/components/responses/HealthRes' + '500': + $ref: '#/components/responses/InternalServerError' + + /metrics: + get: + tags: + - health + summary: Prometheus metrics + description: Returns Prometheus metrics for monitoring + operationId: metrics + security: [] + responses: + '200': + description: Metrics successfully retrieved + content: + text/plain: + schema: + type: string + +components: + securitySchemes: + BearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: User authentication token + AgentAuth: + type: http + scheme: bearer + description: Agent authentication token for internal operations + + parameters: + DomainID: + name: domainID + in: path + required: true + description: Domain identifier + schema: + type: string + EntityID: + name: entityID + in: path + required: true + description: Entity identifier for the certificate + schema: + type: string + CertID: + name: id + in: path + required: true + description: Certificate serial number + schema: + type: string + Offset: + name: offset + in: query + description: Number of items to skip + schema: + type: integer + minimum: 0 + default: 0 + Limit: + name: limit + in: query + description: Maximum number of items to return + schema: + type: integer + minimum: 1 + maximum: 100 + default: 10 + EntityIDFilter: + name: entity_id + in: query + description: Filter certificates by entity ID + schema: + type: string + TTL: + name: ttl + in: query + description: Time to live for the certificate (e.g., "8760h", "365d") + schema: + type: string + + schemas: + IssueCertRequest: + type: object + required: + - options + properties: + ttl: + type: string + description: Time to live for the certificate (e.g., "8760h" for 1 year) + example: "8760h" + ip_addresses: + type: array + items: + type: string + description: IP addresses to include in the certificate + example: ["192.168.1.1", "10.0.0.1"] + options: + $ref: '#/components/schemas/SubjectOptions' + + SubjectOptions: + type: object + required: + - common_name + properties: + common_name: + type: string + description: Common Name (CN) for the certificate subject + example: "example.com" + organization: + type: array + items: + type: string + description: Organization (O) + example: ["Abstract Machines"] + organizational_unit: + type: array + items: + type: string + description: Organizational Unit (OU) + example: ["Engineering"] + country: + type: array + items: + type: string + description: Country (C) + example: ["US"] + province: + type: array + items: + type: string + description: Province or State (ST) + example: ["California"] + locality: + type: array + items: + type: string + description: Locality or City (L) + example: ["San Francisco"] + street_address: + type: array + items: + type: string + description: Street Address + example: ["123 Main St"] + postal_code: + type: array + items: + type: string + description: Postal Code + example: ["94105"] + dns_names: + type: array + items: + type: string + description: DNS names for Subject Alternative Names + example: ["example.com", "www.example.com"] + ip_addresses: + type: array + items: + type: string + description: IP addresses for Subject Alternative Names + example: ["192.168.1.1"] + + CertificateResponse: + type: object + properties: + serial_number: + type: string + description: Unique serial number of the certificate + example: "4a:3f:5e:2c:1b:8d:9e:7f" + certificate: + type: string + description: PEM-encoded certificate + example: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----" + key: + type: string + description: PEM-encoded private key + example: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" + revoked: + type: boolean + description: Whether the certificate is revoked + example: false + expiry_time: + type: string + format: date-time + description: Certificate expiration time + example: "2026-11-05T12:00:00Z" + entity_id: + type: string + description: Entity identifier associated with the certificate + example: "entity-123" + + RenewCertResponse: + type: object + properties: + certificate: + $ref: '#/components/schemas/ViewCertResponse' + + ViewCertResponse: + type: object + properties: + serial_number: + type: string + description: Certificate serial number + example: "4a:3f:5e:2c:1b:8d:9e:7f" + certificate: + type: string + description: PEM-encoded certificate + example: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----" + key: + type: string + description: PEM-encoded private key + example: "-----BEGIN RSA PRIVATE KEY-----\n...\n-----END RSA PRIVATE KEY-----" + revoked: + type: boolean + description: Revocation status + example: false + expiry_time: + type: string + format: date-time + description: Expiration timestamp + example: "2026-11-05T12:00:00Z" + entity_id: + type: string + description: Associated entity identifier + example: "entity-123" + + CertificateListResponse: + type: object + properties: + total: + type: integer + format: uint64 + description: Total number of certificates + example: 100 + offset: + type: integer + format: uint64 + description: Current offset + example: 0 + limit: + type: integer + format: uint64 + description: Current limit + example: 10 + certificates: + type: array + items: + $ref: '#/components/schemas/ViewCertResponse' + + IssueFromCSRRequest: + type: object + required: + - csr + properties: + csr: + type: string + format: byte + description: PEM-encoded Certificate Signing Request + example: "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURSBSRVFVRVNULS0tLS0K..." + + IssueFromCSRResponse: + type: object + properties: + serial_number: + type: string + description: Serial number of the issued certificate + example: "4a:3f:5e:2c:1b:8d:9e:7f" + certificate: + type: string + description: PEM-encoded certificate + example: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----" + revoked: + type: boolean + description: Revocation status + example: false + expiry_time: + type: string + format: date-time + description: Expiration timestamp + example: "2026-11-05T12:00:00Z" + entity_id: + type: string + description: Associated entity identifier + example: "entity-123" + + OCSPRequest: + type: object + properties: + serial_number: + type: string + description: Certificate serial number to check + example: "4a:3f:5e:2c:1b:8d:9e:7f" + certificate: + type: string + description: PEM-encoded certificate to check + example: "-----BEGIN CERTIFICATE-----\n...\n-----END CERTIFICATE-----" + status: + type: string + description: Force a specific status (for testing) + enum: [good, revoked, unknown] + + CRLResponse: + type: object + properties: + crl: + type: string + format: byte + description: DER-encoded Certificate Revocation List + + Error: + type: object + properties: + error: + type: string + description: Error message + example: "invalid request" + + responses: + BadRequest: + description: Bad request - invalid parameters or malformed request + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + Unauthorized: + description: Unauthorized - invalid or missing authentication token + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + NotFound: + description: Resource not found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + UnprocessableEntity: + description: Unprocessable entity - request cannot be processed + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + InternalServerError: + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + HealthRes: + description: Service Health Check. + content: + application/health+json: + schema: + $ref: './schemas/health_info.yaml' diff --git a/apidocs/openapi/notifiers.yaml b/apidocs/openapi/notifiers.yaml new file mode 100644 index 000000000..f5bc2bb3f --- /dev/null +++ b/apidocs/openapi/notifiers.yaml @@ -0,0 +1,292 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala Notifiers service + description: | + HTTP API for Notifiers service. + Some useful links: + - [The Magistrala repository](https://github.com/absmach/supermq) + contact: + email: info@absmach.eu + license: + name: Apache 2.0 + url: https://github.com/absmach/supermq/blob/main/LICENSE + version: 0.18.5 + +servers: + - url: http://localhost:9014 + - url: https://localhost:9014 + - url: http://localhost:9015 + - url: https://localhost:9015 + +tags: + - name: notifiers + description: Everything about your Notifiers + externalDocs: + description: Find out more about notifiers + url: https://docs.magistrala.absmach.eu + +paths: + /subscriptions: + post: + operationId: createSubscription + summary: Create subscription + description: Creates a new subscription give a topic and contact. + tags: + - notifiers + requestBody: + $ref: "#/components/requestBodies/Create" + responses: + "201": + $ref: "#/components/responses/Create" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "409": + description: Failed due to using an existing topic and contact. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + get: + operationId: listSubscriptions + summary: List subscriptions + description: List subscriptions given list parameters. + tags: + - notifiers + parameters: + - $ref: "#/components/parameters/Topic" + - $ref: "#/components/parameters/Contact" + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/Limit" + responses: + "200": + $ref: "#/components/responses/Page" + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + /subscriptions/{id}: + get: + operationId: viewSubscription + summary: Get subscription with the provided id + description: Retrieves a subscription with the provided id. + tags: + - notifiers + parameters: + - $ref: "#/components/parameters/Id" + responses: + "200": + $ref: "#/components/responses/View" + "400": + description: Failed due to malformed ID. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + delete: + operationId: removeSubscription + summary: Delete subscription with the provided id + description: Removes a subscription with the provided id. + tags: + - notifiers + parameters: + - $ref: "#/components/parameters/Id" + responses: + "204": + description: Subscription removed + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + /health: + get: + summary: Retrieves service health check info. + tags: + - health + security: [] + responses: + "200": + $ref: "#/components/responses/HealthRes" + "500": + $ref: "#/components/responses/ServiceError" + +components: + schemas: + Subscription: + type: object + properties: + id: + type: string + format: ulid + example: 01EWDVKBQSG80B6PQRS9PAAY35 + description: ULID id of the subscription. + owner_id: + type: string + format: uuid + example: 18167738-f7a8-4e96-a123-58c3cd14de3a + description: An id of the owner who created subscription. + topic: + type: string + example: topic.subtopic + description: Topic to which the user subscribes. + contact: + type: string + example: user@example.com + description: The contact of the user to which the notification will be sent. + CreateSubscription: + type: object + properties: + topic: + type: string + example: topic.subtopic + description: Topic to which the user subscribes. + contact: + type: string + example: user@example.com + description: The contact of the user to which the notification will be sent. + Page: + type: object + properties: + subscriptions: + type: array + minItems: 0 + uniqueItems: true + items: + $ref: "#/components/schemas/Subscription" + total: + type: integer + description: Total number of items. + offset: + type: integer + description: Number of items to skip during retrieval. + limit: + type: integer + description: Maximum number of items to return in one page. + + parameters: + Id: + name: id + description: Unique identifier. + in: path + schema: + type: string + format: ulid + required: true + Limit: + name: limit + description: Size of the subset to retrieve. + in: query + schema: + type: integer + default: 10 + maximum: 100 + minimum: 1 + required: false + Offset: + name: offset + description: Number of items to skip during retrieval. + in: query + schema: + type: integer + default: 0 + minimum: 0 + required: false + Topic: + name: topic + description: Topic name. + in: query + schema: + type: string + required: false + Contact: + name: contact + description: Subscription contact. + in: query + schema: + type: string + required: false + + requestBodies: + Create: + description: JSON-formatted document describing the new subscription to be created + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateSubscription" + + responses: + Create: + description: Created a new subscription. + headers: + Location: + content: + text/plain: + schema: + type: string + description: Created subscription relative URL + example: /subscriptions/{id} + View: + description: View subscription. + content: + application/json: + schema: + $ref: "#/components/schemas/Subscription" + links: + delete: + operationId: removeSubscription + parameters: + id: $response.body#/id + Page: + description: Data retrieved. + content: + application/json: + schema: + $ref: "#/components/schemas/Page" + ServiceError: + description: Unexpected server-side error occurred. + HealthRes: + description: Service Health Check. + content: + application/health+json: + schema: + $ref: "./schemas/health_info.yaml" + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: | + * Users access: "Authorization: Bearer " + +security: + - bearerAuth: [] diff --git a/apidocs/openapi/readers.yaml b/apidocs/openapi/readers.yaml new file mode 100644 index 000000000..f5634a25e --- /dev/null +++ b/apidocs/openapi/readers.yaml @@ -0,0 +1,312 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala reader service + description: | + HTTP API for reading messages. + Some useful links: + - [The Magistrala repository](https://github.com/absmach/supermq) + contact: + email: info@absmach.eu + license: + name: Apache 2.0 + url: https://github.com/absmach/supermq/blob/main/LICENSE + version: 0.18.5 + +servers: + - url: http://localhost:9003 + - url: https://localhost:9003 + - url: http://localhost:9005 + - url: https://localhost:9005 + - url: http://localhost:9009 + - url: https://localhost:9009 + - url: http://localhost:9011 + - url: https://localhost:9011 + +tags: + - name: readers + description: Everything about your Readers + externalDocs: + description: Find out more about readers + url: https://docs.magistrala.absmach.eu + +paths: + /{domainID}/channels/{chanId}/messages: + get: + operationId: getMessages + summary: Retrieves messages sent to single channel + description: | + Retrieves a list of messages sent to specific channel. Due to + performance concerns, data is retrieved in subsets. The API readers must + ensure that the entire dataset is consumed either by making subsequent + requests, or by increasing the subset size of the initial request. + tags: + - readers + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/ChanId" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/Publisher" + - $ref: "#/components/parameters/Name" + - $ref: "#/components/parameters/Value" + - $ref: "#/components/parameters/BoolValue" + - $ref: "#/components/parameters/StringValue" + - $ref: "#/components/parameters/DataValue" + - $ref: "#/components/parameters/From" + - $ref: "#/components/parameters/To" + - $ref: "#/components/parameters/Aggregation" + - $ref: "#/components/parameters/Interval" + responses: + "200": + $ref: "#/components/responses/MessagesPageRes" + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "500": + $ref: "#/components/responses/ServiceError" + /health: + get: + operationId: health + summary: Retrieves service health check info. + tags: + - health + security: [] + responses: + "200": + $ref: "#/components/responses/HealthRes" + "500": + $ref: "#/components/responses/ServiceError" + +components: + schemas: + MessagesPage: + type: object + properties: + total: + type: number + description: Total number of items that are present on the system. + offset: + type: number + description: Number of items that were skipped during retrieval. + limit: + type: number + description: Size of the subset that was retrieved. + messages: + type: array + minItems: 0 + uniqueItems: true + items: + type: object + properties: + channel: + type: integer + description: Unique channel id. + publisher: + type: integer + description: Unique publisher id. + protocol: + type: string + description: Protocol name. + name: + type: string + description: Measured parameter name. + unit: + type: string + description: Value unit. + value: + type: number + description: Measured value in number. + stringValue: + type: string + description: Measured value in string format. + boolValue: + type: boolean + description: Measured value in boolean format. + dataValue: + type: string + description: Measured value in binary format. + valueSum: + type: number + description: Sum value. + time: + type: number + description: Time of measurement. + updateTime: + type: number + description: Time of updating measurement. + + parameters: + DomainID: + name: domainID + description: Unique domain identifier. + in: path + schema: + type: string + format: uuid + required: true + ChanId: + name: chanId + description: Unique channel identifier. + in: path + schema: + type: string + format: uuid + required: true + Limit: + name: limit + description: Size of the subset to retrieve. + in: query + schema: + type: integer + default: 10 + maximum: 100 + minimum: 1 + required: false + Offset: + name: offset + description: Number of items to skip during retrieval. + in: query + schema: + type: integer + default: 0 + minimum: 0 + required: false + Publisher: + name: Publisher + description: Unique thing identifier. + in: query + schema: + type: string + format: uuid + required: false + Name: + name: name + description: SenML message name. + in: query + schema: + type: string + required: false + Value: + name: v + description: SenML message value. + in: query + schema: + type: string + required: false + BoolValue: + name: vb + description: SenML message bool value. + in: query + schema: + type: boolean + required: false + StringValue: + name: vs + description: SenML message string value. + in: query + schema: + type: string + required: false + DataValue: + name: vd + description: SenML message data value. + in: query + schema: + type: string + required: false + Comparator: + name: comparator + description: Value comparison operator. + in: query + schema: + type: string + default: eq + enum: + - eq + - lt + - le + - gt + - ge + required: false + From: + name: from + description: SenML message time in nanoseconds (integer part represents seconds). + in: query + schema: + type: number + example: 1709218556069 + required: false + To: + name: to + description: SenML message time in nanoseconds (integer part represents seconds). + in: query + schema: + type: number + example: 1709218757503 + required: false + Aggregation: + name: aggregation + description: Aggregation function. + in: query + schema: + type: string + enum: + - MAX + - AVG + - MIN + - SUM + - COUNT + - max + - min + - sum + - avg + - count + example: MAX + required: false + Interval: + name: interval + description: Aggregation interval. + in: query + schema: + type: string + example: 10s + required: false + + responses: + MessagesPageRes: + description: Data retrieved. + content: + application/json: + schema: + $ref: "#/components/schemas/MessagesPage" + ServiceError: + description: Unexpected server-side error occurred. + HealthRes: + description: Service Health Check. + content: + application/health+json: + schema: + $ref: "./schemas/health_info.yaml" + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: | + * Users access: "Authorization: Bearer " + + thingAuth: + type: http + scheme: bearer + bearerFormat: uuid + description: | + * Things access: "Authorization: Thing " + +security: + - bearerAuth: [] + - thingAuth: [] diff --git a/apidocs/openapi/reports.yaml b/apidocs/openapi/reports.yaml new file mode 100644 index 000000000..7f306627a --- /dev/null +++ b/apidocs/openapi/reports.yaml @@ -0,0 +1,553 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala Reports Service API + description: | + HTTP API for managing reports service. + version: 0.18.5 +servers: + - url: http://localhost:9017 +tags: + - name: reports + description: Operations related to report configurations and generation +paths: + /{domainID}/reports: + post: + operationId: generateReport + summary: Generate a report + description: Generates a report based on the provided configuration or an existing config. The action determines the response format. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + security: + - bearerAuth: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateReportRequest' + responses: + '200': + description: Report generated successfully (content varies by action) + content: + application/json: + schema: + $ref: '#/components/schemas/GenerateReportResponse' + application/octet-stream: + schema: + type: string + format: binary + '400': + description: Invalid request parameters + '401': + description: Missing or invalid access token + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/reports/configs: + post: + operationId: addReportConfig + summary: Create a report configuration + description: Creates a new report configuration. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + security: + - bearerAuth: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/AddReportConfigRequest' + responses: + '201': + description: Report configuration created + headers: + Location: + schema: + type: string + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '400': + description: Invalid request body + '401': + description: Missing or invalid access token + '500': + $ref: '#/components/responses/ServiceError' + get: + operationId: listReportConfigs + summary: List report configurations + description: Retrieves a paginated list of report configurations. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/Offset' + - $ref: '#/components/parameters/Limit' + security: + - bearerAuth: [] + responses: + '200': + description: List of report configurations + content: + application/json: + schema: + $ref: '#/components/schemas/ListReportsConfigResponse' + '400': + description: Invalid query parameters + '401': + description: Missing or invalid access token + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/reports/configs/{reportID}: + get: + operationId: viewReportConfig + summary: View a report configuration + description: Retrieves details of a specific report configuration. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + responses: + '200': + description: Report configuration details + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '404': + description: Report configuration not found + '401': + description: Missing or invalid access token + '500': + $ref: '#/components/responses/ServiceError' + patch: + operationId: updateReportConfig + summary: Update a report configuration + description: Updates specified fields of a report configuration. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateReportConfigRequest' + responses: + '200': + description: Report configuration updated + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '400': + description: Invalid request body + '401': + description: Missing or invalid access token + '404': + description: Report configuration not found + '500': + $ref: '#/components/responses/ServiceError' + delete: + operationId: deleteReportConfig + summary: Delete a report configuration + description: Permanently deletes a report configuration. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + responses: + '204': + description: Report configuration deleted + '401': + description: Missing or invalid access token + '404': + description: Report configuration not found + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/reports/configs/{reportID}/schedule: + patch: + operationId: updateReportSchedule + summary: Update report schedule + description: Updates the schedule of a report configuration. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/Schedule' + responses: + '200': + description: Schedule updated + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '400': + description: Invalid schedule + '401': + description: Missing or invalid access token + '404': + description: Report configuration not found + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/reports/configs/{reportID}/enable: + post: + operationId: enableReportConfig + summary: Enable a report configuration + description: Enables a report configuration to generate scheduled reports. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + responses: + '200': + description: Report configuration enabled + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '401': + description: Missing or invalid access token + '404': + description: Report configuration not found + '500': + $ref: '#/components/responses/ServiceError' + + /{domainID}/reports/configs/{reportID}/disable: + post: + operationId: disableReportConfig + summary: Disable a report configuration + description: Disables a report configuration, stopping scheduled reports. + tags: + - reports + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/ReportID' + security: + - bearerAuth: [] + responses: + '200': + description: Report configuration disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ReportConfig' + '401': + description: Missing or invalid access token + '404': + description: Report configuration not found + '500': + $ref: '#/components/responses/ServiceError' + + /health: + get: + summary: Service health check + tags: + - health + responses: + '200': + $ref: '#/components/responses/HealthRes' + +components: + schemas: + ReportConfig: + type: object + properties: + id: + type: string + readOnly: true + name: + type: string + description: + type: string + domain_id: + type: string + readOnly: true + schedule: + $ref: '#/components/schemas/Schedule' + config: + $ref: '#/components/schemas/MetricConfig' + email: + $ref: '#/components/schemas/EmailSetting' + metrics: + type: array + items: + $ref: '#/components/schemas/ReqMetric' + status: + $ref: '#/components/schemas/Status' + created_at: + type: string + format: date-time + readOnly: true + created_by: + type: string + readOnly: true + updated_at: + type: string + format: date-time + readOnly: true + updated_by: + type: string + readOnly: true + required: + - name + - metrics + - config + + Schedule: + type: object + properties: + recurring: + type: string + enum: [None, Daily, Weekly, Monthly] + recurring_period: + type: integer + minimum: 1 + start_time: + type: string + format: date-time + next_run: + type: string + format: date-time + readOnly: true + + MetricConfig: + type: object + properties: + title: + type: string + maxLength: 100 + format: + type: string + enum: [pdf, csv, html] + aggregation: + $ref: '#/components/schemas/AggConfig' + + AggConfig: + type: object + properties: + window: + type: string + function: + type: string + enum: [sum, average, max, min] + + EmailSetting: + type: object + properties: + recipients: + type: array + items: + type: string + format: email + subject: + type: string + body_template: + type: string + required: + - recipients + - subject + + ReqMetric: + type: object + properties: + name: + type: string + type: + type: string + enum: [gauge, counter, histogram] + parameters: + type: object + required: + - name + - type + + Status: + type: string + enum: [enabled, disabled] + + GenerateReportRequest: + type: object + properties: + action: + type: string + enum: [view, download, email] + config_id: + type: string + name: + type: string + description: + type: string + schedule: + $ref: '#/components/schemas/Schedule' + config: + $ref: '#/components/schemas/MetricConfig' + email: + $ref: '#/components/schemas/EmailSetting' + metrics: + type: array + items: + $ref: '#/components/schemas/ReqMetric' + required: + - action + + GenerateReportResponse: + type: object + properties: + total: + type: integer + from: + type: string + format: date-time + to: + type: string + format: date-time + aggregation: + $ref: '#/components/schemas/AggConfig' + reports: + type: array + items: + $ref: '#/components/schemas/Report' + + Report: + type: object + properties: + timestamp: + type: string + format: date-time + value: + type: number + metric_name: + type: string + + AddReportConfigRequest: + type: object + properties: + name: + type: string + description: + type: string + schedule: + $ref: '#/components/schemas/Schedule' + config: + $ref: '#/components/schemas/MetricConfig' + email: + $ref: '#/components/schemas/EmailSetting' + metrics: + type: array + items: + $ref: '#/components/schemas/ReqMetric' + status: + $ref: '#/components/schemas/Status' + required: + - name + - metrics + - config + + UpdateReportConfigRequest: + type: object + properties: + name: + type: string + description: + type: string + schedule: + $ref: '#/components/schemas/Schedule' + config: + $ref: '#/components/schemas/MetricConfig' + email: + $ref: '#/components/schemas/EmailSetting' + metrics: + type: array + items: + $ref: '#/components/schemas/ReqMetric' + status: + $ref: '#/components/schemas/Status' + + ListReportsConfigResponse: + type: object + properties: + total: + type: integer + offset: + type: integer + limit: + type: integer + report_configs: + type: array + items: + $ref: '#/components/schemas/ReportConfig' + + parameters: + DomainID: + name: domainID + in: path + required: true + schema: + type: string + ReportID: + name: reportID + in: path + required: true + schema: + type: string + Offset: + name: offset + in: query + schema: + type: integer + default: 0 + minimum: 0 + Limit: + name: limit + in: query + schema: + type: integer + default: 10 + minimum: 1 + maximum: 100 + + responses: + ServiceError: + description: Unexpected server error + HealthRes: + description: Service Health Check. + content: + application/health+json: + schema: + $ref: './schemas/health_info.yaml' + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT diff --git a/apidocs/openapi/rules.yaml b/apidocs/openapi/rules.yaml new file mode 100644 index 000000000..58c3a8990 --- /dev/null +++ b/apidocs/openapi/rules.yaml @@ -0,0 +1,586 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +openapi: 3.0.1 +info: + title: Magistrala Rules Engine API + description: | + HTTP API for managing rules engine service. + Some useful links: + - [The Magistrala repository](https://github.com/absmach/supermq) + contact: + email: info@absmach.eu + license: + name: Apache 2.0 + url: https://github.com/absmach/supermq/blob/main/LICENSE + version: 0.18.5 + +servers: + - url: http://localhost:9008 + - url: http://localhost:9008 + +tags: + - name: rules engine + description: Everything about your Rules Engine + externalDocs: + description: Find out more about rules engine + url: https://docs.magistrala.absmach.eu + +paths: + /{domainID}/rules: + post: + operationId: createRule + summary: Create Rule + description: | + Creates a new rule for message processing + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + security: + - bearerAuth: [] + requestBody: + $ref: '#/components/requestBodies/RuleCreateReq' + responses: + '201': + $ref: '#/components/responses/RuleCreateRes' + '400': + description: Failed due to malformed JSON + '401': + description: Missing or invalid access token + '415': + description: Missing or invalid content type + "500": + $ref: "#/components/responses/ServiceError" + "503": + description: Failed to receive response from the clients service. + get: + operationId: getRules + summary: List Rules + description: | + Retrieves a list of rules with optional filtering + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/Offset' + - $ref: '#/components/parameters/Limit' + - $ref: '#/components/parameters/InputChannel' + - $ref: '#/components/parameters/OutputChannel' + - $ref: '#/components/parameters/Status' + security: + - bearerAuth: [] + responses: + '200': + $ref: '#/components/responses/RuleListRes' + '400': + description: Failed due to malformed query parameters + '401': + description: Missing or invalid access token + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/rules/{ruleID}: + get: + operationId: getRule + summary: View Rule + description: Retrieves a rule by ID + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/RuleID' + security: + - bearerAuth: [] + responses: + '200': + $ref: '#/components/responses/RuleRes' + "400": + description: Missing or invalid rule + "403": + description: Failed to perform authorization over the entity + '401': + description: Missing or invalid access token + '404': + description: Rule does not exist + "422": + description: Database can't process request + "500": + $ref: "#/components/responses/ServiceError" + put: + operationId: updateRule + summary: Update Rule + description: Updates an existing rule + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/RuleID' + security: + - bearerAuth: [] + requestBody: + $ref: '#/components/requestBodies/RuleUpdateReq' + responses: + '200': + $ref: '#/components/responses/RuleRes' + '400': + description: Failed due to malformed JSON + '401': + description: Missing or invalid access token + '404': + description: Rule does not exist + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + delete: + operationId: removeRule + summary: Delete Rule + description: Deletes a rule + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/RuleID' + security: + - bearerAuth: [] + responses: + '204': + description: Rule deleted successfully + "400": + description: Failed due to malformed rule ID + '401': + description: Missing or invalid access token + "403": + description: Failed to perform authorization over the entity + "422": + description: Database can't process request + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/rules/{ruleID}/enable: + put: + operationId: enableRule + summary: Enable Rule + description: Enables a rule for processing + tags: + - rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/RuleID' + security: + - bearerAuth: [] + responses: + '200': + description: Rule enabled successfully + "400": + description: Failed due to malformed JSON + '401': + description: Missing or invalid access token + "403": + description: Failed to perform authorization over the entity + '404': + description: Rule does not exist + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /{domainID}/rules/{ruleID}/disable: + put: + operationId: disableRule + summary: Disable Rule + description: Disables a rule from processing + tags: + - Rules + parameters: + - $ref: '#/components/parameters/DomainID' + - $ref: '#/components/parameters/RuleID' + security: + - bearerAuth: [] + responses: + '200': + description: Rule disabled successfully + "400": + description: Failed due to malformed JSON + '401': + description: Missing or invalid access token + "403": + description: Failed to perform authorization over the entity + '404': + description: Rule does not exist + "422": + description: Database can't process request + "500": + $ref: "#/components/responses/ServiceError" + + /health: + get: + summary: Retrieves service health check info. + tags: + - health + security: [] + responses: + "200": + $ref: "#/components/responses/HealthRes" + "500": + $ref: "#/components/responses/ServiceError" + +components: + schemas: + RulesListRes: + type: object + properties: + total: + type: integer + description: Total number of results + minimum: 0 + offset: + type: integer + description: Number of items to skip during retrieval + minimum: 0 + default: 0 + limit: + type: integer + description: Size of the subset to retrieve + maximum: 100 + default: 10 + rules: + type: array + minItems: 0 + uniqueItems: true + items: + $ref: '#/components/schemas/Rule' + required: + - rules + + Rule: + type: object + properties: + id: + type: string + description: Unique rule identifier + name: + type: string + description: Rule name + domain: + type: string + description: Domain ID this rule belongs to + metadata: + type: object + description: Custom metadata + additionalProperties: + type: string + input_channel: + type: string + description: Input channel for receiving messages + input_topic: + type: string + description: Input topic for receiving messages + logic: + type: object + description: Rule processing logic script + properties: + script: + type: string + description: Script content + output_channel: + type: string + description: Output channel for processed messages + output_topic: + type: string + description: Output topic for processed messages + schedule: + type: object + description: Rule execution schedule + properties: + start_datetime: + type: string + format: date-time + description: When the schedule becomes active + time: + type: string + format: date-time + description: Specific time for the rule to run + recurring: + type: string + description: Schedule recurrence pattern + enum: [None, Daily, Weekly, Monthly] + recurring_period: + type: integer + minimum: 1 + description: Controls how many intervals to skip between executions (1 = every interval, 2 = every second interval, etc.) + status: + type: string + description: Rule status + enum: [enabled, disabled] + created_at: + type: string + format: date-time + description: Creation timestamp + readOnly: true + created_by: + type: string + description: User who created the rule + updated_at: + type: string + format: date-time + description: Last update timestamp + readOnly: true + updated_by: + type: string + description: User who last updated the rule + required: + - name + - domain + - input_channel + - input_topic + - logic + - status + + parameters: + DomainID: + name: domainID + description: Domain ID + in: path + required: true + schema: + type: string + RuleID: + name: ruleID + description: Rule ID + in: path + required: true + schema: + type: string + Offset: + name: offset + description: Number of items to skip + in: query + required: false + schema: + type: integer + default: 0 + minimum: 0 + Limit: + name: limit + description: Size of the subset + in: query + required: false + schema: + type: integer + default: 10 + minimum: 1 + InputChannel: + name: input_channel + description: Filter by input channel + in: query + required: false + schema: + type: string + OutputChannel: + name: output_channel + description: Filter by output channel + in: query + required: false + schema: + type: string + Status: + name: status + description: Filter by rule status + in: query + required: false + schema: + type: string + enum: [enabled, disabled] + default: enabled + + requestBodies: + RuleCreateReq: + description: JSON-formatted document describing the new rule + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Rule name + domain: + type: string + description: Domain ID this rule belongs to + metadata: + type: object + description: Custom metadata + additionalProperties: + type: string + input_channel: + type: string + description: Input channel for receiving messages + input_topic: + type: string + description: Input topic for receiving messages + logic: + type: object + description: Rule processing logic script + properties: + script: + type: string + description: Script content + output_channel: + type: string + description: Output channel for processed messages + output_topic: + type: string + description: Output topic for processed messages + schedule: + type: object + description: Rule execution schedule + properties: + start_datetime: + type: string + format: date-time + description: When the schedule becomes active + time: + type: string + format: date-time + description: Specific time for the rule to run + recurring: + type: string + description: Schedule recurrence pattern + enum: [None, Daily, Weekly, Monthly] + recurring_period: + type: integer + minimum: 1 + description: Controls how many intervals to skip between executions + status: + type: string + description: Rule status + enum: [enabled, disabled] + required: + - name + - domain + - input_channel + - input_topic + - logic + - schedule + RuleUpdateReq: + description: JSON-formatted document describing the rule update + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: Rule name + metadata: + type: object + description: Custom metadata + additionalProperties: + type: string + input_channel: + type: string + description: Input channel for receiving messages + input_topic: + type: string + description: Input topic for receiving messages + logic: + type: object + description: Rule processing logic script + properties: + script: + type: string + description: Script content + output_channel: + type: string + description: Output channel for processed messages + output_topic: + type: string + description: Output topic for processed messages + schedule: + type: object + description: Rule execution schedule + properties: + start_datetime: + type: string + format: date-time + description: When the schedule becomes active + time: + type: string + format: date-time + description: Specific time for the rule to run + recurring: + type: string + description: Schedule recurrence pattern + enum: [None, Daily, Weekly, Monthly] + recurring_period: + type: integer + minimum: 1 + description: Controls how many intervals to skip between executions + status: + type: string + description: Rule status + enum: [enabled, disabled] + + responses: + RuleCreateRes: + description: Rule registered + headers: + Location: + content: + text/plain: + schema: + type: string + description: Created rule's relative URL (i.e. /rules/{ruleID}) + RuleListRes: + description: Data retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/RulesListRes' + RuleRes: + description: Data retrieved + content: + application/json: + schema: + $ref: '#/components/schemas/Rule' + links: + update: + operationId: updateRule + parameters: + ruleID: $response.body#/id + enable: + operationId: enableRule + parameters: + ruleID: $response.body#/id + disable: + operationId: disableRule + parameters: + ruleID: $response.body#/id + delete: + operationId: removeRule + parameters: + ruleID: $response.body#/id + ServiceError: + description: Unexpected server-side error occurred + HealthRes: + description: Service Health Check + content: + application/health+json: + schema: + $ref: "./schemas/health_info.yaml" + + securitySchemes: + bearerAuth: + type: http + scheme: bearer + bearerFormat: JWT + description: | + * Users access: "Authorization: Bearer " diff --git a/auth/README.md b/auth/README.md index 251ab5a38..32551d96a 100644 --- a/auth/README.md +++ b/auth/README.md @@ -61,48 +61,48 @@ The service is configured using the environment variables presented in the follo | Variable | Description | Default | | :--- | :--- | :--- | -| `SMQ_AUTH_LOG_LEVEL` | Log level for the Auth service (debug, info, warn, error) | info | -| `SMQ_AUTH_DB_HOST` | Database host address | localhost | -| `SMQ_AUTH_DB_PORT` | Database host port | 5432 | -| `SMQ_AUTH_DB_USER` | Database user | supermq | -| `SMQ_AUTH_DB_PASSWORD` | Database password | supermq | -| `SMQ_AUTH_DB_NAME` | Name of the database used by the service | auth | -| `SMQ_AUTH_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| `SMQ_AUTH_DB_SSL_CERT` | Path to the PEM encoded certificate file | "" | -| `SMQ_AUTH_DB_SSL_KEY` | Path to the PEM encoded key file | "" | -| `SMQ_AUTH_DB_SSL_ROOT_CERT` | Path to the PEM encoded root certificate file | "" | -| `SMQ_AUTH_HTTP_HOST` | Auth service HTTP host | "" | -| `SMQ_AUTH_HTTP_PORT` | Auth service HTTP port | 8189 | -| `SMQ_AUTH_HTTP_SERVER_CERT` | Path to the PEM encoded HTTP server certificate file | "" | -| `SMQ_AUTH_HTTP_SERVER_KEY` | Path to the PEM encoded HTTP server key file | "" | -| `SMQ_AUTH_GRPC_HOST` | Auth service gRPC host | "" | -| `SMQ_AUTH_GRPC_PORT` | Auth service gRPC port | 8181 | -| `SMQ_AUTH_GRPC_SERVER_CERT` | Path to the PEM encoded gRPC server certificate file | "" | -| `SMQ_AUTH_GRPC_SERVER_KEY` | Path to the PEM encoded gRPC server key file | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM encoded gRPC server CA certificate file | "" | -| `SMQ_AUTH_GRPC_CLIENT_CA_CERTS` | Path to the PEM encoded gRPC client CA certificate file | "" | -| `SMQ_AUTH_SECRET_KEY` | String used for signing tokens | secret | -| `SMQ_AUTH_ACCESS_TOKEN_DURATION` | The access token expiration period | 1h | -| `SMQ_AUTH_REFRESH_TOKEN_DURATION` | The refresh token expiration period | 24h | -| `SMQ_AUTH_INVITATION_DURATION` | The invitation token expiration period | 168h | -| `SMQ_AUTH_CACHE_URL` | Redis URL for caching PAT scopes | redis://localhost:6379/0 | -| `SMQ_AUTH_CACHE_KEY_DURATION` | Duration for which PAT scope cache keys are valid | 10m | -| `SMQ_SPICEDB_HOST` | SpiceDB host address | localhost | -| `SMQ_SPICEDB_PORT` | SpiceDB host port | 50051 | -| `SMQ_SPICEDB_PRE_SHARED_KEY` | SpiceDB pre-shared key | 12345678 | -| `SMQ_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | -| `SMQ_JAEGER_URL` | Jaeger server URL | | -| `SMQ_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to supermq call home server | true | -| `SMQ_ADAPTER_INSTANCE_ID` | Adapter instance ID | "" | -| `SMQ_CALLOUT_URLS` | Comma-separated list of callout URLs | "" | -| `SMQ_CALLOUT_METHOD` | Callout method | POST | -| `SMQ_CALLOUT_TLS_VERIFICATION` | Enable TLS verification for callouts | true | -| `SMQ_CALLOUT_TIMEOUT` | Callout timeout | 10s | -| `SMQ_CALLOUT_CA_CERT` | Path to CA certificate file | "" | -| `SMQ_CALLOUT_CERT` | Path to client certificate file | "" | -| `SMQ_CALLOUT_KEY` | Path to client key file | "" | -| `SMQ_CALLOUT_OPERATIONS` | Invoke callout if the authorization permission matches any of the given permissions. | "" | +| `MG_AUTH_LOG_LEVEL` | Log level for the Auth service (debug, info, warn, error) | info | +| `MG_AUTH_DB_HOST` | Database host address | localhost | +| `MG_AUTH_DB_PORT` | Database host port | 5432 | +| `MG_AUTH_DB_USER` | Database user | supermq | +| `MG_AUTH_DB_PASSWORD` | Database password | supermq | +| `MG_AUTH_DB_NAME` | Name of the database used by the service | auth | +| `MG_AUTH_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| `MG_AUTH_DB_SSL_CERT` | Path to the PEM encoded certificate file | "" | +| `MG_AUTH_DB_SSL_KEY` | Path to the PEM encoded key file | "" | +| `MG_AUTH_DB_SSL_ROOT_CERT` | Path to the PEM encoded root certificate file | "" | +| `MG_AUTH_HTTP_HOST` | Auth service HTTP host | "" | +| `MG_AUTH_HTTP_PORT` | Auth service HTTP port | 8189 | +| `MG_AUTH_HTTP_SERVER_CERT` | Path to the PEM encoded HTTP server certificate file | "" | +| `MG_AUTH_HTTP_SERVER_KEY` | Path to the PEM encoded HTTP server key file | "" | +| `MG_AUTH_GRPC_HOST` | Auth service gRPC host | "" | +| `MG_AUTH_GRPC_PORT` | Auth service gRPC port | 8181 | +| `MG_AUTH_GRPC_SERVER_CERT` | Path to the PEM encoded gRPC server certificate file | "" | +| `MG_AUTH_GRPC_SERVER_KEY` | Path to the PEM encoded gRPC server key file | "" | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM encoded gRPC server CA certificate file | "" | +| `MG_AUTH_GRPC_CLIENT_CA_CERTS` | Path to the PEM encoded gRPC client CA certificate file | "" | +| `MG_AUTH_SECRET_KEY` | String used for signing tokens | secret | +| `MG_AUTH_ACCESS_TOKEN_DURATION` | The access token expiration period | 1h | +| `MG_AUTH_REFRESH_TOKEN_DURATION` | The refresh token expiration period | 24h | +| `MG_AUTH_INVITATION_DURATION` | The invitation token expiration period | 168h | +| `MG_AUTH_CACHE_URL` | Redis URL for caching PAT scopes | redis://localhost:6379/0 | +| `MG_AUTH_CACHE_KEY_DURATION` | Duration for which PAT scope cache keys are valid | 10m | +| `MG_SPICEDB_HOST` | SpiceDB host address | localhost | +| `MG_SPICEDB_PORT` | SpiceDB host port | 50051 | +| `MG_SPICEDB_PRE_SHARED_KEY` | SpiceDB pre-shared key | 12345678 | +| `MG_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | +| `MG_JAEGER_URL` | Jaeger server URL | | +| `MG_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | +| `MG_SEND_TELEMETRY` | Send telemetry to supermq call home server | true | +| `MG_ADAPTER_INSTANCE_ID` | Adapter instance ID | "" | +| `MG_CALLOUT_URLS` | Comma-separated list of callout URLs | "" | +| `MG_CALLOUT_METHOD` | Callout method | POST | +| `MG_CALLOUT_TLS_VERIFICATION` | Enable TLS verification for callouts | true | +| `MG_CALLOUT_TIMEOUT` | Callout timeout | 10s | +| `MG_CALLOUT_CA_CERT` | Path to CA certificate file | "" | +| `MG_CALLOUT_CERT` | Path to client certificate file | "" | +| `MG_CALLOUT_KEY` | Path to client key file | "" | +| `MG_CALLOUT_OPERATIONS` | Invoke callout if the authorization permission matches any of the given permissions. | "" | ## Deployment @@ -124,46 +124,46 @@ make auth make install # set the environment variables and run the service -SMQ_AUTH_LOG_LEVEL=info \ -SMQ_AUTH_DB_HOST=localhost \ -SMQ_AUTH_DB_PORT=5432 \ -SMQ_AUTH_DB_USER=supermq \ -SMQ_AUTH_DB_PASSWORD=supermq \ -SMQ_AUTH_DB_NAME=auth \ -SMQ_AUTH_DB_SSL_MODE=disable \ -SMQ_AUTH_DB_SSL_CERT="" \ -SMQ_AUTH_DB_SSL_KEY="" \ -SMQ_AUTH_DB_SSL_ROOT_CERT="" \ -SMQ_AUTH_HTTP_HOST=localhost \ -SMQ_AUTH_HTTP_PORT=8189 \ -SMQ_AUTH_HTTP_SERVER_CERT="" \ -SMQ_AUTH_HTTP_SERVER_KEY="" \ -SMQ_AUTH_GRPC_HOST=localhost \ -SMQ_AUTH_GRPC_PORT=8181 \ -SMQ_AUTH_GRPC_SERVER_CERT="" \ -SMQ_AUTH_GRPC_SERVER_KEY="" \ -SMQ_AUTH_GRPC_SERVER_CA_CERTS="" \ -SMQ_AUTH_GRPC_CLIENT_CA_CERTS="" \ -SMQ_AUTH_SECRET_KEY=secret \ -SMQ_AUTH_ACCESS_TOKEN_DURATION=1h \ -SMQ_AUTH_REFRESH_TOKEN_DURATION=24h \ -SMQ_AUTH_INVITATION_DURATION=168h \ -SMQ_SPICEDB_HOST=localhost \ -SMQ_SPICEDB_PORT=50051 \ -SMQ_SPICEDB_PRE_SHARED_KEY=12345678 \ -SMQ_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.zed \ -SMQ_JAEGER_URL=http://localhost:14268/api/traces \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_SEND_TELEMETRY=true \ -SMQ_AUTH_ADAPTER_INSTANCE_ID="" \ -SMQ_CALLOUT_URLS="" \ -SMQ_CALLOUT_METHOD="POST" \ -SMQ_CALLOUT_TLS_VERIFICATION=true \ +MG_AUTH_LOG_LEVEL=info \ +MG_AUTH_DB_HOST=localhost \ +MG_AUTH_DB_PORT=5432 \ +MG_AUTH_DB_USER=supermq \ +MG_AUTH_DB_PASSWORD=supermq \ +MG_AUTH_DB_NAME=auth \ +MG_AUTH_DB_SSL_MODE=disable \ +MG_AUTH_DB_SSL_CERT="" \ +MG_AUTH_DB_SSL_KEY="" \ +MG_AUTH_DB_SSL_ROOT_CERT="" \ +MG_AUTH_HTTP_HOST=localhost \ +MG_AUTH_HTTP_PORT=8189 \ +MG_AUTH_HTTP_SERVER_CERT="" \ +MG_AUTH_HTTP_SERVER_KEY="" \ +MG_AUTH_GRPC_HOST=localhost \ +MG_AUTH_GRPC_PORT=8181 \ +MG_AUTH_GRPC_SERVER_CERT="" \ +MG_AUTH_GRPC_SERVER_KEY="" \ +MG_AUTH_GRPC_SERVER_CA_CERTS="" \ +MG_AUTH_GRPC_CLIENT_CA_CERTS="" \ +MG_AUTH_SECRET_KEY=secret \ +MG_AUTH_ACCESS_TOKEN_DURATION=1h \ +MG_AUTH_REFRESH_TOKEN_DURATION=24h \ +MG_AUTH_INVITATION_DURATION=168h \ +MG_SPICEDB_HOST=localhost \ +MG_SPICEDB_PORT=50051 \ +MG_SPICEDB_PRE_SHARED_KEY=12345678 \ +MG_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.zed \ +MG_JAEGER_URL=http://localhost:14268/api/traces \ +MG_JAEGER_TRACE_RATIO=1.0 \ +MG_SEND_TELEMETRY=true \ +MG_AUTH_ADAPTER_INSTANCE_ID="" \ +MG_CALLOUT_URLS="" \ +MG_CALLOUT_METHOD="POST" \ +MG_CALLOUT_TLS_VERIFICATION=true \ $GOBIN/supermq-auth ``` -Setting `SMQ_AUTH_HTTP_SERVER_CERT` and `SMQ_AUTH_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. -Setting `SMQ_AUTH_GRPC_SERVER_CERT` and `SMQ_AUTH_GRPC_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_AUTH_GRPC_SERVER_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. Setting `SMQ_AUTH_GRPC_CLIENT_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. +Setting `MG_AUTH_HTTP_SERVER_CERT` and `MG_AUTH_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. +Setting `MG_AUTH_GRPC_SERVER_CERT` and `MG_AUTH_GRPC_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_AUTH_GRPC_SERVER_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. Setting `MG_AUTH_GRPC_CLIENT_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. ## Personal Access Tokens (PATs) diff --git a/auth/service.go b/auth/service.go index 7f513f934..d02d24df5 100644 --- a/auth/service.go +++ b/auth/service.go @@ -258,7 +258,7 @@ func (svc service) checkPolicy(ctx context.Context, pr policies.Policy) error { } func (svc service) PolicyValidation(pr policies.Policy) error { - if pr.ObjectType == policies.PlatformType && pr.Object != policies.SuperMQObject { + if pr.ObjectType == policies.PlatformType && pr.Object != policies.MagistralaObject { return errPlatform } return nil @@ -375,7 +375,7 @@ func (svc service) checkUserRole(ctx context.Context, key Key) (err error) { Subject: key.Subject, SubjectType: policies.UserType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, nil); err != nil { return errRoleAuth @@ -386,7 +386,7 @@ func (svc service) checkUserRole(ctx context.Context, key Key) (err error) { Subject: key.Subject, SubjectType: policies.UserType, Permission: policies.MembershipPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, nil); err != nil { return errRoleAuth @@ -403,7 +403,7 @@ func (svc service) getUserRole(ctx context.Context, userID string) (role Role) { Subject: userID, SubjectType: policies.UserType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, nil); err == nil { rl = AdminRole diff --git a/auth/service_test.go b/auth/service_test.go index 643f37784..b832855f4 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -141,7 +141,7 @@ func TestIssue(t *testing.T) { Subject: tc.key.Subject, SubjectType: policies.UserType, Permission: policies.MembershipPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }).Return(tc.roleCheckErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) @@ -195,7 +195,7 @@ func TestIssue(t *testing.T) { Subject: tc.key.Subject, SubjectType: policies.UserType, Permission: policies.MembershipPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }).Return(tc.roleCheckErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) @@ -290,7 +290,7 @@ func TestIssue(t *testing.T) { Subject: tc.key.Subject, SubjectType: policies.UserType, Permission: policies.MembershipPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }).Return(tc.roleCheckErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) @@ -404,7 +404,7 @@ func TestIssue(t *testing.T) { Subject: tc.key.Subject, SubjectType: policies.UserType, Permission: policies.MembershipPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }).Return(tc.roleCheckErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) @@ -887,14 +887,14 @@ func TestAuthorize(t *testing.T) { policyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, checkPolicyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, @@ -949,7 +949,7 @@ func TestAuthorize(t *testing.T) { policyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, @@ -964,7 +964,7 @@ func TestAuthorize(t *testing.T) { checkPolicyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, @@ -976,7 +976,7 @@ func TestAuthorize(t *testing.T) { policyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, @@ -991,7 +991,7 @@ func TestAuthorize(t *testing.T) { checkPolicyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, @@ -1049,14 +1049,14 @@ func TestAuthorize(t *testing.T) { policyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, checkPolicyReq: policies.Policy{ SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, Permission: policies.AdminPermission, }, diff --git a/auth/tokenizer/asymmetric/README.md b/auth/tokenizer/asymmetric/README.md index 2bade9fdb..e59e4a98c 100644 --- a/auth/tokenizer/asymmetric/README.md +++ b/auth/tokenizer/asymmetric/README.md @@ -14,8 +14,8 @@ The tokenizer uses environment variables to specify key file paths: | Environment Variable | Required | Description | | --------------------------------- | -------- | ------------------------------------------------ | -| `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` | Yes | Path to active private key file | -| `SMQ_AUTH_KEYS_RETIRING_KEY_PATH` | No | Path to retiring private key file (for rotation) | +| `MG_AUTH_KEYS_ACTIVE_KEY_PATH` | Yes | Path to active private key file | +| `MG_AUTH_KEYS_RETIRING_KEY_PATH` | No | Path to retiring private key file (for rotation) | Please note that key names are used as **key IDs (kid)**. @@ -24,7 +24,7 @@ Please note that key names are used as **key IDs (kid)**. Set only the active key path: ```bash -export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key" +export MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key" ``` The tokenizer will: @@ -38,8 +38,8 @@ The tokenizer will: Set both active and retiring key paths: ```bash -export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" -export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" +export MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" +export MG_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" ``` The tokenizer will: @@ -64,12 +64,12 @@ Move the current active key to retiring position and set the new key as active: ```bash # Before rotation -SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/current.key" -SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # No retiring key +MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/current.key" +MG_AUTH_KEYS_RETIRING_KEY_PATH="" # No retiring key # During rotation (both keys active for grace period) -SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" -SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/current.key" +MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" +MG_AUTH_KEYS_RETIRING_KEY_PATH="./keys/current.key" # After rotation (restart service with new config) docker-compose restart auth @@ -83,8 +83,8 @@ After the grace period expires (typically 7-30 days), remove the retiring key: ```bash # Remove retiring key configuration -SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" -SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" # Remove retiring key +MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/new.key" +MG_AUTH_KEYS_RETIRING_KEY_PATH="" # Remove retiring key # Restart service docker-compose restart auth @@ -121,20 +121,20 @@ The grace period should be longer than your longest-lived access token duration. ```bash # Day 0: Normal operation -export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2024.pem" -export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" +export MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2024.pem" +export MG_AUTH_KEYS_RETIRING_KEY_PATH="" # Day 1: Start rotation - generate new key openssl genpkey -algorithm Ed25519 -out ./keys/key-2025.pem chmod 600 ./keys/key-2025.pem # Day 1: Update config and restart -export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2025.pem" -export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/key-2024.pem" +export MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/key-2025.pem" +export MG_AUTH_KEYS_RETIRING_KEY_PATH="./keys/key-2024.pem" docker-compose restart auth # Day 8: Grace period expired - remove old key -export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="" +export MG_AUTH_KEYS_RETIRING_KEY_PATH="" docker-compose restart auth rm ./keys/key-2024.pem ``` @@ -147,7 +147,7 @@ rm ./keys/key-2024.pem Error: active key file not found: ./keys/active.key ``` -**Solution:** Ensure the file exists and path is correct. Verify `SMQ_AUTH_KEYS_ACTIVE_KEY_PATH` environment variable. +**Solution:** Ensure the file exists and path is correct. Verify `MG_AUTH_KEYS_ACTIVE_KEY_PATH` environment variable. ### Retiring key warning diff --git a/bootstrap/README.md b/bootstrap/README.md new file mode 100644 index 000000000..d5d1c0498 --- /dev/null +++ b/bootstrap/README.md @@ -0,0 +1,122 @@ +# BOOTSTRAP SERVICE + +New devices need to be configured properly and connected to the Magistrala. Bootstrap service is used in order to accomplish that. This service provides the following features: + +1. Creating new Magistrala Clients +2. Providing basic configuration for the newly created Clients +3. Enabling/disabling Clients + +Pre-provisioning a new Client is as simple as sending Configuration data to the Bootstrap service. Once the Client is online, it sends a request for initial config to Bootstrap service. Bootstrap service provides an API for enabling and disabling Clients. Only enabled Clients can exchange messages over Magistrala. Bootstrapping does not implicitly enable Clients, it has to be done manually. + +In order to bootstrap successfully, the Client needs to send bootstrapping request to the specific URL, as well as a secret key. This key and URL are pre-provisioned during the manufacturing process. If the Client is provisioned on the Bootstrap service side, the corresponding configuration will be sent as a response. Otherwise, the Client will be saved so that it can be provisioned later. + +## Client Configuration Entity + +Client Configuration consists of two logical parts: the custom configuration that can be interpreted by the Client itself and Magistrala-related configuration. Magistrala config contains: + +1. corresponding Magistrala Client ID +2. corresponding Magistrala Client key +3. list of the Magistrala channels the Client is connected to + +> Note: list of channels contains IDs of the Magistrala channels. These channels are _pre-provisioned_ on the Magistrala side and, unlike corresponding Magistrala Client, Bootstrap service is not able to create Magistrala Channels. + +Enabling and disabling Client (adding Client to/from whitelist) is as simple as connecting corresponding Magistrala Client to the given list of Channels. Configuration keeps _state_ of the Client: + +| State | What it means | +| -------- | ---------------------------------------------- | +| Inactive | Client is created, but isn't enabled | +| Active | Client is able to communicate using Magistrala | + +Switching between states `Active` and `Inactive` enables and disables Client, respectively. + +Client configuration also contains the so-called `external ID` and `external key`. An external ID is a unique identifier of corresponding Client. For example, a device MAC address is a good choice for external ID. External key is a secret key that is used for authentication during the bootstrapping procedure. + +## Configuration + +The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. + +| Variable | Description | Default | +| ------------------------------ | -------------------------------------------------------------------------------- | --------------------------------- | +| MG_BOOTSTRAP_LOG_LEVEL | Log level for Bootstrap (debug, info, warn, error) | info | +| MG_BOOTSTRAP_DB_HOST | Database host address | localhost | +| MG_BOOTSTRAP_DB_PORT | Database host port | 5432 | +| MG_BOOTSTRAP_DB_USER | Database user | magistrala | +| MG_BOOTSTRAP_DB_PASS | Database password | magistrala | +| MG_BOOTSTRAP_DB_NAME | Name of the database used by the service | bootstrap | +| MG_BOOTSTRAP_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| MG_BOOTSTRAP_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| MG_BOOTSTRAP_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| MG_BOOTSTRAP_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| MG_BOOTSTRAP_ENCRYPT_KEY | Secret key for secure bootstrapping encryption | 12345678910111213141516171819202 | +| MG_BOOTSTRAP_HTTP_HOST | Bootstrap service HTTP host | "" | +| MG_BOOTSTRAP_HTTP_PORT | Bootstrap service HTTP port | 9013 | +| MG_BOOTSTRAP_HTTP_SERVER_CERT | Path to server certificate in pem format | "" | +| MG_BOOTSTRAP_HTTP_SERVER_KEY | Path to server key in pem format | "" | +| MG_BOOTSTRAP_EVENT_CONSUMER | Bootstrap service event source consumer name | bootstrap | +| MG_ES_URL | Event store URL | | +| MG_AUTH_GRPC_URL | Auth service Auth gRPC URL | | +| MG_AUTH_GRPC_TIMEOUT | Auth service Auth gRPC request timeout in seconds | 1s | +| MG_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded auth service Auth gRPC client certificate file | "" | +| MG_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded auth service Auth gRPC client key file | "" | +| MG_AUTH_GRPC_SERVER_CERTS | Path to the PEM encoded auth server Auth gRPC server trusted CA certificate file | "" | +| MG_CLIENTS_URL | Base URL for Magistrala Clients | | +| MG_JAEGER_URL | Jaeger server URL | | +| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | +| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | +| MG_BOOTSTRAP_INSTANCE_ID | Bootstrap service instance ID | "" | + +## Deployment + +The service itself is distributed as Docker container. Check the [`bootstrap`](https://github.com/absmach/magistrala/blob/main/docker/addons/bootstrap/docker-compose.yaml) service section in docker-compose file to see how service is deployed. + +To start the service outside of the container, execute the following shell script: + +```bash +# download the latest version of the service +git clone https://github.com/absmach/magistrala + +cd magistrala + +# compile the servic e +make bootstrap + +# copy binary to bin +make install + +# set the environment variables and run the service +MG_BOOTSTRAP_LOG_LEVEL=info \ +MG_BOOTSTRAP_DB_HOST=localhost \ +MG_BOOTSTRAP_DB_PORT=5432 \ +MG_BOOTSTRAP_DB_USER=magistrala \ +MG_BOOTSTRAP_DB_PASS=magistrala \ +MG_BOOTSTRAP_DB_NAME=bootstrap \ +MG_BOOTSTRAP_DB_SSL_MODE=disable \ +MG_BOOTSTRAP_DB_SSL_CERT="" \ +MG_BOOTSTRAP_DB_SSL_KEY="" \ +MG_BOOTSTRAP_DB_SSL_ROOT_CERT="" \ +MG_BOOTSTRAP_HTTP_HOST=localhost \ +MG_BOOTSTRAP_HTTP_PORT=9013 \ +MG_BOOTSTRAP_HTTP_SERVER_CERT="" \ +MG_BOOTSTRAP_HTTP_SERVER_KEY="" \ +MG_BOOTSTRAP_EVENT_CONSUMER=bootstrap \ +MG_ES_URL=nats://localhost:4222 \ +MG_AUTH_GRPC_URL=localhost:8181 \ +MG_AUTH_GRPC_TIMEOUT=1s \ +MG_AUTH_GRPC_CLIENT_CERT="" \ +MG_AUTH_GRPC_CLIENT_KEY="" \ +MG_AUTH_GRPC_SERVER_CERTS="" \ +MG_CLIENTS_URL=http://localhost:9000 \ +MG_JAEGER_URL=http://localhost:14268/api/traces \ +MG_JAEGER_TRACE_RATIO=1.0 \ +MG_SEND_TELEMETRY=true \ +MG_BOOTSTRAP_INSTANCE_ID="" \ +$GOBIN/magistrala-bootstrap +``` + +Setting `MG_BOOTSTRAP_HTTP_SERVER_CERT` and `MG_BOOTSTRAP_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. + +Setting `MG_AUTH_GRPC_CLIENT_CERT` and `MG_AUTH_GRPC_CLIENT_KEY` will enable TLS against the auth service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_AUTH_GRPC_SERVER_CERTS` will enable TLS against the auth service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. + +## Usage + +For more information about service capabilities and its usage, please check out the [API documentation](https://docs.api.magistrala.absmach.eu/?urls.primaryName=bootstrap.yaml). diff --git a/bootstrap/api/doc.go b/bootstrap/api/doc.go new file mode 100644 index 000000000..1e8268ee6 --- /dev/null +++ b/bootstrap/api/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains implementation of bootstrap service HTTP API. +package api diff --git a/bootstrap/api/endpoint.go b/bootstrap/api/endpoint.go new file mode 100644 index 000000000..5ed6b778a --- /dev/null +++ b/bootstrap/api/endpoint.go @@ -0,0 +1,289 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" +) + +func addEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(addReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + channels := []bootstrap.Channel{} + for _, c := range req.Channels { + channels = append(channels, bootstrap.Channel{ID: c}) + } + + config := bootstrap.Config{ + ClientID: req.ClientID, + ExternalID: req.ExternalID, + ExternalKey: req.ExternalKey, + Channels: channels, + Name: req.Name, + ClientCert: req.ClientCert, + ClientKey: req.ClientKey, + CACert: req.CACert, + Content: req.Content, + } + + saved, err := svc.Add(ctx, session, req.token, config) + if err != nil { + return nil, err + } + + res := configRes{ + id: saved.ClientID, + created: true, + } + + return res, nil + } +} + +func updateCertEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(updateCertReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + cfg, err := svc.UpdateCert(ctx, session, req.clientID, req.ClientCert, req.ClientKey, req.CACert) + if err != nil { + return nil, err + } + + res := updateConfigRes{ + ClientID: cfg.ClientID, + ClientCert: cfg.ClientCert, + CACert: cfg.CACert, + ClientKey: cfg.ClientKey, + } + + return res, nil + } +} + +func viewEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(entityReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + config, err := svc.View(ctx, session, req.id) + if err != nil { + return nil, err + } + + var channels []channelRes + for _, ch := range config.Channels { + channels = append(channels, channelRes{ + ID: ch.ID, + Name: ch.Name, + Metadata: ch.Metadata, + }) + } + + res := viewRes{ + ClientID: config.ClientID, + CLientSecret: config.ClientSecret, + Channels: channels, + ExternalID: config.ExternalID, + ExternalKey: config.ExternalKey, + Name: config.Name, + Content: config.Content, + State: config.State, + } + + return res, nil + } +} + +func updateEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(updateReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + config := bootstrap.Config{ + ClientID: req.id, + Name: req.Name, + Content: req.Content, + } + + if err := svc.Update(ctx, session, config); err != nil { + return nil, err + } + + res := configRes{ + id: config.ClientID, + created: false, + } + + return res, nil + } +} + +func updateConnEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(updateConnReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + if err := svc.UpdateConnections(ctx, session, req.token, req.id, req.Channels); err != nil { + return nil, err + } + + res := configRes{ + id: req.id, + created: false, + } + + return res, nil + } +} + +func listEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(listReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + page, err := svc.List(ctx, session, req.filter, req.offset, req.limit) + if err != nil { + return nil, err + } + res := listRes{ + Total: page.Total, + Offset: page.Offset, + Limit: page.Limit, + Configs: []viewRes{}, + } + + for _, cfg := range page.Configs { + var channels []channelRes + for _, ch := range cfg.Channels { + channels = append(channels, channelRes{ + ID: ch.ID, + Name: ch.Name, + Metadata: ch.Metadata, + }) + } + + view := viewRes{ + ClientID: cfg.ClientID, + CLientSecret: cfg.ClientSecret, + Channels: channels, + ExternalID: cfg.ExternalID, + ExternalKey: cfg.ExternalKey, + Name: cfg.Name, + Content: cfg.Content, + State: cfg.State, + } + res.Configs = append(res.Configs, view) + } + + return res, nil + } +} + +func removeEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(entityReq) + if err := req.validate(); err != nil { + return removeRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + if err := svc.Remove(ctx, session, req.id); err != nil { + return nil, err + } + + return removeRes{}, nil + } +} + +func bootstrapEndpoint(svc bootstrap.Service, reader bootstrap.ConfigReader, secure bool) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(bootstrapReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + cfg, err := svc.Bootstrap(ctx, req.key, req.id, secure) + if err != nil { + return nil, err + } + + return reader.ReadConfig(cfg, secure) + } +} + +func stateEndpoint(svc bootstrap.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(changeStateReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + if err := svc.ChangeState(ctx, session, req.token, req.id, req.State); err != nil { + return nil, err + } + + return stateRes{}, nil + } +} diff --git a/bootstrap/api/endpoint_test.go b/bootstrap/api/endpoint_test.go new file mode 100644 index 000000000..4cc2a9b5b --- /dev/null +++ b/bootstrap/api/endpoint_test.go @@ -0,0 +1,1419 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" + bsapi "github.com/absmach/supermq/bootstrap/api" + "github.com/absmach/supermq/bootstrap/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + validToken = "validToken" + domainID = "b4d7d79e-fd99-4c2b-ac09-524e43df6888" + invalidToken = "invalid" + email = "test@example.com" + unknown = "unknown" + channelsNum = 3 + contentType = "application/json" + wrongID = "wrong_id" + + addName = "name" + addContent = "config" + instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" +) + +var ( + encKey = []byte("1234567891011121") + metadata = map[string]any{"meta": "data"} + addExternalID = testsutil.GenerateUUID(&testing.T{}) + addExternalKey = testsutil.GenerateUUID(&testing.T{}) + addClientID = testsutil.GenerateUUID(&testing.T{}) + addClientSecret = testsutil.GenerateUUID(&testing.T{}) + addReq = struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key"` + Channels []string `json:"channels"` + Name string `json:"name"` + Content string `json:"content"` + }{ + ClientID: addClientID, + ClientSecret: addClientSecret, + ExternalID: addExternalID, + ExternalKey: addExternalKey, + Channels: []string{"1"}, + Name: "name", + Content: "config", + } + + updateReq = struct { + Channels []string `json:"channels,omitempty"` + Content string `json:"content,omitempty"` + State bootstrap.State `json:"state,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + CACert string `json:"ca_cert,omitempty"` + }{ + Channels: []string{"1"}, + Content: "config update", + State: 1, + ClientCert: "newcert", + ClientSecret: "newkey", + CACert: "newca", + } + + missingIDRes = toJSON(apiutil.ErrMissingID) + missingKeyRes = toJSON(apiutil.ErrBearerKey) + unknownExternalIDErrorRes = toJSON(svcerr.ErrNotFound) + extKeyRes = toJSON(bootstrap.ErrExternalKey) + extSecKeyRes = toJSON(bootstrap.ErrExternalKeySecure) +) + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + key string + body io.Reader +} + +func newConfig() bootstrap.Config { + return bootstrap.Config{ + ClientID: addClientID, + ClientSecret: addClientSecret, + ExternalID: addExternalID, + ExternalKey: addExternalKey, + Channels: []bootstrap.Channel{ + { + ID: "1", + Metadata: metadata, + }, + }, + Name: addName, + Content: addContent, + ClientCert: "newcert", + ClientKey: "newkey", + CACert: "newca", + } +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + if tr.key != "" { + req.Header.Set("Authorization", apiutil.ClientPrefix+tr.key) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + return tr.client.Do(req) +} + +func enc(in []byte) ([]byte, error) { + block, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(in)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], in) + return ciphertext, nil +} + +func dec(in []byte) ([]byte, error) { + block, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + if len(in) < aes.BlockSize { + return nil, errors.ErrMalformedEntity + } + iv := in[:aes.BlockSize] + in = in[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(in, in) + return in, nil +} + +func newBootstrapServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + logger := smqlog.NewMock() + svc := new(mocks.Service) + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := bsapi.MakeHandler(svc, am, bootstrap.NewConfigReader(encKey), logger, instanceID) + return httptest.NewServer(mux), svc, authn +} + +func toJSON(data any) string { + jsonData, err := json.Marshal(data) + if err != nil { + return "" + } + return string(jsonData) +} + +func TestAdd(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + data := toJSON(addReq) + + neID := addReq + neID.ClientID = testsutil.GenerateUUID(t) + neData := toJSON(neID) + + invalidChannels := addReq + invalidChannels.Channels = []string{wrongID} + wrongData := toJSON(invalidChannels) + + cases := []struct { + desc string + req string + domainID string + token string + session smqauthn.Session + contentType string + status int + location string + authenticateErr error + err error + }{ + { + desc: "add a config with invalid token", + req: data, + domainID: domainID, + token: invalidToken, + contentType: contentType, + status: http.StatusUnauthorized, + location: "", + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "add a valid config", + req: data, + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusCreated, + location: "/clients/configs/" + c.ClientID, + err: nil, + }, + { + desc: "add a config with wrong content type", + req: data, + domainID: domainID, + token: validToken, + contentType: "", + status: http.StatusUnsupportedMediaType, + location: "", + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "add an existing config", + req: data, + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrConflict, + }, + { + desc: "add a config with non-existent ID", + req: neData, + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrConflict, + }, + { + desc: "add a config with invalid channels", + req: wrongData, + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrConflict, + }, + { + desc: "add a config with wrong JSON", + req: "{\"external_id\": 5}", + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add a config with invalid request format", + req: "}", + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add a config with empty JSON", + req: "{}", + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "add a config with an empty request", + req: "", + domainID: domainID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + + svcCall := svc.On("Add", mock.Anything, tc.session, tc.token, mock.Anything).Return(c, tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/clients/configs", bs.URL, tc.domainID), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + location := res.Header.Get("Location") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.location, location, fmt.Sprintf("%s: expected location '%s' got '%s'", tc.desc, tc.location, location)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestView(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + var channels []channel + for _, ch := range c.Channels { + channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata}) + } + + data := config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + State: c.State, + Channels: channels, + ExternalID: c.ExternalID, + ExternalKey: c.ExternalKey, + Name: c.Name, + Content: c.Content, + } + + cases := []struct { + desc string + token string + session smqauthn.Session + id string + status int + res config + authenticateErr error + err error + }{ + { + desc: "view a config with invalid token", + token: invalidToken, + id: c.ClientID, + status: http.StatusUnauthorized, + res: config{}, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "view a config", + token: validToken, + id: c.ClientID, + status: http.StatusOK, + res: data, + err: nil, + }, + { + desc: "view a non-existing config", + token: validToken, + id: wrongID, + status: http.StatusNotFound, + res: config{}, + err: svcerr.ErrNotFound, + }, + { + desc: "view a config with an empty token", + token: "", + id: c.ClientID, + status: http.StatusUnauthorized, + res: config{}, + err: apiutil.ErrBearerToken, + }, + { + desc: "view config without authorization", + token: validToken, + id: c.ClientID, + status: http.StatusForbidden, + res: config{}, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("View", mock.Anything, tc.session, tc.id).Return(c, tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/clients/configs/%s", bs.URL, domainID, tc.id), + token: tc.token, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + var view config + if err := json.NewDecoder(res.Body).Decode(&view); err != io.EOF { + assert.Nil(t, err, fmt.Sprintf("Decoding expected to succeed %s: %s", tc.desc, err)) + } + + assert.ElementsMatch(t, tc.res.Channels, view.Channels, fmt.Sprintf("%s: expected response '%s' got '%s'", tc.desc, tc.res.Channels, view.Channels)) + // Empty channels to prevent order mismatch. + tc.res.Channels = []channel{} + view.Channels = []channel{} + assert.Equal(t, tc.res, view, fmt.Sprintf("%s: expected response '%s' got '%s'", tc.desc, tc.res, view)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdate(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + data := toJSON(updateReq) + + cases := []struct { + desc string + req string + id string + token string + session smqauthn.Session + contentType string + status int + authenticateErr error + err error + }{ + { + desc: "update with invalid token", + req: data, + id: c.ClientID, + token: invalidToken, + contentType: contentType, + status: http.StatusUnauthorized, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "update with an empty token", + req: data, + id: c.ClientID, + token: "", + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update a valid config", + req: data, + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusOK, + err: nil, + }, + { + desc: "update a config with wrong content type", + req: data, + id: c.ClientID, + token: validToken, + contentType: "", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update a non-existing config", + req: data, + id: wrongID, + token: validToken, + contentType: contentType, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "update a config with invalid request format", + req: "}", + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "update a config with an empty request", + id: c.ClientID, + req: "", + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("Update", mock.Anything, tc.session, mock.Anything).Return(tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodPut, + url: fmt.Sprintf("%s/%s/clients/configs/%s", bs.URL, domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateCert(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + data := toJSON(updateReq) + + cases := []struct { + desc string + req string + id string + token string + session smqauthn.Session + contentType string + status int + authenticateErr error + err error + }{ + { + desc: "update with invalid token", + req: data, + id: c.ClientID, + token: invalidToken, + contentType: contentType, + status: http.StatusUnauthorized, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "update with an empty token", + req: data, + id: c.ClientID, + token: "", + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update a valid config", + req: data, + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusOK, + err: nil, + }, + { + desc: "update a config with wrong content type", + req: data, + id: c.ClientID, + token: validToken, + contentType: "", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update a non-existing config", + req: data, + id: wrongID, + token: validToken, + contentType: contentType, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "update a config with invalid request format", + req: "}", + id: c.ClientSecret, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "update a config with an empty request", + id: c.ClientID, + req: "", + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("UpdateCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(c, tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/clients/configs/certs/%s", bs.URL, domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateConnections(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + data := toJSON(updateReq) + + invalidChannels := updateReq + invalidChannels.Channels = []string{wrongID} + + wrongData := toJSON(invalidChannels) + + cases := []struct { + desc string + req string + id string + token string + session smqauthn.Session + contentType string + status int + authenticateErr error + err error + }{ + { + desc: "update connections with invalid token", + req: data, + id: c.ClientID, + token: invalidToken, + contentType: contentType, + status: http.StatusUnauthorized, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "update connections with an empty token", + req: data, + id: c.ClientID, + token: "", + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update connections valid config", + req: data, + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusOK, + err: nil, + }, + { + desc: "update connections with wrong content type", + req: data, + id: c.ClientID, + token: validToken, + contentType: "", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update connections for a non-existing config", + req: data, + id: wrongID, + token: validToken, + contentType: contentType, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "update connections with invalid channels", + req: wrongData, + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "update a config with invalid request format", + req: "}", + id: c.ClientID, + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "update a config with an empty request", + id: c.ClientID, + req: "", + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + repoCall := svc.On("UpdateConnections", mock.Anything, tc.session, tc.token, mock.Anything, mock.Anything).Return(tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodPut, + url: fmt.Sprintf("%s/%s/clients/configs/connections/%s", bs.URL, domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + repoCall.Unset() + authCall.Unset() + }) + } +} + +func TestList(t *testing.T) { + configNum := 101 + changedStateNum := 20 + var active, inactive []config + list := make([]config, configNum) + + bs, svc, auth := newBootstrapServer() + defer bs.Close() + path := fmt.Sprintf("%s/%s/%s", bs.URL, domainID, "clients/configs") + + c := newConfig() + + for i := 0; i < configNum; i++ { + c.ExternalID = strconv.Itoa(i) + c.ClientSecret = c.ExternalID + c.Name = fmt.Sprintf("%s-%d", addName, i) + c.ExternalKey = fmt.Sprintf("%s%s", addExternalKey, strconv.Itoa(i)) + + var channels []channel + for _, ch := range c.Channels { + channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata}) + } + s := config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + Channels: channels, + ExternalID: c.ExternalID, + ExternalKey: c.ExternalKey, + Name: c.Name, + Content: c.Content, + State: c.State, + } + list[i] = s + } + // Change state of first 20 elements for filtering tests. + for i := 0; i < changedStateNum; i++ { + state := bootstrap.Active + if i%2 == 0 { + state = bootstrap.Inactive + } + svcCall := svc.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + err := svc.ChangeState(context.Background(), smqauthn.Session{}, validToken, list[i].ClientID, state) + assert.Nil(t, err, fmt.Sprintf("Changing state expected to succeed: %s.\n", err)) + + svcCall.Unset() + + list[i].State = state + if state == bootstrap.Inactive { + inactive = append(inactive, list[i]) + continue + } + active = append(active, list[i]) + } + + cases := []struct { + desc string + token string + session smqauthn.Session + url string + status int + res configPage + authenticateErr error + err error + }{ + { + desc: "view list with invalid token", + token: invalidToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d", path, 0, 10), + status: http.StatusUnauthorized, + res: configPage{}, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "view list with an empty token", + token: "", + url: fmt.Sprintf("%s?offset=%d&limit=%d", path, 0, 10), + status: http.StatusUnauthorized, + res: configPage{}, + err: apiutil.ErrBearerToken, + }, + { + desc: "view list", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d", path, 0, 1), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list)), + Offset: 0, + Limit: 1, + Configs: list[0:1], + }, + err: nil, + }, + { + desc: "view list searching by name", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&name=%s", path, 0, 100, "95"), + status: http.StatusOK, + res: configPage{ + Total: 1, + Offset: 0, + Limit: 100, + Configs: list[95:96], + }, + err: nil, + }, + { + desc: "view last page", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d", path, 100, 10), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list)), + Offset: 100, + Limit: 10, + Configs: list[100:], + }, + err: nil, + }, + { + desc: "view with limit greater than allowed", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d", path, 0, 1000), + status: http.StatusBadRequest, + res: configPage{}, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "view list with no specified limit and offset", + token: validToken, + url: path, + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list)), + Offset: 0, + Limit: 10, + Configs: list[0:10], + }, + err: nil, + }, + { + desc: "view list with no specified limit", + token: validToken, + url: fmt.Sprintf("%s?offset=%d", path, 10), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list)), + Offset: 10, + Limit: 10, + Configs: list[10:20], + }, + err: nil, + }, + { + desc: "view list with no specified offset", + token: validToken, + url: fmt.Sprintf("%s?limit=%d", path, 10), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list)), + Offset: 0, + Limit: 10, + Configs: list[0:10], + }, + err: nil, + }, + { + desc: "view list with limit < 0", + token: validToken, + url: fmt.Sprintf("%s?limit=%d", path, -10), + status: http.StatusBadRequest, + res: configPage{}, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "view list with offset < 0", + token: validToken, + url: fmt.Sprintf("%s?offset=%d", path, -10), + status: http.StatusBadRequest, + res: configPage{}, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "view list with invalid query parameters", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&state=%d&key=%%", path, 10, 10, bootstrap.Inactive), + status: http.StatusBadRequest, + res: configPage{}, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "view first 10 active", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&state=%d", path, 0, 20, bootstrap.Active), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(active)), + Offset: 0, + Limit: 20, + Configs: active, + }, + err: nil, + }, + { + desc: "view first 10 inactive", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&state=%d", path, 0, 20, bootstrap.Inactive), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list) - len(inactive)), + Offset: 0, + Limit: 20, + Configs: inactive, + }, + err: nil, + }, + { + desc: "view first 5 active", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&state=%d", path, 0, 10, bootstrap.Active), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(active)), + Offset: 0, + Limit: 10, + Configs: active[:5], + }, + err: nil, + }, + { + desc: "view last 5 inactive", + token: validToken, + url: fmt.Sprintf("%s?offset=%d&limit=%d&state=%d", path, 10, 10, bootstrap.Inactive), + status: http.StatusOK, + res: configPage{ + Total: uint64(len(list) - len(active)), + Offset: 10, + Limit: 10, + Configs: inactive[5:], + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bootstrap.ConfigsPage{Total: tc.res.Total, Offset: tc.res.Offset, Limit: tc.res.Limit}, tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodGet, + url: tc.url, + token: tc.token, + } + + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + var body configPage + + err = json.NewDecoder(res.Body).Decode(&body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + assert.Equal(t, tc.res.Total, body.Total, fmt.Sprintf("%s: expected response total '%d' got '%d'", tc.desc, tc.res.Total, body.Total)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemove(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + status int + authenticateErr error + err error + }{ + { + desc: "remove with invalid token", + id: c.ClientID, + token: invalidToken, + status: http.StatusUnauthorized, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "remove with an empty token", + id: c.ClientID, + token: "", + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "remove non-existing config", + id: "non-existing", + token: validToken, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove config", + id: c.ClientID, + token: validToken, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove removed config", + id: wrongID, + token: validToken, + status: http.StatusNoContent, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/clients/configs/%s", bs.URL, domainID, tc.id), + token: tc.token, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestBootstrap(t *testing.T) { + bs, svc, _ := newBootstrapServer() + defer bs.Close() + c := newConfig() + + encExternKey, err := enc([]byte(c.ExternalKey)) + assert.Nil(t, err, fmt.Sprintf("Encrypting config expected to succeed: %s.\n", err)) + + var channels []channel + for _, ch := range c.Channels { + channels = append(channels, channel{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata}) + } + + s := struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Channels []channel `json:"channels"` + Content string `json:"content"` + ClientCert string `json:"client_cert"` + ClientKey string `json:"client_key"` + CACert string `json:"ca_cert"` + }{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + Channels: channels, + Content: c.Content, + ClientCert: c.ClientCert, + ClientKey: c.ClientKey, + CACert: c.CACert, + } + + data := toJSON(s) + + cases := []struct { + desc string + externalID string + externalKey string + status int + res string + secure bool + err error + }{ + { + desc: "bootstrap a Client with unknown ID", + externalID: unknown, + externalKey: c.ExternalKey, + status: http.StatusNotFound, + res: unknownExternalIDErrorRes, + secure: false, + err: svcerr.ErrNotFound, + }, + { + desc: "bootstrap a Client with an empty ID", + externalID: "", + externalKey: c.ExternalKey, + status: http.StatusBadRequest, + res: missingIDRes, + secure: false, + err: apiutil.ErrMissingID, + }, + { + desc: "bootstrap a Client with unknown key", + externalID: c.ExternalID, + externalKey: unknown, + status: http.StatusForbidden, + res: extKeyRes, + secure: false, + err: bootstrap.ErrExternalKey, + }, + { + desc: "bootstrap a Client with an empty key", + externalID: c.ExternalID, + externalKey: "", + status: http.StatusUnauthorized, + res: missingKeyRes, + secure: false, + err: apiutil.ErrBearerKey, + }, + { + desc: "bootstrap known Client", + externalID: c.ExternalID, + externalKey: c.ExternalKey, + status: http.StatusOK, + res: data, + secure: false, + err: nil, + }, + { + desc: "bootstrap secure", + externalID: fmt.Sprintf("secure/%s", c.ExternalID), + externalKey: hex.EncodeToString(encExternKey), + status: http.StatusOK, + res: data, + secure: true, + err: nil, + }, + { + desc: "bootstrap secure with unencrypted key", + externalID: fmt.Sprintf("secure/%s", c.ExternalID), + externalKey: c.ExternalKey, + status: http.StatusForbidden, + res: extSecKeyRes, + secure: true, + err: bootstrap.ErrExternalKeySecure, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("Bootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(c, tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/clients/bootstrap/%s", bs.URL, tc.externalID), + key: tc.externalKey, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + body, err := io.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + if tc.secure && tc.status == http.StatusOK { + body, err = dec(body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding body: %s", tc.desc, err)) + } + data := strings.Trim(string(body), "\n") + assert.Equal(t, tc.res, data, fmt.Sprintf("%s: expected response '%s' got '%s'", tc.desc, tc.res, data)) + svcCall.Unset() + }) + } +} + +func TestChangeState(t *testing.T) { + bs, svc, auth := newBootstrapServer() + defer bs.Close() + c := newConfig() + + inactive := fmt.Sprintf("{\"state\": %d}", bootstrap.Inactive) + active := fmt.Sprintf("{\"state\": %d}", bootstrap.Active) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + state string + contentType string + status int + authenticateErr error + err error + }{ + { + desc: "change state with invalid token", + id: c.ClientID, + token: invalidToken, + state: active, + contentType: contentType, + status: http.StatusUnauthorized, + authenticateErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "change state with an empty token", + id: c.ClientID, + token: "", + state: active, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "change state with invalid content type", + id: c.ClientID, + token: validToken, + state: active, + contentType: "", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "change state to active", + id: c.ClientID, + token: validToken, + state: active, + contentType: contentType, + status: http.StatusOK, + err: nil, + }, + { + desc: "change state to inactive", + id: c.ClientID, + token: validToken, + state: inactive, + contentType: contentType, + status: http.StatusOK, + err: nil, + }, + { + desc: "change state of non-existing config", + id: wrongID, + token: validToken, + state: active, + contentType: contentType, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "change state to invalid value", + id: c.ClientID, + token: validToken, + state: fmt.Sprintf("{\"state\": %d}", -3), + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "change state with invalid data", + id: c.ClientID, + token: validToken, + state: "", + contentType: contentType, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("ChangeState", mock.Anything, tc.session, tc.token, mock.Anything, mock.Anything).Return(tc.err) + req := testRequest{ + client: bs.Client(), + method: http.MethodPut, + url: fmt.Sprintf("%s/%s/clients/state/%s", bs.URL, domainID, tc.id), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.state), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +type channel struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +type config struct { + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Channels []channel `json:"channels,omitempty"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key,omitempty"` + Content string `json:"content,omitempty"` + Name string `json:"name"` + State bootstrap.State `json:"state"` +} + +type configPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Configs []config `json:"configs"` +} diff --git a/bootstrap/api/requests.go b/bootstrap/api/requests.go new file mode 100644 index 000000000..67f4c514c --- /dev/null +++ b/bootstrap/api/requests.go @@ -0,0 +1,163 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" +) + +const maxLimitSize = 100 + +type addReq struct { + token string + ClientID string `json:"client_id"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key"` + Channels []string `json:"channels"` + Name string `json:"name"` + Content string `json:"content"` + ClientCert string `json:"client_cert"` + ClientKey string `json:"client_key"` + CACert string `json:"ca_cert"` +} + +func (req addReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if req.ExternalID == "" { + return apiutil.ErrMissingID + } + + if req.ExternalKey == "" { + return apiutil.ErrBearerKey + } + + if len(req.Channels) == 0 { + return apiutil.ErrEmptyList + } + + for _, channel := range req.Channels { + if channel == "" { + return apiutil.ErrMissingID + } + } + + return nil +} + +type entityReq struct { + id string +} + +func (req entityReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type updateReq struct { + id string + Name string `json:"name"` + Content string `json:"content"` +} + +func (req updateReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type updateCertReq struct { + clientID string + ClientCert string `json:"client_cert"` + ClientKey string `json:"client_key"` + CACert string `json:"ca_cert"` +} + +func (req updateCertReq) validate() error { + if req.clientID == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type updateConnReq struct { + token string + id string + Channels []string `json:"channels"` +} + +func (req updateConnReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type listReq struct { + filter bootstrap.Filter + offset uint64 + limit uint64 +} + +func (req listReq) validate() error { + if req.limit > maxLimitSize { + return apiutil.ErrLimitSize + } + + return nil +} + +type bootstrapReq struct { + key string + id string +} + +func (req bootstrapReq) validate() error { + if req.key == "" { + return apiutil.ErrBearerKey + } + + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type changeStateReq struct { + token string + id string + State bootstrap.State `json:"state"` +} + +func (req changeStateReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if req.id == "" { + return apiutil.ErrMissingID + } + + if req.State != bootstrap.Inactive && + req.State != bootstrap.Active { + return bootstrap.ErrBootstrapState + } + + return nil +} diff --git a/bootstrap/api/requests_test.go b/bootstrap/api/requests_test.go new file mode 100644 index 000000000..5906510cc --- /dev/null +++ b/bootstrap/api/requests_test.go @@ -0,0 +1,313 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/internal/testsutil" + "github.com/stretchr/testify/assert" +) + +var ( + channel1 = testsutil.GenerateUUID(&testing.T{}) + channel2 = testsutil.GenerateUUID(&testing.T{}) +) + +func TestAddReqValidation(t *testing.T) { + cases := []struct { + desc string + token string + externalID string + externalKey string + channels []string + err error + }{ + { + desc: "valid request", + token: "token", + externalID: "external-id", + externalKey: "external-key", + channels: []string{channel1, channel2}, + err: nil, + }, + { + desc: "empty token", + token: "", + externalID: "external-id", + externalKey: "external-key", + channels: []string{channel1, channel2}, + err: apiutil.ErrBearerToken, + }, + { + desc: "empty external ID", + token: "token", + externalID: "", + externalKey: "external-key", + channels: []string{channel1, channel2}, + err: apiutil.ErrMissingID, + }, + { + desc: "empty external key", + token: "token", + externalID: "external-id", + externalKey: "", + channels: []string{channel1, channel2}, + err: apiutil.ErrBearerKey, + }, + { + desc: "empty external key and external ID", + token: "token", + externalID: "", + externalKey: "", + channels: []string{channel1, channel2}, + err: apiutil.ErrMissingID, + }, + { + desc: "empty channels", + token: "token", + externalID: "external-id", + externalKey: "external-key", + channels: []string{}, + err: apiutil.ErrEmptyList, + }, + { + desc: "empty channel value", + token: "token", + externalID: "external-id", + externalKey: "external-key", + channels: []string{channel1, ""}, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := addReq{ + token: tc.token, + ExternalID: tc.externalID, + ExternalKey: tc.externalKey, + Channels: tc.channels, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestEntityReqValidation(t *testing.T) { + cases := []struct { + desc string + id string + err error + }{ + { + desc: "empty id", + id: "", + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := entityReq{ + id: tc.id, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestUpdateReqValidation(t *testing.T) { + cases := []struct { + desc string + id string + err error + }{ + { + desc: "valid request", + id: "id", + err: nil, + }, + { + desc: "empty id", + id: "", + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := updateReq{ + id: tc.id, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestUpdateCertReqValidation(t *testing.T) { + cases := []struct { + desc string + clientID string + err error + }{ + { + desc: "empty client id", + clientID: "", + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := updateCertReq{ + clientID: tc.clientID, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestUpdateConnReqValidation(t *testing.T) { + cases := []struct { + desc string + id string + token string + + err error + }{ + { + desc: "empty token", + token: "", + id: "id", + err: apiutil.ErrBearerToken, + }, + { + desc: "empty id", + token: "token", + id: "", + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := updateConnReq{ + token: tc.token, + id: tc.id, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestListReqValidation(t *testing.T) { + cases := []struct { + desc string + offset uint64 + limit uint64 + err error + }{ + { + desc: "too large limit", + offset: 0, + limit: maxLimitSize + 1, + err: apiutil.ErrLimitSize, + }, + { + desc: "default limit", + offset: 0, + limit: defLimit, + err: nil, + }, + } + + for _, tc := range cases { + req := listReq{ + offset: tc.offset, + limit: tc.limit, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestBootstrapReqValidation(t *testing.T) { + cases := []struct { + desc string + externKey string + externID string + err error + }{ + { + desc: "empty external key", + externKey: "", + externID: "id", + err: apiutil.ErrBearerKey, + }, + { + desc: "empty external id", + externKey: "key", + externID: "", + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + req := bootstrapReq{ + id: tc.externID, + key: tc.externKey, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestChangeStateReqValidation(t *testing.T) { + cases := []struct { + desc string + token string + id string + state bootstrap.State + err error + }{ + { + desc: "empty token", + token: "", + id: "id", + state: bootstrap.State(1), + err: apiutil.ErrBearerToken, + }, + { + desc: "empty id", + token: "token", + id: "", + state: bootstrap.State(0), + err: apiutil.ErrMissingID, + }, + { + desc: "invalid state", + token: "token", + id: "id", + state: bootstrap.State(14), + err: bootstrap.ErrBootstrapState, + }, + } + + for _, tc := range cases { + req := changeStateReq{ + token: tc.token, + id: tc.id, + State: tc.state, + } + + err := req.validate() + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} diff --git a/bootstrap/api/responses.go b/bootstrap/api/responses.go new file mode 100644 index 000000000..7020dcb17 --- /dev/null +++ b/bootstrap/api/responses.go @@ -0,0 +1,144 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/bootstrap" +) + +var ( + _ supermq.Response = (*removeRes)(nil) + _ supermq.Response = (*configRes)(nil) + _ supermq.Response = (*stateRes)(nil) + _ supermq.Response = (*viewRes)(nil) + _ supermq.Response = (*listRes)(nil) +) + +type removeRes struct{} + +func (res removeRes) Code() int { + return http.StatusNoContent +} + +func (res removeRes) Headers() map[string]string { + return map[string]string{} +} + +func (res removeRes) Empty() bool { + return true +} + +type configRes struct { + id string + created bool +} + +func (res configRes) Code() int { + if res.created { + return http.StatusCreated + } + + return http.StatusOK +} + +func (res configRes) Headers() map[string]string { + if res.created { + return map[string]string{ + "Location": fmt.Sprintf("/clients/configs/%s", res.id), + } + } + + return map[string]string{} +} + +func (res configRes) Empty() bool { + return true +} + +type channelRes struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +type viewRes struct { + ClientID string `json:"client_id,omitempty"` + CLientSecret string `json:"client_secret,omitempty"` + Channels []channelRes `json:"channels,omitempty"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key,omitempty"` + Content string `json:"content,omitempty"` + Name string `json:"name,omitempty"` + State bootstrap.State `json:"state"` + ClientCert string `json:"client_cert,omitempty"` + CACert string `json:"ca_cert,omitempty"` +} + +func (res viewRes) Code() int { + return http.StatusOK +} + +func (res viewRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewRes) Empty() bool { + return false +} + +type listRes struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Configs []viewRes `json:"configs"` +} + +func (res listRes) Code() int { + return http.StatusOK +} + +func (res listRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listRes) Empty() bool { + return false +} + +type stateRes struct{} + +func (res stateRes) Code() int { + return http.StatusOK +} + +func (res stateRes) Headers() map[string]string { + return map[string]string{} +} + +func (res stateRes) Empty() bool { + return true +} + +type updateConfigRes struct { + ClientID string `json:"client_id,omitempty"` + CACert string `json:"ca_cert,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` +} + +func (res updateConfigRes) Code() int { + return http.StatusOK +} + +func (res updateConfigRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updateConfigRes) Empty() bool { + return false +} diff --git a/bootstrap/api/transport.go b/bootstrap/api/transport.go new file mode 100644 index 000000000..fa0b0d4fb --- /dev/null +++ b/bootstrap/api/transport.go @@ -0,0 +1,283 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/url" + "strings" + + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + contentType = "application/json" + byteContentType = "application/octet-stream" + offsetKey = "offset" + limitKey = "limit" + defOffset = 0 + defLimit = 10 +) + +var ( + fullMatch = []string{"state", "external_id", "client_id", "client_key"} + partialMatch = []string{"name"} + // ErrBootstrap indicates error in getting bootstrap configuration. + ErrBootstrap = errors.New("failed to read bootstrap configuration") +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc bootstrap.Service, authn smqauthn.AuthNMiddleware, reader bootstrap.ConfigReader, logger *slog.Logger, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + r := chi.NewRouter() + + r.Route("/{domainID}/clients", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) + r.Route("/configs", func(r chi.Router) { + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + addEndpoint(svc), + decodeAddRequest, + api.EncodeResponse, + opts...), "add").ServeHTTP) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listEndpoint(svc), + decodeListRequest, + api.EncodeResponse, + opts...), "list").ServeHTTP) + + r.Get("/{configID}", otelhttp.NewHandler(kithttp.NewServer( + viewEndpoint(svc), + decodeEntityRequest, + api.EncodeResponse, + opts...), "view").ServeHTTP) + + r.Put("/{configID}", otelhttp.NewHandler(kithttp.NewServer( + updateEndpoint(svc), + decodeUpdateRequest, + api.EncodeResponse, + opts...), "update").ServeHTTP) + + r.Delete("/{configID}", otelhttp.NewHandler(kithttp.NewServer( + removeEndpoint(svc), + decodeEntityRequest, + api.EncodeResponse, + opts...), "remove").ServeHTTP) + + r.Patch("/certs/{certID}", otelhttp.NewHandler(kithttp.NewServer( + updateCertEndpoint(svc), + decodeUpdateCertRequest, + api.EncodeResponse, + opts...), "update_cert").ServeHTTP) + + r.Put("/connections/{connID}", otelhttp.NewHandler(kithttp.NewServer( + updateConnEndpoint(svc), + decodeUpdateConnRequest, + api.EncodeResponse, + opts...), "update_connections").ServeHTTP) + }) + }) + + r.With(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()).Put("/state/{clientID}", otelhttp.NewHandler(kithttp.NewServer( + stateEndpoint(svc), + decodeStateRequest, + api.EncodeResponse, + opts...), "update_state").ServeHTTP) + }) + + r.Route("/clients/bootstrap", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + bootstrapEndpoint(svc, reader, false), + decodeBootstrapRequest, + api.EncodeResponse, + opts...), "bootstrap").ServeHTTP) + r.Get("/{externalID}", otelhttp.NewHandler(kithttp.NewServer( + bootstrapEndpoint(svc, reader, false), + decodeBootstrapRequest, + api.EncodeResponse, + opts...), "bootstrap").ServeHTTP) + r.Get("/secure/{externalID}", otelhttp.NewHandler(kithttp.NewServer( + bootstrapEndpoint(svc, reader, true), + decodeBootstrapRequest, + encodeSecureRes, + opts...), "bootstrap_secure").ServeHTTP) + }) + + r.Get("/health", supermq.Health("bootstrap", instanceID)) + r.Handle("/metrics", promhttp.Handler()) + + return r +} + +func decodeAddRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := addReq{ + token: apiutil.ExtractBearerToken(r), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateReq{ + id: chi.URLParam(r, "configID"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateCertRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateCertReq{ + clientID: chi.URLParam(r, "certID"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateConnRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateConnReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "connID"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeListRequest(_ context.Context, r *http.Request) (any, error) { + o, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + l, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + q, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidQueryParams) + } + + req := listReq{ + filter: parseFilter(q), + offset: o, + limit: l, + } + + return req, nil +} + +func decodeBootstrapRequest(_ context.Context, r *http.Request) (any, error) { + req := bootstrapReq{ + id: chi.URLParam(r, "externalID"), + key: apiutil.ExtractClientSecret(r), + } + + return req, nil +} + +func decodeStateRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := changeStateReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "clientID"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeEntityRequest(_ context.Context, r *http.Request) (any, error) { + req := entityReq{ + id: chi.URLParam(r, "configID"), + } + + return req, nil +} + +func encodeSecureRes(_ context.Context, w http.ResponseWriter, response any) error { + w.Header().Set("Content-Type", byteContentType) + w.WriteHeader(http.StatusOK) + if b, ok := response.([]byte); ok { + if _, err := w.Write(b); err != nil { + return err + } + } + return nil +} + +func parseFilter(values url.Values) bootstrap.Filter { + ret := bootstrap.Filter{ + FullMatch: make(map[string]string), + PartialMatch: make(map[string]string), + } + for k := range values { + if contains(fullMatch, k) { + ret.FullMatch[k] = values.Get(k) + } + if contains(partialMatch, k) { + ret.PartialMatch[k] = strings.ToLower(values.Get(k)) + } + } + + return ret +} + +func contains(l []string, s string) bool { + for _, v := range l { + if v == s { + return true + } + } + return false +} diff --git a/bootstrap/configs.go b/bootstrap/configs.go new file mode 100644 index 000000000..6561f9668 --- /dev/null +++ b/bootstrap/configs.go @@ -0,0 +1,118 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap + +import ( + "context" + "time" + + "github.com/absmach/supermq/clients" +) + +// Config represents Configuration entity. It wraps information about external entity +// as well as info about corresponding SuperMQ entities. +// MGClient represents corresponding SuperMQ Client ID. +// MGKey is key of corresponding SuperMQ Client. +// MGChannels is a list of SuperMQ Channels corresponding SuperMQ Client connects to. +type Config struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + DomainID string `json:"domain_id,omitempty"` + Name string `json:"name,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` + Channels []Channel `json:"channels,omitempty"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key"` + Content string `json:"content,omitempty"` + State State `json:"state"` +} + +// Channel represents SuperMQ channel corresponding SuperMQ Client is connected to. +type Channel struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + DomainID string `json:"domain_id"` + Parent string `json:"parent_id,omitempty"` + Description string `json:"description,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + UpdatedBy string `json:"updated_by,omitempty"` + Status clients.Status `json:"status"` +} + +// Filter is used for the search filters. +type Filter struct { + FullMatch map[string]string + PartialMatch map[string]string +} + +// ConfigsPage contains page related metadata as well as list of Configs that +// belong to this page. +type ConfigsPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Configs []Config `json:"configs"` +} + +// ConfigRepository specifies a Config persistence API. +type ConfigRepository interface { + // Save persists the Config. Successful operation is indicated by non-nil + // error response. + Save(ctx context.Context, cfg Config, chsConnIDs []string) (string, error) + + // RetrieveByID retrieves the Config having the provided identifier, that is owned + // by the specified user. + RetrieveByID(ctx context.Context, domainID, id string) (Config, error) + + // RetrieveAll retrieves a subset of Configs that are owned + // by the specific user, with given filter parameters. + RetrieveAll(ctx context.Context, domainID string, clientIDs []string, filter Filter, offset, limit uint64) ConfigsPage + + // RetrieveByExternalID returns Config for given external ID. + RetrieveByExternalID(ctx context.Context, externalID string) (Config, error) + + // Update updates an existing Config. A non-nil error is returned + // to indicate operation failure. + Update(ctx context.Context, cfg Config) error + + // UpdateCerts updates and returns an existing Config certificate and domainID. + // A non-nil error is returned to indicate operation failure. + UpdateCert(ctx context.Context, domainID, clientID, clientCert, clientKey, caCert string) (Config, error) + + // UpdateConnections updates a list of Channels the Config is connected to + // adding new Channels if needed. + UpdateConnections(ctx context.Context, domainID, id string, channels []Channel, connections []string) error + + // Remove removes the Config having the provided identifier, that is owned + // by the specified user. + Remove(ctx context.Context, domainID, id string) error + + // ChangeState changes of the Config, that is owned by the specific user. + ChangeState(ctx context.Context, domainID, id string, state State) error + + // ListExisting retrieves those channels from the given list that exist in DB. + ListExisting(ctx context.Context, domainID string, ids []string) ([]Channel, error) + + // Methods RemoveClient, UpdateChannel, and RemoveChannel are related to + // event sourcing. That's why these methods surpass ownership check. + + // RemoveClient removes Config of the Client with the given ID. + RemoveClient(ctx context.Context, id string) error + + // UpdateChannel updates channel with the given ID. + UpdateChannel(ctx context.Context, c Channel) error + + // RemoveChannel removes channel with the given ID. + RemoveChannel(ctx context.Context, id string) error + + // ConnectClient changes state of the Config when the corresponding Client is connected to the Channel. + ConnectClient(ctx context.Context, channelID, clientID string) error + + // DisconnectClient changes state of the Config when the corresponding Client is disconnected from the Channel. + DisconnectClient(ctx context.Context, channelID, clientID string) error +} diff --git a/bootstrap/doc.go b/bootstrap/doc.go new file mode 100644 index 000000000..2e939673d --- /dev/null +++ b/bootstrap/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bootstrap contains the domain concept definitions needed to support +// SuperMQ bootstrap service functionality. +package bootstrap diff --git a/bootstrap/events/consumer/doc.go b/bootstrap/events/consumer/doc.go new file mode 100644 index 000000000..f3fea76f1 --- /dev/null +++ b/bootstrap/events/consumer/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package consumer contains events consumer for events +// published by Bootstrap service. +package consumer diff --git a/bootstrap/events/consumer/events.go b/bootstrap/events/consumer/events.go new file mode 100644 index 000000000..451001b81 --- /dev/null +++ b/bootstrap/events/consumer/events.go @@ -0,0 +1,24 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import "time" + +type removeEvent struct { + id string +} + +type updateChannelEvent struct { + id string + name string + metadata map[string]any + updatedAt time.Time + updatedBy string +} + +// Connection event is either connect or disconnect event. +type connectionEvent struct { + clientIDs []string + channelID string +} diff --git a/bootstrap/events/consumer/streams.go b/bootstrap/events/consumer/streams.go new file mode 100644 index 000000000..c11802c8e --- /dev/null +++ b/bootstrap/events/consumer/streams.go @@ -0,0 +1,148 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + "time" + + "github.com/absmach/supermq/bootstrap" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/events" +) + +const ( + clientRemove = "client.remove" + clientConnect = "group.assign" + clientDisconnect = "group.unassign" + + channelPrefix = "channels." + channelUpdate = channelPrefix + "update" + channelRemove = channelPrefix + "remove" + + memberKind = "client" + relation = "group" +) + +type eventHandler struct { + svc bootstrap.Service +} + +// NewEventHandler returns new event store handler. +func NewEventHandler(svc bootstrap.Service) events.EventHandler { + return &eventHandler{ + svc: svc, + } +} + +func (es *eventHandler) Handle(ctx context.Context, event events.Event) error { + msg, err := event.Encode() + if err != nil { + return err + } + + switch msg["operation"] { + case clientRemove: + rte := decodeRemoveClient(msg) + err = es.svc.RemoveConfigHandler(ctx, rte.id) + case clientConnect: + cte := decodeConnectClient(msg) + if cte.channelID == "" || len(cte.clientIDs) == 0 { + return svcerr.ErrMalformedEntity + } + for _, clientID := range cte.clientIDs { + if clientID == "" { + return svcerr.ErrMalformedEntity + } + if err := es.svc.ConnectClientHandler(ctx, cte.channelID, clientID); err != nil { + return err + } + } + case clientDisconnect: + dte := decodeDisconnectClient(msg) + if dte.channelID == "" || len(dte.clientIDs) == 0 { + return svcerr.ErrMalformedEntity + } + for _, clientID := range dte.clientIDs { + if clientID == "" { + return svcerr.ErrMalformedEntity + } + } + + for _, c := range dte.clientIDs { + if err = es.svc.DisconnectClientHandler(ctx, dte.channelID, c); err != nil { + return err + } + } + case channelUpdate: + uce := decodeUpdateChannel(msg) + err = es.handleUpdateChannel(ctx, uce) + case channelRemove: + rce := decodeRemoveChannel(msg) + err = es.svc.RemoveChannelHandler(ctx, rce.id) + } + if err != nil { + return err + } + + return nil +} + +func decodeRemoveClient(event map[string]any) removeEvent { + return removeEvent{ + id: events.Read(event, "id", ""), + } +} + +func decodeUpdateChannel(event map[string]any) updateChannelEvent { + metadata := events.Read(event, "metadata", map[string]any{}) + + return updateChannelEvent{ + id: events.Read(event, "id", ""), + name: events.Read(event, "name", ""), + metadata: metadata, + updatedAt: events.Read(event, "updated_at", time.Now()), + updatedBy: events.Read(event, "updated_by", ""), + } +} + +func decodeRemoveChannel(event map[string]any) removeEvent { + return removeEvent{ + id: events.Read(event, "id", ""), + } +} + +func decodeConnectClient(event map[string]any) connectionEvent { + if events.Read(event, "memberKind", "") != memberKind && events.Read(event, "relation", "") != relation { + return connectionEvent{} + } + + return connectionEvent{ + channelID: events.Read(event, "group_id", ""), + clientIDs: events.ReadStringSlice(event, "member_ids"), + } +} + +func decodeDisconnectClient(event map[string]any) connectionEvent { + if events.Read(event, "memberKind", "") != memberKind && events.Read(event, "relation", "") != relation { + return connectionEvent{} + } + + return connectionEvent{ + channelID: events.Read(event, "group_id", ""), + clientIDs: events.ReadStringSlice(event, "member_ids"), + } +} + +func (es *eventHandler) handleUpdateChannel(ctx context.Context, uce updateChannelEvent) error { + channel := bootstrap.Channel{ + ID: uce.id, + Name: uce.name, + Metadata: uce.metadata, + UpdatedAt: uce.updatedAt, + UpdatedBy: uce.updatedBy, + } + + return es.svc.UpdateChannelHandler(ctx, channel) +} diff --git a/bootstrap/events/doc.go b/bootstrap/events/doc.go new file mode 100644 index 000000000..fa65f5af2 --- /dev/null +++ b/bootstrap/events/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package events provides the domain concept definitions needed to support +// bootstrap events functionality. +package events diff --git a/bootstrap/events/producer/doc.go b/bootstrap/events/producer/doc.go new file mode 100644 index 000000000..ab1537514 --- /dev/null +++ b/bootstrap/events/producer/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package producer contains the domain events needed to support +// event sourcing of Bootstrap service actions. +package producer diff --git a/bootstrap/events/producer/events.go b/bootstrap/events/producer/events.go new file mode 100644 index 000000000..cc31172c3 --- /dev/null +++ b/bootstrap/events/producer/events.go @@ -0,0 +1,277 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package producer + +import ( + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/pkg/events" +) + +const ( + configPrefix = "bootstrap.config." + configCreate = configPrefix + "create" + configUpdate = configPrefix + "update" + configRemove = configPrefix + "remove" + configView = configPrefix + "view" + configList = configPrefix + "list" + configHandlerRemove = configPrefix + "remove_handler" + + clientPrefix = "bootstrap.client." + clientBootstrap = clientPrefix + "bootstrap" + clientStateChange = clientPrefix + "change_state" + clientUpdateConnections = clientPrefix + "update_connections" + clientConnect = clientPrefix + "connect" + clientDisconnect = clientPrefix + "disconnect" + + channelPrefix = "bootstrap.channel." + channelHandlerRemove = channelPrefix + "remove_handler" + channelUpdateHandler = channelPrefix + "update_handler" + + certUpdate = "bootstrap.cert.update" +) + +var ( + _ events.Event = (*configEvent)(nil) + _ events.Event = (*removeConfigEvent)(nil) + _ events.Event = (*bootstrapEvent)(nil) + _ events.Event = (*changeStateEvent)(nil) + _ events.Event = (*updateConnectionsEvent)(nil) + _ events.Event = (*updateCertEvent)(nil) + _ events.Event = (*listConfigsEvent)(nil) + _ events.Event = (*removeHandlerEvent)(nil) +) + +type configEvent struct { + bootstrap.Config + operation string +} + +func (ce configEvent) Encode() (map[string]any, error) { + val := map[string]any{ + "state": ce.State.String(), + "operation": ce.operation, + } + if ce.ClientID != "" { + val["client_id"] = ce.ClientID + } + if ce.Content != "" { + val["content"] = ce.Content + } + if ce.DomainID != "" { + val["domain_id "] = ce.DomainID + } + if ce.Name != "" { + val["name"] = ce.Name + } + if ce.ExternalID != "" { + val["external_id"] = ce.ExternalID + } + if len(ce.Channels) > 0 { + channels := make([]string, len(ce.Channels)) + for i, ch := range ce.Channels { + channels[i] = ch.ID + } + val["channels"] = channels + } + if ce.ClientCert != "" { + val["client_cert"] = ce.ClientCert + } + if ce.ClientKey != "" { + val["client_key"] = ce.ClientKey + } + if ce.CACert != "" { + val["ca_cert"] = ce.CACert + } + if ce.Content != "" { + val["content"] = ce.Content + } + + return val, nil +} + +type removeConfigEvent struct { + client string +} + +func (rce removeConfigEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": rce.client, + "operation": configRemove, + }, nil +} + +type listConfigsEvent struct { + offset uint64 + limit uint64 + fullMatch map[string]string + partialMatch map[string]string +} + +func (rce listConfigsEvent) Encode() (map[string]any, error) { + val := map[string]any{ + "offset": rce.offset, + "limit": rce.limit, + "operation": configList, + } + if len(rce.fullMatch) > 0 { + val["full_match"] = rce.fullMatch + } + + if len(rce.partialMatch) > 0 { + val["full_match"] = rce.partialMatch + } + return val, nil +} + +type bootstrapEvent struct { + bootstrap.Config + externalID string + success bool +} + +func (be bootstrapEvent) Encode() (map[string]any, error) { + val := map[string]any{ + "external_id": be.externalID, + "success": be.success, + "operation": clientBootstrap, + } + + if be.ClientID != "" { + val["client_id"] = be.ClientID + } + if be.Content != "" { + val["content"] = be.Content + } + if be.DomainID != "" { + val["domain_id "] = be.DomainID + } + if be.Name != "" { + val["name"] = be.Name + } + if be.ExternalID != "" { + val["external_id"] = be.ExternalID + } + if len(be.Channels) > 0 { + channels := make([]string, len(be.Channels)) + for i, ch := range be.Channels { + channels[i] = ch.ID + } + val["channels"] = channels + } + if be.ClientCert != "" { + val["client_cert"] = be.ClientCert + } + if be.ClientKey != "" { + val["client_key"] = be.ClientKey + } + if be.CACert != "" { + val["ca_cert"] = be.CACert + } + if be.Content != "" { + val["content"] = be.Content + } + return val, nil +} + +type changeStateEvent struct { + mgClient string + state bootstrap.State +} + +func (cse changeStateEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": cse.mgClient, + "state": cse.state.String(), + "operation": clientStateChange, + }, nil +} + +type updateConnectionsEvent struct { + mgClient string + mgChannels []string +} + +func (uce updateConnectionsEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": uce.mgClient, + "channels": uce.mgChannels, + "operation": clientUpdateConnections, + }, nil +} + +type updateCertEvent struct { + clientID string + clientCert string + clientKey string + caCert string +} + +func (uce updateCertEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": uce.clientID, + "client_cert": uce.clientCert, + "client_key": uce.clientKey, + "ca_cert": uce.caCert, + "operation": certUpdate, + }, nil +} + +type removeHandlerEvent struct { + id string + operation string +} + +func (rhe removeHandlerEvent) Encode() (map[string]any, error) { + return map[string]any{ + "config_id": rhe.id, + "operation": rhe.operation, + }, nil +} + +type updateChannelHandlerEvent struct { + bootstrap.Channel +} + +func (uche updateChannelHandlerEvent) Encode() (map[string]any, error) { + val := map[string]any{ + "operation": channelUpdateHandler, + } + + if uche.ID != "" { + val["channel_id"] = uche.ID + } + if uche.Name != "" { + val["name"] = uche.Name + } + if uche.Metadata != nil { + val["metadata"] = uche.Metadata + } + return val, nil +} + +type connectClientEvent struct { + clientID string + channelID string +} + +func (cte connectClientEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": cte.clientID, + "channel_id": cte.channelID, + "operation": clientConnect, + }, nil +} + +type disconnectClientEvent struct { + clientID string + channelID string +} + +func (dte disconnectClientEvent) Encode() (map[string]any, error) { + return map[string]any{ + "client_id": dte.clientID, + "channel_id": dte.channelID, + "operation": clientDisconnect, + }, nil +} diff --git a/bootstrap/events/producer/setup_test.go b/bootstrap/events/producer/setup_test.go new file mode 100644 index 000000000..517cd652d --- /dev/null +++ b/bootstrap/events/producer/setup_test.go @@ -0,0 +1,61 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package producer_test + +import ( + "context" + "fmt" + "log" + "os" + "testing" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/redis/go-redis/v9" +) + +var ( + redisClient *redis.Client + redisURL string +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7.2.4-alpine", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp")) + opts, err := redis.ParseURL(redisURL) + if err != nil { + log.Fatalf("Could not parse redis URL: %s", err) + } + + if err := pool.Retry(func() error { + redisClient = redis.NewClient(opts) + + return redisClient.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + code := m.Run() + + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/bootstrap/events/producer/streams.go b/bootstrap/events/producer/streams.go new file mode 100644 index 000000000..143938e37 --- /dev/null +++ b/bootstrap/events/producer/streams.go @@ -0,0 +1,253 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package producer + +import ( + "context" + + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/events" +) + +var _ bootstrap.Service = (*eventStore)(nil) + +const ( + magistralaPrefix = "magistrala." + createStream = magistralaPrefix + configCreate + viewStream = magistralaPrefix + configView + listStream = magistralaPrefix + configList + updateStream = magistralaPrefix + configUpdate + removeStream = magistralaPrefix + configRemove + updateCertStream = magistralaPrefix + certUpdate + updateConnectionsStream = magistralaPrefix + clientUpdateConnections + removeHandlerStream = magistralaPrefix + configHandlerRemove + bootstrapStream = magistralaPrefix + clientBootstrap + stateChangeStream = magistralaPrefix + clientStateChange + connectStream = magistralaPrefix + clientConnect + disconnectStream = magistralaPrefix + clientDisconnect + updateHandlerStream = magistralaPrefix + channelUpdateHandler + removeChannelHandlerStream = magistralaPrefix + channelHandlerRemove +) + +type eventStore struct { + events.Publisher + svc bootstrap.Service +} + +// NewEventStoreMiddleware returns wrapper around bootstrap service that sends +// events to event store. +func NewEventStoreMiddleware(svc bootstrap.Service, publisher events.Publisher) bootstrap.Service { + return &eventStore{ + svc: svc, + Publisher: publisher, + } +} + +func (es *eventStore) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) { + saved, err := es.svc.Add(ctx, session, token, cfg) + if err != nil { + return saved, err + } + + ev := configEvent{ + saved, configCreate, + } + + if err := es.Publish(ctx, createStream, ev); err != nil { + return saved, err + } + + return saved, err +} + +func (es *eventStore) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) { + cfg, err := es.svc.View(ctx, session, id) + if err != nil { + return cfg, err + } + ev := configEvent{ + cfg, configView, + } + + if err := es.Publish(ctx, configView, ev); err != nil { + return cfg, err + } + + return cfg, err +} + +func (es *eventStore) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error { + if err := es.svc.Update(ctx, session, cfg); err != nil { + return err + } + + ev := configEvent{ + cfg, configUpdate, + } + + return es.Publish(ctx, configUpdate, ev) +} + +func (es eventStore) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { + cfg, err := es.svc.UpdateCert(ctx, session, clientID, clientCert, clientKey, caCert) + if err != nil { + return cfg, err + } + + ev := updateCertEvent{ + clientID: clientID, + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + } + + if err := es.Publish(ctx, updateCertStream, ev); err != nil { + return cfg, err + } + + return cfg, nil +} + +func (es *eventStore) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) error { + if err := es.svc.UpdateConnections(ctx, session, token, id, connections); err != nil { + return err + } + + ev := updateConnectionsEvent{ + mgClient: id, + mgChannels: connections, + } + + return es.Publish(ctx, updateConnectionsStream, ev) +} + +func (es *eventStore) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) { + bp, err := es.svc.List(ctx, session, filter, offset, limit) + if err != nil { + return bp, err + } + + ev := listConfigsEvent{ + offset: offset, + limit: limit, + fullMatch: filter.FullMatch, + partialMatch: filter.PartialMatch, + } + + if err := es.Publish(ctx, listStream, ev); err != nil { + return bp, err + } + + return bp, nil +} + +func (es *eventStore) Remove(ctx context.Context, session smqauthn.Session, id string) error { + if err := es.svc.Remove(ctx, session, id); err != nil { + return err + } + + ev := removeConfigEvent{ + client: id, + } + + return es.Publish(ctx, removeStream, ev) +} + +func (es *eventStore) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) { + cfg, err := es.svc.Bootstrap(ctx, externalKey, externalID, secure) + + ev := bootstrapEvent{ + cfg, + externalID, + true, + } + + if err != nil { + ev.success = false + } + + if err := es.Publish(ctx, bootstrapStream, ev); err != nil { + return cfg, err + } + + return cfg, err +} + +func (es *eventStore) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state bootstrap.State) error { + if err := es.svc.ChangeState(ctx, session, token, id, state); err != nil { + return err + } + + ev := changeStateEvent{ + mgClient: id, + state: state, + } + + return es.Publish(ctx, stateChangeStream, ev) +} + +func (es *eventStore) RemoveConfigHandler(ctx context.Context, id string) error { + if err := es.svc.RemoveConfigHandler(ctx, id); err != nil { + return err + } + + ev := removeHandlerEvent{ + id: id, + operation: configHandlerRemove, + } + + return es.Publish(ctx, removeHandlerStream, ev) +} + +func (es *eventStore) RemoveChannelHandler(ctx context.Context, id string) error { + if err := es.svc.RemoveChannelHandler(ctx, id); err != nil { + return err + } + + ev := removeHandlerEvent{ + id: id, + operation: channelHandlerRemove, + } + + return es.Publish(ctx, removeChannelHandlerStream, ev) +} + +func (es *eventStore) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) error { + if err := es.svc.UpdateChannelHandler(ctx, channel); err != nil { + return err + } + + ev := updateChannelHandlerEvent{ + channel, + } + + return es.Publish(ctx, updateStream, ev) +} + +func (es *eventStore) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { + if err := es.svc.ConnectClientHandler(ctx, channelID, clientID); err != nil { + return err + } + + ev := connectClientEvent{ + clientID: clientID, + channelID: channelID, + } + + return es.Publish(ctx, connectStream, ev) +} + +func (es *eventStore) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { + if err := es.svc.DisconnectClientHandler(ctx, channelID, clientID); err != nil { + return err + } + + ev := disconnectClientEvent{ + clientID: clientID, + channelID: channelID, + } + + return es.Publish(ctx, disconnectStream, ev) +} diff --git a/bootstrap/events/producer/streams_test.go b/bootstrap/events/producer/streams_test.go new file mode 100644 index 000000000..1d49fafb7 --- /dev/null +++ b/bootstrap/events/producer/streams_test.go @@ -0,0 +1,1481 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package producer_test + +import ( + "context" + "fmt" + "strconv" + "strings" + "testing" + "time" + + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/bootstrap/events/producer" + "github.com/absmach/supermq/bootstrap/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/events/store" + policysvc "github.com/absmach/supermq/pkg/policies" + policymocks "github.com/absmach/supermq/pkg/policies/mocks" + mgsdk "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/absmach/supermq/pkg/uuid" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const ( + streamID = "supermq.bootstrap" + email = "user@example.com" + validToken = "validToken" + invalidToken = "invalid" + unknownClientID = "unknown" + channelsNum = 3 + defaultTimout = 5 + + configPrefix = "config." + configCreate = configPrefix + "create" + configView = configPrefix + "view" + configUpdate = configPrefix + "update" + configRemove = configPrefix + "remove" + configList = configPrefix + "list" + configHandlerRemove = configPrefix + "remove_handler" + + clientPrefix = "client." + clientBootstrap = clientPrefix + "bootstrap" + clientStateChange = clientPrefix + "change_state" + clientUpdateConnections = clientPrefix + "update_connections" + clientConnect = clientPrefix + "connect" + clientDisconnect = clientPrefix + "disconnect" + + channelPrefix = "group." + channelHandlerRemove = channelPrefix + "remove_handler" + channelUpdateHandler = channelPrefix + "update_handler" + + certUpdate = "cert.update" + instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" +) + +var ( + encKey = []byte("1234567891011121") + + domainID = testsutil.GenerateUUID(&testing.T{}) + validID = testsutil.GenerateUUID(&testing.T{}) + + channel = bootstrap.Channel{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: "name", + Metadata: map[string]any{"name": "value"}, + } + + config = bootstrap.Config{ + ClientID: testsutil.GenerateUUID(&testing.T{}), + ClientSecret: testsutil.GenerateUUID(&testing.T{}), + ExternalID: testsutil.GenerateUUID(&testing.T{}), + ExternalKey: testsutil.GenerateUUID(&testing.T{}), + Channels: []bootstrap.Channel{channel}, + Content: "config", + } +) + +type testVariable struct { + svc bootstrap.Service + boot *mocks.ConfigRepository + policies *policymocks.Service + sdk *sdkmocks.SDK +} + +func newTestVariable(t *testing.T, redisURL string) testVariable { + boot := new(mocks.ConfigRepository) + policies := new(policymocks.Service) + sdk := new(sdkmocks.SDK) + idp := uuid.NewMock() + svc := bootstrap.New(policies, boot, sdk, encKey, idp) + publisher, err := store.NewPublisher(context.Background(), redisURL, "bootstrap-es-pub-test") + require.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + svc = producer.NewEventStoreMiddleware(svc, publisher) + return testVariable{ + svc: svc, + boot: boot, + policies: policies, + sdk: sdk, + } +} + +func TestAdd(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + var channels []string + for _, ch := range config.Channels { + channels = append(channels, ch.ID) + } + + invalidConfig := config + invalidConfig.Channels = []bootstrap.Channel{{ID: "empty"}} + invalidConfig.Channels = []bootstrap.Channel{{ID: "empty"}} + + cases := []struct { + desc string + config bootstrap.Config + token string + session smqauthn.Session + id string + domainID string + clientErr error + channel []bootstrap.Channel + listErr error + saveErr error + err error + event map[string]any + }{ + { + desc: "create config successfully", + config: config, + token: validToken, + id: validID, + domainID: domainID, + channel: config.Channels, + event: map[string]any{ + "client_id": "1", + "domain_id": domainID, + "name": config.Name, + "channels": channels, + "external_id": config.ExternalID, + "content": config.Content, + "timestamp": time.Now().Unix(), + "operation": configCreate, + }, + err: nil, + }, + { + desc: "create config with failed to fetch client", + config: config, + token: validToken, + id: validID, + domainID: domainID, + event: nil, + clientErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "create config with failed to list existing", + config: config, + token: validToken, + id: validID, + domainID: domainID, + event: nil, + listErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "create invalid config", + config: invalidConfig, + token: validToken, + id: validID, + domainID: domainID, + event: nil, + listErr: svcerr.ErrMalformedEntity, + err: svcerr.ErrMalformedEntity, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + sdkCall := tv.sdk.On("Client", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, errors.NewSDKError(tc.clientErr)) + repoCall := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything).Return(tc.config.Channels, tc.listErr) + repoCall1 := tv.boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) + + _, err := tv.svc.Add(context.Background(), tc.session, tc.token, tc.config) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + + sdkCall.Unset() + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestView(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + nonExisting := config + nonExisting.ClientID = unknownClientID + + cases := []struct { + desc string + config bootstrap.Config + token string + session smqauthn.Session + id string + domainID string + retrieveErr error + err error + event map[string]any + }{ + { + desc: "view successfully", + config: config, + token: validToken, + id: validID, + domainID: domainID, + err: nil, + event: map[string]any{ + "client_id": config.ClientID, + "domain_id": config.DomainID, + "name": config.Name, + "channels": config.Channels, + "external_id": config.ExternalID, + "content": config.Content, + "timestamp": time.Now().Unix(), + "operation": configView, + }, + }, + { + desc: "view with failed retrieve", + config: nonExisting, + token: validToken, + id: validID, + domainID: domainID, + retrieveErr: svcerr.ErrViewEntity, + err: svcerr.ErrViewEntity, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.config.ClientID).Return(config, tc.retrieveErr) + _, err := tv.svc.View(context.Background(), tc.session, tc.config.ClientID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestUpdate(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + c := config + + ch1 := channel + ch1.ID = testsutil.GenerateUUID(t) + + ch2 := channel + ch2.ID = testsutil.GenerateUUID(t) + + c.Channels = append(c.Channels, ch1, ch2) + + modified := c + modified.Content = "new-config" + modified.Name = "new name" + + nonExisting := config + nonExisting.ClientID = unknownClientID + + channels := []string{modified.Channels[0].ID, modified.Channels[1].ID} + + cases := []struct { + desc string + config bootstrap.Config + token string + session smqauthn.Session + id string + domainID string + updateErr error + err error + event map[string]any + }{ + { + desc: "update config successfully", + config: modified, + token: validToken, + id: validID, + domainID: domainID, + err: nil, + event: map[string]any{ + "name": modified.Name, + "content": modified.Content, + "timestamp": time.Now().UnixNano(), + "operation": configUpdate, + "channels": channels, + "external_id": modified.ExternalID, + "client_id": modified.ClientID, + "domain_id": domainID, + "state": "0", + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "update with failed update", + config: nonExisting, + token: validToken, + id: validID, + domainID: domainID, + updateErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := tv.boot.On("Update", context.Background(), mock.Anything).Return(tc.updateErr) + err := tv.svc.Update(context.Background(), tc.session, tc.config) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestUpdateConnections(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + configID string + id string + domainID string + token string + session smqauthn.Session + connections []string + clientErr error + channelErr error + retrieveErr error + listErr error + updateErr error + err error + event map[string]any + }{ + { + desc: "update connections successfully", + configID: config.ClientID, + token: validToken, + id: validID, + domainID: domainID, + connections: []string{config.Channels[0].ID}, + err: nil, + event: map[string]any{ + "client_id": config.ClientID, + "channels": "2", + "timestamp": time.Now().Unix(), + "operation": clientUpdateConnections, + }, + }, + { + desc: "update connections with failed channel fetch", + configID: config.ClientID, + token: validToken, + id: validID, + domainID: domainID, + connections: []string{"256"}, + channelErr: errors.NewSDKError(svcerr.ErrNotFound), + err: svcerr.ErrNotFound, + event: nil, + }, + { + desc: "update connections with failed RetrieveByID", + configID: config.ClientID, + token: validToken, + id: validID, + domainID: domainID, + connections: []string{config.Channels[0].ID}, + retrieveErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + { + desc: "update connections with failed ListExisting", + configID: config.ClientID, + token: validToken, + id: validID, + domainID: domainID, + connections: []string{config.Channels[0].ID}, + listErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + { + desc: "update connections with failed UpdateConnections", + configID: config.ClientID, + token: validToken, + id: validID, + domainID: domainID, + connections: []string{config.Channels[0].ID}, + updateErr: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + sdkCall := tv.sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.configID).Return(config, tc.retrieveErr) + repoCall1 := tv.boot.On("ListExisting", context.Background(), domainID, mock.Anything, mock.Anything).Return(config.Channels, tc.listErr) + repoCall2 := tv.boot.On("UpdateConnections", context.Background(), tc.domainID, tc.configID, mock.Anything, tc.connections).Return(tc.updateErr) + err := tv.svc.UpdateConnections(context.Background(), tc.session, tc.token, tc.configID, tc.connections) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + sdkCall.Unset() + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + } +} + +func TestUpdateCert(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + configID string + userID string + domainID string + token string + session smqauthn.Session + clientCert string + clientKey string + caCert string + updateErr error + err error + event map[string]any + }{ + { + desc: "update cert successfully", + configID: config.ClientID, + userID: validID, + domainID: domainID, + token: validToken, + clientCert: "clientCert", + clientKey: "clientKey", + caCert: "caCert", + err: nil, + event: map[string]any{ + "client_secret": config.ClientSecret, + "client_cert": "clientCert", + "client_key": "clientKey", + "ca_cert": "caCert", + "operation": certUpdate, + }, + }, + { + desc: "update cert with failed update", + configID: "clientID", + token: validToken, + userID: validID, + domainID: domainID, + clientCert: "clientCert", + clientKey: "clientKey", + caCert: "caCert", + updateErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + { + desc: "update cert with empty client certificate", + configID: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + clientCert: "", + clientKey: "clientKey", + caCert: "caCert", + err: nil, + event: nil, + }, + { + desc: "update cert with empty client key", + configID: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + clientCert: "clientCert", + clientKey: "", + caCert: "caCert", + err: nil, + event: nil, + }, + { + desc: "update cert with empty CA certificate", + configID: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + clientCert: "clientCert", + clientKey: "clientKey", + caCert: "", + err: nil, + event: nil, + }, + { + desc: "successful update without CA certificate", + configID: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + clientCert: "clientCert", + clientKey: "clientKey", + caCert: "", + err: nil, + event: map[string]any{ + "client_secret": config.ClientSecret, + "client_cert": "clientCert", + "client_key": "clientKey", + "ca_cert": "caCert", + "operation": certUpdate, + "timestamp": time.Now().Unix(), + }, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := tv.boot.On("UpdateCert", context.Background(), tc.domainID, tc.configID, tc.clientCert, tc.clientKey, tc.caCert).Return(config, tc.updateErr) + _, err := tv.svc.UpdateCert(context.Background(), tc.session, tc.configID, tc.clientCert, tc.clientKey, tc.caCert) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + + repoCall.Unset() + } +} + +func TestList(t *testing.T) { + tv := newTestVariable(t, redisURL) + + numClients := 101 + var c bootstrap.Config + saved := make([]bootstrap.Config, 0) + for i := 0; i < numClients; i++ { + c := config + c.ExternalID = testsutil.GenerateUUID(t) + c.ExternalKey = testsutil.GenerateUUID(t) + c.Name = fmt.Sprintf("%s-%d", config.Name, i) + if i == 41 { + c.State = bootstrap.Active + } + saved = append(saved, c) + } + + cases := []struct { + desc string + token string + session smqauthn.Session + userID string + domainID string + config bootstrap.ConfigsPage + filter bootstrap.Filter + offset uint64 + limit uint64 + listObjectsResponse policysvc.PolicyPage + listObjectsErr error + retrieveErr error + err error + event map[string]any + }{ + { + desc: "list successfully as super admin", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + err: nil, + event: map[string]any{ + "client_id": c.ClientID, + "domain_id": c.DomainID, + "name": c.Name, + "channels": c.Channels, + "external_id": c.ExternalID, + "content": c.Content, + "timestamp": time.Now().Unix(), + "operation": configList, + }, + }, + { + desc: "list successfully as domain admin", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + err: nil, + event: map[string]any{ + "client_id": c.ClientID, + "domain_id": c.DomainID, + "name": c.Name, + "channels": c.Channels, + "external_id": c.ExternalID, + "content": c.Content, + "timestamp": time.Now().Unix(), + "operation": configList, + }, + }, + { + desc: "list successfully as non admin", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + err: nil, + event: map[string]any{ + "client_id": c.ClientID, + "domain_id": c.DomainID, + "name": c.Name, + "channels": c.Channels, + "external_id": c.ExternalID, + "content": c.Content, + "timestamp": time.Now().Unix(), + "operation": configList, + }, + }, + { + desc: "list as non admin with failed list all objects", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + listObjectsErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + + { + desc: "list as super admin with failed retrieve all", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + retrieveErr: nil, + err: nil, + event: nil, + }, + { + desc: "list as domain admin with failed retrieve all", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + retrieveErr: nil, + err: nil, + event: nil, + }, + { + desc: "list as non admin with failed retrieve all", + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + listObjectsResponse: policysvc.PolicyPage{}, + retrieveErr: nil, + err: nil, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + policyCall := tv.policies.On("ListAllObjects", mock.Anything, policysvc.Policy{ + SubjectType: policysvc.UserType, + Subject: tc.userID, + Permission: policysvc.ViewPermission, + ObjectType: policysvc.ClientType, + }).Return(tc.listObjectsResponse, tc.listObjectsErr) + repoCall := tv.boot.On("RetrieveAll", context.Background(), mock.Anything, mock.Anything, tc.filter, tc.offset, tc.limit).Return(tc.config, tc.retrieveErr) + + _, err := tv.svc.List(context.Background(), tc.session, tc.filter, tc.offset, tc.limit) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + + policyCall.Unset() + repoCall.Unset() + } +} + +func TestRemove(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + nonExisting := config + nonExisting.ClientID = unknownClientID + + cases := []struct { + desc string + configID string + userID string + domainID string + token string + session smqauthn.Session + removeErr error + err error + event map[string]any + }{ + { + desc: "remove config successfully", + configID: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + event: map[string]any{ + "client_id": config.ClientID, + "timestamp": time.Now().Unix(), + "operation": configRemove, + }, + }, + { + desc: "remove config with failed removal", + configID: nonExisting.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + removeErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := tv.boot.On("Remove", context.Background(), mock.Anything, mock.Anything).Return(tc.removeErr) + err := tv.svc.Remove(context.Background(), tc.session, tc.configID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestBootstrap(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + externalID string + externalKey string + err error + retrieveErr error + event map[string]any + }{ + { + desc: "bootstrap successfully", + externalID: config.ExternalID, + externalKey: config.ExternalKey, + err: nil, + event: map[string]any{ + "external_id": config.ExternalID, + "success": "1", + "timestamp": time.Now().Unix(), + "operation": clientBootstrap, + }, + }, + { + desc: "bootstrap with an error", + externalID: "external_id1", + externalKey: "external_id", + retrieveErr: bootstrap.ErrBootstrap, + err: bootstrap.ErrBootstrap, + event: map[string]any{ + "external_id": "external_id", + "success": "0", + "timestamp": time.Now().Unix(), + "operation": clientBootstrap, + }, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("RetrieveByExternalID", context.Background(), mock.Anything).Return(config, tc.retrieveErr) + _, err = tv.svc.Bootstrap(context.Background(), tc.externalKey, tc.externalID, false) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestChangeState(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + id string + userID string + domainID string + token string + session smqauthn.Session + state bootstrap.State + authResponse smqauthn.Session + authorizeErr error + connectErr error + retrieveErr error + stateErr error + authenticateErr error + err error + event map[string]any + }{ + { + desc: "change state to active", + id: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + state: bootstrap.Active, + authResponse: smqauthn.Session{}, + err: nil, + event: map[string]any{ + "client_id": config.ClientID, + "state": bootstrap.Active.String(), + "timestamp": time.Now().Unix(), + "operation": clientStateChange, + }, + }, + { + desc: "change state with failed retrieve by ID", + id: "", + token: validToken, + userID: validID, + domainID: domainID, + state: bootstrap.Active, + retrieveErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + event: nil, + }, + { + desc: "change state with failed connect", + id: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + state: bootstrap.Active, + connectErr: bootstrap.ErrClients, + err: bootstrap.ErrClients, + event: nil, + }, + { + desc: "change state unsuccessfully", + id: config.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + state: bootstrap.Active, + stateErr: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + tc.session = smqauthn.Session{UserID: validID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := tv.boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(config, tc.retrieveErr) + sdkCall1 := tv.sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.NewSDKError(tc.connectErr)) + repoCall1 := tv.boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) + err := tv.svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + event := streams[0].Messages + lastID = event[0].ID + } + + test(t, tc.event, event, tc.desc) + sdkCall1.Unset() + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestUpdateChannelHandler(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + channel bootstrap.Channel + err error + event map[string]any + }{ + { + desc: "update channel handler successfully", + channel: channel, + err: nil, + event: map[string]any{ + "channel_id": channel.ID, + "metadata": "{\"name\":\"value\"}", + "name": channel.Name, + "operation": channelUpdateHandler, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "update non-existing channel handler", + channel: bootstrap.Channel{ID: "unknown", Name: "NonExistingChannel"}, + err: nil, + event: nil, + }, + { + desc: "update channel handler with empty ID", + channel: bootstrap.Channel{Name: "ChannelWithEmptyID"}, + err: nil, + event: nil, + }, + { + desc: "update channel handler with empty name", + channel: bootstrap.Channel{ID: "3"}, + err: nil, + event: nil, + }, + { + desc: "update channel handler successfully with modified fields", + channel: channel, + err: nil, + event: map[string]any{ + "channel_id": channel.ID, + "metadata": "{\"name\":\"value\"}", + "name": channel.Name, + "operation": channelUpdateHandler, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("UpdateChannel", context.Background(), mock.Anything).Return(tc.err) + err := tv.svc.UpdateChannelHandler(context.Background(), tc.channel) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestRemoveChannelHandler(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + channelID string + err error + event map[string]any + }{ + { + desc: "remove channel handler successfully", + channelID: channel.ID, + err: nil, + event: map[string]any{ + "config_id": channel.ID, + "operation": channelHandlerRemove, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "remove non-existing channel handler", + channelID: "unknown", + err: nil, + event: nil, + }, + { + desc: "remove channel handler with empty ID", + channelID: "", + err: nil, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("RemoveChannel", context.Background(), mock.Anything).Return(tc.err) + err := tv.svc.RemoveChannelHandler(context.Background(), tc.channelID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestRemoveConfigHandler(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + configID string + err error + event map[string]any + }{ + { + desc: "remove config handler successfully", + configID: channel.ID, + err: nil, + event: map[string]any{ + "config_id": channel.ID, + "operation": configHandlerRemove, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "remove non-existing config handler", + configID: "unknown", + err: nil, + event: nil, + }, + { + desc: "remove config handler with empty ID", + configID: "", + err: nil, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("RemoveClient", context.Background(), mock.Anything).Return(tc.err) + err := tv.svc.RemoveConfigHandler(context.Background(), tc.configID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestConnectClientHandler(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + channelID string + clientID string + err error + event map[string]any + }{ + { + desc: "connect client handler successfully", + channelID: channel.ID, + clientID: "1", + err: nil, + event: map[string]any{ + "channel_id": channel.ID, + "client_id": "1", + "operation": clientConnect, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "connect non-existing client handler", + channelID: channel.ID, + clientID: "unknown", + err: nil, + event: nil, + }, + { + desc: "connect client handler with empty client ID", + channelID: channel.ID, + clientID: "", + err: nil, + event: nil, + }, + { + desc: "connect client handler with empty channel ID", + channelID: "", + clientID: "1", + err: nil, + event: nil, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("ConnectClient", context.Background(), mock.Anything, mock.Anything).Return(tc.err) + err := tv.svc.ConnectClientHandler(context.Background(), tc.channelID, tc.clientID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func TestDisconnectClientHandler(t *testing.T) { + err := redisClient.FlushAll(context.Background()).Err() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + tv := newTestVariable(t, redisURL) + + cases := []struct { + desc string + channelID string + clientID string + err error + event map[string]any + }{ + { + desc: "disconnect client handler successfully", + channelID: channel.ID, + clientID: "1", + err: nil, + event: map[string]any{ + "channel_id": channel.ID, + "client_id": "1", + "operation": clientDisconnect, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + { + desc: "remove non-existing client handler", + channelID: "unknown", + err: nil, + }, + { + desc: "remove client handler with empty client ID", + channelID: channel.ID, + clientID: "", + err: nil, + event: nil, + }, + { + desc: "remove client handler with empty channel ID", + channelID: "", + err: nil, + event: nil, + }, + { + desc: "remove client handler successfully", + channelID: channel.ID, + clientID: "1", + err: nil, + event: map[string]any{ + "channel_id": channel.ID, + "client_id": "1", + "operation": clientDisconnect, + "timestamp": time.Now().UnixNano(), + "occurred_at": time.Now().UnixNano(), + }, + }, + } + + lastID := "0" + for _, tc := range cases { + repoCall := tv.boot.On("DisconnectClient", context.Background(), tc.channelID, tc.clientID).Return(tc.err) + err := tv.svc.DisconnectClientHandler(context.Background(), tc.channelID, tc.clientID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + streams := redisClient.XRead(context.Background(), &redis.XReadArgs{ + Streams: []string{streamID, lastID}, + Count: 1, + Block: time.Second, + }).Val() + + var event map[string]any + if len(streams) > 0 && len(streams[0].Messages) > 0 { + msg := streams[0].Messages[0] + event = msg.Values + event["timestamp"] = msg.ID + lastID = msg.ID + } + + test(t, tc.event, event, tc.desc) + repoCall.Unset() + } +} + +func test(t *testing.T, expected, actual map[string]any, description string) { + if expected != nil && actual != nil { + ts1 := expected["timestamp"].(int64) + ats := actual["timestamp"].(string) + ts2, err := strconv.ParseInt(strings.Split(ats, "-")[0], 10, 64) + require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid timestamp, got %s", description, err)) + ts1 = ts1 / 1e9 + ts2 = ts2 / 1e3 + if assert.WithinDuration(t, time.Unix(ts1, 0), time.Unix(ts2, 0), time.Second, fmt.Sprintf("%s: timestamp is not in valid range of 1 second", description)) { + delete(expected, "timestamp") + delete(actual, "timestamp") + } + + oa1 := expected["occurred_at"].(int64) + aoa := actual["occurred_at"].(string) + oa2, err := strconv.ParseInt(aoa, 10, 64) + require.Nil(t, err, fmt.Sprintf("%s: expected to get a valid occurred_at, got %s", description, err)) + oa1 = oa1 / 1e9 + oa2 = oa2 / 1e9 + if assert.WithinDuration(t, time.Unix(oa1, 0), time.Unix(oa2, 0), time.Second, fmt.Sprintf("%s: occurred_at is not in valid range of 1 second", description)) { + delete(expected, "occurred_at") + delete(actual, "occurred_at") + } + + exchs := expected["channels"].([]any) + achs := actual["channels"].([]any) + + if exchs != nil && achs != nil { + if assert.Len(t, exchs, len(achs), fmt.Sprintf("%s: got incorrect number of channels\n", description)) { + for _, exch := range exchs { + assert.Contains(t, achs, exch, fmt.Sprintf("%s: got incorrect channel\n", description)) + } + } + } + + assert.Equal(t, expected, actual, fmt.Sprintf("%s: got incorrect event\n", description)) + } +} diff --git a/bootstrap/middleware/authorization.go b/bootstrap/middleware/authorization.go new file mode 100644 index 000000000..07b20cea7 --- /dev/null +++ b/bootstrap/middleware/authorization.go @@ -0,0 +1,150 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/policies" +) + +const ( + updatePermission = "update_permission" + readPermission = "read_permission" + deletePermission = "delete_permission" +) + +var _ bootstrap.Service = (*authorizationMiddleware)(nil) + +type authorizationMiddleware struct { + svc bootstrap.Service + authz authz.Authorization +} + +// AuthorizationMiddleware adds authorization to the clients service. +func AuthorizationMiddleware(svc bootstrap.Service, authz authz.Authorization) bootstrap.Service { + return &authorizationMiddleware{ + svc: svc, + authz: authz, + } +} + +func (am *authorizationMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) { + if err := am.authorize(ctx, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.MembershipPermission, policies.DomainType, session.DomainID); err != nil { + return bootstrap.Config{}, err + } + + return am.svc.Add(ctx, session, token, cfg) +} + +func (am *authorizationMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) { + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, readPermission, policies.ClientType, id); err != nil { + return bootstrap.Config{}, err + } + + return am.svc.View(ctx, session, id) +} + +func (am *authorizationMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error { + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, updatePermission, policies.ClientType, cfg.ClientID); err != nil { + return err + } + + return am.svc.Update(ctx, session, cfg) +} + +func (am *authorizationMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, updatePermission, policies.ClientType, clientID); err != nil { + return bootstrap.Config{}, err + } + + return am.svc.UpdateCert(ctx, session, clientID, clientCert, clientKey, caCert) +} + +func (am *authorizationMiddleware) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) error { + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, updatePermission, policies.ClientType, id); err != nil { + return err + } + + return am.svc.UpdateConnections(ctx, session, token, id, connections) +} + +func (am *authorizationMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) { + if err := am.checkSuperAdmin(ctx, session.DomainUserID); err == nil { + session.SuperAdmin = true + } + if err := am.authorize(ctx, "", policies.UserType, policies.UsersKind, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID); err == nil { + session.SuperAdmin = true + } + + return am.svc.List(ctx, session, filter, offset, limit) +} + +func (am *authorizationMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error { + if err := am.authorize(ctx, session.DomainID, policies.UserType, policies.UsersKind, session.DomainUserID, deletePermission, policies.ClientType, id); err != nil { + return err + } + + return am.svc.Remove(ctx, session, id) +} + +func (am *authorizationMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) { + return am.svc.Bootstrap(ctx, externalKey, externalID, secure) +} + +func (am *authorizationMiddleware) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state bootstrap.State) error { + return am.svc.ChangeState(ctx, session, token, id, state) +} + +func (am *authorizationMiddleware) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) error { + return am.svc.UpdateChannelHandler(ctx, channel) +} + +func (am *authorizationMiddleware) RemoveConfigHandler(ctx context.Context, id string) error { + return am.svc.RemoveConfigHandler(ctx, id) +} + +func (am *authorizationMiddleware) RemoveChannelHandler(ctx context.Context, id string) error { + return am.svc.RemoveChannelHandler(ctx, id) +} + +func (am *authorizationMiddleware) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { + return am.svc.ConnectClientHandler(ctx, channelID, clientID) +} + +func (am *authorizationMiddleware) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { + return am.svc.DisconnectClientHandler(ctx, channelID, clientID) +} + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, adminID string) error { + if err := am.authz.Authorize(ctx, authz.PolicyReq{ + SubjectType: policies.UserType, + Subject: adminID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.MagistralaObject, + }, nil); err != nil { + return err + } + return nil +} + +func (am *authorizationMiddleware) authorize(ctx context.Context, domain, subjType, subjKind, subj, perm, objType, obj string) error { + req := authz.PolicyReq{ + Domain: domain, + SubjectType: subjType, + SubjectKind: subjKind, + Subject: subj, + Permission: perm, + ObjectType: objType, + Object: obj, + } + if err := am.authz.Authorize(ctx, req, nil); err != nil { + return err + } + return nil +} diff --git a/bootstrap/middleware/logging.go b/bootstrap/middleware/logging.go new file mode 100644 index 000000000..a6d23497c --- /dev/null +++ b/bootstrap/middleware/logging.go @@ -0,0 +1,295 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" +) + +var _ bootstrap.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc bootstrap.Service +} + +// LoggingMiddleware adds logging facilities to the bootstrap service. +func LoggingMiddleware(svc bootstrap.Service, logger *slog.Logger) bootstrap.Service { + return &loggingMiddleware{logger, svc} +} + +// Add logs the add request. It logs the client ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", saved.ClientID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Add new bootstrap failed", args...) + return + } + lm.logger.Info("Add new bootstrap completed successfully", args...) + }(time.Now()) + + return lm.svc.Add(ctx, session, token, cfg) +} + +// View logs the view request. It logs the client ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View client config failed", args...) + return + } + lm.logger.Info("View client config completed successfully", args...) + }(time.Now()) + + return lm.svc.View(ctx, session, id) +} + +// Update logs the update request. It logs bootstrap client ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("config", + slog.String("client_id", cfg.ClientID), + slog.String("name", cfg.Name), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update bootstrap config failed", args...) + return + } + lm.logger.Info("Update bootstrap config completed successfully", args...) + }(time.Now()) + + return lm.svc.Update(ctx, session, cfg) +} + +// UpdateCert logs the update_cert request. It logs client ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", cfg.ClientID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update bootstrap config certificate failed", args...) + return + } + lm.logger.Info("Update bootstrap config certificate completed successfully", args...) + }(time.Now()) + + return lm.svc.UpdateCert(ctx, session, clientID, clientCert, clientKey, caCert) +} + +// UpdateConnections logs the update_connections request. It logs bootstrap ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", id), + slog.Any("connections", connections), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update config connections failed", args...) + return + } + lm.logger.Info("Update config connections completed successfully", args...) + }(time.Now()) + + return lm.svc.UpdateConnections(ctx, session, token, id, connections) +} + +// List logs the list request. It logs offset, limit and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (res bootstrap.ConfigsPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("page", + slog.Any("filter", filter), + slog.Uint64("offset", offset), + slog.Uint64("limit", limit), + slog.Uint64("total", res.Total), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List configs failed", args...) + return + } + lm.logger.Info("List configs completed successfully", args...) + }(time.Now()) + + return lm.svc.List(ctx, session, filter, offset, limit) +} + +// Remove logs the remove request. It logs bootstrap ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove bootstrap config failed", args...) + return + } + lm.logger.Info("Remove bootstrap config completed successfully", args...) + }(time.Now()) + + return lm.svc.Remove(ctx, session, id) +} + +func (lm *loggingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("external_id", externalID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View bootstrap config failed", args...) + return + } + lm.logger.Info("View bootstrap completed successfully", args...) + }(time.Now()) + + return lm.svc.Bootstrap(ctx, externalKey, externalID, secure) +} + +func (lm *loggingMiddleware) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state bootstrap.State) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("id", id), + slog.Any("state", state), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Change client state failed", args...) + return + } + lm.logger.Info("Change client state completed successfully", args...) + }(time.Now()) + + return lm.svc.ChangeState(ctx, session, token, id, state) +} + +func (lm *loggingMiddleware) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("channel", + slog.String("id", channel.ID), + slog.String("name", channel.Name), + slog.Any("metadata", channel.Metadata), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update channel handler failed", args...) + return + } + lm.logger.Info("Update channel handler completed successfully", args...) + }(time.Now()) + + return lm.svc.UpdateChannelHandler(ctx, channel) +} + +func (lm *loggingMiddleware) RemoveConfigHandler(ctx context.Context, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("config_id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove config handler failed", args...) + return + } + lm.logger.Info("Remove config handler completed successfully", args...) + }(time.Now()) + + return lm.svc.RemoveConfigHandler(ctx, id) +} + +func (lm *loggingMiddleware) RemoveChannelHandler(ctx context.Context, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("channel_id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove channel handler failed", args...) + return + } + lm.logger.Info("Remove channel handler completed successfully", args...) + }(time.Now()) + + return lm.svc.RemoveChannelHandler(ctx, id) +} + +func (lm *loggingMiddleware) ConnectClientHandler(ctx context.Context, channelID, clientID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("channel_id", channelID), + slog.String("client_id", clientID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Connect client handler failed", args...) + return + } + lm.logger.Info("Connect client handler completed successfully", args...) + }(time.Now()) + + return lm.svc.ConnectClientHandler(ctx, channelID, clientID) +} + +func (lm *loggingMiddleware) DisconnectClientHandler(ctx context.Context, channelID, clientID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("channel_id", channelID), + slog.String("client_id", clientID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Disconnect client handler failed", args...) + return + } + lm.logger.Info("Disconnect client handler completed successfully", args...) + }(time.Now()) + + return lm.svc.DisconnectClientHandler(ctx, channelID, clientID) +} diff --git a/bootstrap/middleware/metrics.go b/bootstrap/middleware/metrics.go new file mode 100644 index 000000000..89f64ecd4 --- /dev/null +++ b/bootstrap/middleware/metrics.go @@ -0,0 +1,172 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/go-kit/kit/metrics" +) + +var _ bootstrap.Service = (*metricsMiddleware)(nil) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + svc bootstrap.Service +} + +// MetricsMiddleware instruments core service by tracking request count and latency. +func MetricsMiddleware(svc bootstrap.Service, counter metrics.Counter, latency metrics.Histogram) bootstrap.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + svc: svc, + } +} + +// Add instruments Add method with metrics. +func (mm *metricsMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (saved bootstrap.Config, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "add").Add(1) + mm.latency.With("method", "add").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.Add(ctx, session, token, cfg) +} + +// View instruments View method with metrics. +func (mm *metricsMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (saved bootstrap.Config, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "view").Add(1) + mm.latency.With("method", "view").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.View(ctx, session, id) +} + +// Update instruments Update method with metrics. +func (mm *metricsMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "update").Add(1) + mm.latency.With("method", "update").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.Update(ctx, session, cfg) +} + +// UpdateCert instruments UpdateCert method with metrics. +func (mm *metricsMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (cfg bootstrap.Config, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_cert").Add(1) + mm.latency.With("method", "update_cert").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.UpdateCert(ctx, session, clientID, clientCert, clientKey, caCert) +} + +// UpdateConnections instruments UpdateConnections method with metrics. +func (mm *metricsMiddleware) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_connections").Add(1) + mm.latency.With("method", "update_connections").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.UpdateConnections(ctx, session, token, id, connections) +} + +// List instruments List method with metrics. +func (mm *metricsMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (saved bootstrap.ConfigsPage, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "list").Add(1) + mm.latency.With("method", "list").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.List(ctx, session, filter, offset, limit) +} + +// Remove instruments Remove method with metrics. +func (mm *metricsMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "remove").Add(1) + mm.latency.With("method", "remove").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.Remove(ctx, session, id) +} + +// Bootstrap instruments Bootstrap method with metrics. +func (mm *metricsMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (cfg bootstrap.Config, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "bootstrap").Add(1) + mm.latency.With("method", "bootstrap").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.Bootstrap(ctx, externalKey, externalID, secure) +} + +// ChangeState instruments ChangeState method with metrics. +func (mm *metricsMiddleware) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state bootstrap.State) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "change_state").Add(1) + mm.latency.With("method", "change_state").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.ChangeState(ctx, session, token, id, state) +} + +// UpdateChannelHandler instruments UpdateChannelHandler method with metrics. +func (mm *metricsMiddleware) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_channel").Add(1) + mm.latency.With("method", "update_channel").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.UpdateChannelHandler(ctx, channel) +} + +// RemoveConfigHandler instruments RemoveConfigHandler method with metrics. +func (mm *metricsMiddleware) RemoveConfigHandler(ctx context.Context, id string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "remove_config").Add(1) + mm.latency.With("method", "remove_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.RemoveConfigHandler(ctx, id) +} + +// RemoveChannelHandler instruments RemoveChannelHandler method with metrics. +func (mm *metricsMiddleware) RemoveChannelHandler(ctx context.Context, id string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "remove_channel").Add(1) + mm.latency.With("method", "remove_channel").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.RemoveChannelHandler(ctx, id) +} + +// ConnectClientHandler instruments ConnectClientHandler method with metrics. +func (mm *metricsMiddleware) ConnectClientHandler(ctx context.Context, channelID, clientID string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "connect_client_handler").Add(1) + mm.latency.With("method", "connect_client_handler").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.ConnectClientHandler(ctx, channelID, clientID) +} + +// DisconnectClientHandler instruments DisconnectClientHandler method with metrics. +func (mm *metricsMiddleware) DisconnectClientHandler(ctx context.Context, channelID, clientID string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "disconnect_client_handler").Add(1) + mm.latency.With("method", "disconnect_client_handler").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.DisconnectClientHandler(ctx, channelID, clientID) +} diff --git a/bootstrap/mocks/config_reader.go b/bootstrap/mocks/config_reader.go new file mode 100644 index 000000000..3c9d9ea5a --- /dev/null +++ b/bootstrap/mocks/config_reader.go @@ -0,0 +1,109 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "github.com/absmach/supermq/bootstrap" + mock "github.com/stretchr/testify/mock" +) + +// NewConfigReader creates a new instance of ConfigReader. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConfigReader(t interface { + mock.TestingT + Cleanup(func()) +}) *ConfigReader { + mock := &ConfigReader{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// ConfigReader is an autogenerated mock type for the ConfigReader type +type ConfigReader struct { + mock.Mock +} + +type ConfigReader_Expecter struct { + mock *mock.Mock +} + +func (_m *ConfigReader) EXPECT() *ConfigReader_Expecter { + return &ConfigReader_Expecter{mock: &_m.Mock} +} + +// ReadConfig provides a mock function for the type ConfigReader +func (_mock *ConfigReader) ReadConfig(config bootstrap.Config, b bool) (any, error) { + ret := _mock.Called(config, b) + + if len(ret) == 0 { + panic("no return value specified for ReadConfig") + } + + var r0 any + var r1 error + if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) (any, error)); ok { + return returnFunc(config, b) + } + if returnFunc, ok := ret.Get(0).(func(bootstrap.Config, bool) any); ok { + r0 = returnFunc(config, b) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(any) + } + } + if returnFunc, ok := ret.Get(1).(func(bootstrap.Config, bool) error); ok { + r1 = returnFunc(config, b) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigReader_ReadConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadConfig' +type ConfigReader_ReadConfig_Call struct { + *mock.Call +} + +// ReadConfig is a helper method to define mock.On call +// - config bootstrap.Config +// - b bool +func (_e *ConfigReader_Expecter) ReadConfig(config interface{}, b interface{}) *ConfigReader_ReadConfig_Call { + return &ConfigReader_ReadConfig_Call{Call: _e.mock.On("ReadConfig", config, b)} +} + +func (_c *ConfigReader_ReadConfig_Call) Run(run func(config bootstrap.Config, b bool)) *ConfigReader_ReadConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 bootstrap.Config + if args[0] != nil { + arg0 = args[0].(bootstrap.Config) + } + var arg1 bool + if args[1] != nil { + arg1 = args[1].(bool) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigReader_ReadConfig_Call) Return(v any, err error) *ConfigReader_ReadConfig_Call { + _c.Call.Return(v, err) + return _c +} + +func (_c *ConfigReader_ReadConfig_Call) RunAndReturn(run func(config bootstrap.Config, b bool) (any, error)) *ConfigReader_ReadConfig_Call { + _c.Call.Return(run) + return _c +} diff --git a/bootstrap/mocks/config_repository.go b/bootstrap/mocks/config_repository.go new file mode 100644 index 000000000..81d455eab --- /dev/null +++ b/bootstrap/mocks/config_repository.go @@ -0,0 +1,1059 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/bootstrap" + mock "github.com/stretchr/testify/mock" +) + +// NewConfigRepository creates a new instance of ConfigRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewConfigRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *ConfigRepository { + mock := &ConfigRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// ConfigRepository is an autogenerated mock type for the ConfigRepository type +type ConfigRepository struct { + mock.Mock +} + +type ConfigRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *ConfigRepository) EXPECT() *ConfigRepository_Expecter { + return &ConfigRepository_Expecter{mock: &_m.Mock} +} + +// ChangeState provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) ChangeState(ctx context.Context, domainID string, id string, state bootstrap.State) error { + ret := _mock.Called(ctx, domainID, id, state) + + if len(ret) == 0 { + panic("no return value specified for ChangeState") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, bootstrap.State) error); ok { + r0 = returnFunc(ctx, domainID, id, state) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_ChangeState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangeState' +type ConfigRepository_ChangeState_Call struct { + *mock.Call +} + +// ChangeState is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - id string +// - state bootstrap.State +func (_e *ConfigRepository_Expecter) ChangeState(ctx interface{}, domainID interface{}, id interface{}, state interface{}) *ConfigRepository_ChangeState_Call { + return &ConfigRepository_ChangeState_Call{Call: _e.mock.On("ChangeState", ctx, domainID, id, state)} +} + +func (_c *ConfigRepository_ChangeState_Call) Run(run func(ctx context.Context, domainID string, id string, state bootstrap.State)) *ConfigRepository_ChangeState_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 bootstrap.State + if args[3] != nil { + arg3 = args[3].(bootstrap.State) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *ConfigRepository_ChangeState_Call) Return(err error) *ConfigRepository_ChangeState_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_ChangeState_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, state bootstrap.State) error) *ConfigRepository_ChangeState_Call { + _c.Call.Return(run) + return _c +} + +// ConnectClient provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) ConnectClient(ctx context.Context, channelID string, clientID string) error { + ret := _mock.Called(ctx, channelID, clientID) + + if len(ret) == 0 { + panic("no return value specified for ConnectClient") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, channelID, clientID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_ConnectClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectClient' +type ConfigRepository_ConnectClient_Call struct { + *mock.Call +} + +// ConnectClient is a helper method to define mock.On call +// - ctx context.Context +// - channelID string +// - clientID string +func (_e *ConfigRepository_Expecter) ConnectClient(ctx interface{}, channelID interface{}, clientID interface{}) *ConfigRepository_ConnectClient_Call { + return &ConfigRepository_ConnectClient_Call{Call: _e.mock.On("ConnectClient", ctx, channelID, clientID)} +} + +func (_c *ConfigRepository_ConnectClient_Call) Run(run func(ctx context.Context, channelID string, clientID string)) *ConfigRepository_ConnectClient_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_ConnectClient_Call) Return(err error) *ConfigRepository_ConnectClient_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_ConnectClient_Call) RunAndReturn(run func(ctx context.Context, channelID string, clientID string) error) *ConfigRepository_ConnectClient_Call { + _c.Call.Return(run) + return _c +} + +// DisconnectClient provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) DisconnectClient(ctx context.Context, channelID string, clientID string) error { + ret := _mock.Called(ctx, channelID, clientID) + + if len(ret) == 0 { + panic("no return value specified for DisconnectClient") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, channelID, clientID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_DisconnectClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisconnectClient' +type ConfigRepository_DisconnectClient_Call struct { + *mock.Call +} + +// DisconnectClient is a helper method to define mock.On call +// - ctx context.Context +// - channelID string +// - clientID string +func (_e *ConfigRepository_Expecter) DisconnectClient(ctx interface{}, channelID interface{}, clientID interface{}) *ConfigRepository_DisconnectClient_Call { + return &ConfigRepository_DisconnectClient_Call{Call: _e.mock.On("DisconnectClient", ctx, channelID, clientID)} +} + +func (_c *ConfigRepository_DisconnectClient_Call) Run(run func(ctx context.Context, channelID string, clientID string)) *ConfigRepository_DisconnectClient_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_DisconnectClient_Call) Return(err error) *ConfigRepository_DisconnectClient_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_DisconnectClient_Call) RunAndReturn(run func(ctx context.Context, channelID string, clientID string) error) *ConfigRepository_DisconnectClient_Call { + _c.Call.Return(run) + return _c +} + +// ListExisting provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) { + ret := _mock.Called(ctx, domainID, ids) + + if len(ret) == 0 { + panic("no return value specified for ListExisting") + } + + var r0 []bootstrap.Channel + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) ([]bootstrap.Channel, error)); ok { + return returnFunc(ctx, domainID, ids) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) []bootstrap.Channel); ok { + r0 = returnFunc(ctx, domainID, ids) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]bootstrap.Channel) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { + r1 = returnFunc(ctx, domainID, ids) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigRepository_ListExisting_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListExisting' +type ConfigRepository_ListExisting_Call struct { + *mock.Call +} + +// ListExisting is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - ids []string +func (_e *ConfigRepository_Expecter) ListExisting(ctx interface{}, domainID interface{}, ids interface{}) *ConfigRepository_ListExisting_Call { + return &ConfigRepository_ListExisting_Call{Call: _e.mock.On("ListExisting", ctx, domainID, ids)} +} + +func (_c *ConfigRepository_ListExisting_Call) Run(run func(ctx context.Context, domainID string, ids []string)) *ConfigRepository_ListExisting_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_ListExisting_Call) Return(channels []bootstrap.Channel, err error) *ConfigRepository_ListExisting_Call { + _c.Call.Return(channels, err) + return _c +} + +func (_c *ConfigRepository_ListExisting_Call) RunAndReturn(run func(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error)) *ConfigRepository_ListExisting_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) Remove(ctx context.Context, domainID string, id string) error { + ret := _mock.Called(ctx, domainID, id) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, domainID, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type ConfigRepository_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - id string +func (_e *ConfigRepository_Expecter) Remove(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_Remove_Call { + return &ConfigRepository_Remove_Call{Call: _e.mock.On("Remove", ctx, domainID, id)} +} + +func (_c *ConfigRepository_Remove_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_Remove_Call) Return(err error) *ConfigRepository_Remove_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_Remove_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) error) *ConfigRepository_Remove_Call { + _c.Call.Return(run) + return _c +} + +// RemoveChannel provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) RemoveChannel(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveChannel") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_RemoveChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveChannel' +type ConfigRepository_RemoveChannel_Call struct { + *mock.Call +} + +// RemoveChannel is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *ConfigRepository_Expecter) RemoveChannel(ctx interface{}, id interface{}) *ConfigRepository_RemoveChannel_Call { + return &ConfigRepository_RemoveChannel_Call{Call: _e.mock.On("RemoveChannel", ctx, id)} +} + +func (_c *ConfigRepository_RemoveChannel_Call) Run(run func(ctx context.Context, id string)) *ConfigRepository_RemoveChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigRepository_RemoveChannel_Call) Return(err error) *ConfigRepository_RemoveChannel_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_RemoveChannel_Call) RunAndReturn(run func(ctx context.Context, id string) error) *ConfigRepository_RemoveChannel_Call { + _c.Call.Return(run) + return _c +} + +// RemoveClient provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) RemoveClient(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveClient") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_RemoveClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveClient' +type ConfigRepository_RemoveClient_Call struct { + *mock.Call +} + +// RemoveClient is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *ConfigRepository_Expecter) RemoveClient(ctx interface{}, id interface{}) *ConfigRepository_RemoveClient_Call { + return &ConfigRepository_RemoveClient_Call{Call: _e.mock.On("RemoveClient", ctx, id)} +} + +func (_c *ConfigRepository_RemoveClient_Call) Run(run func(ctx context.Context, id string)) *ConfigRepository_RemoveClient_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigRepository_RemoveClient_Call) Return(err error) *ConfigRepository_RemoveClient_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_RemoveClient_Call) RunAndReturn(run func(ctx context.Context, id string) error) *ConfigRepository_RemoveClient_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAll provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) RetrieveAll(ctx context.Context, domainID string, clientIDs []string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage { + ret := _mock.Called(ctx, domainID, clientIDs, filter, offset, limit) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAll") + } + + var r0 bootstrap.ConfigsPage + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok { + r0 = returnFunc(ctx, domainID, clientIDs, filter, offset, limit) + } else { + r0 = ret.Get(0).(bootstrap.ConfigsPage) + } + return r0 +} + +// ConfigRepository_RetrieveAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAll' +type ConfigRepository_RetrieveAll_Call struct { + *mock.Call +} + +// RetrieveAll is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - clientIDs []string +// - filter bootstrap.Filter +// - offset uint64 +// - limit uint64 +func (_e *ConfigRepository_Expecter) RetrieveAll(ctx interface{}, domainID interface{}, clientIDs interface{}, filter interface{}, offset interface{}, limit interface{}) *ConfigRepository_RetrieveAll_Call { + return &ConfigRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, domainID, clientIDs, filter, offset, limit)} +} + +func (_c *ConfigRepository_RetrieveAll_Call) Run(run func(ctx context.Context, domainID string, clientIDs []string, filter bootstrap.Filter, offset uint64, limit uint64)) *ConfigRepository_RetrieveAll_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + var arg3 bootstrap.Filter + if args[3] != nil { + arg3 = args[3].(bootstrap.Filter) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + var arg5 uint64 + if args[5] != nil { + arg5 = args[5].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *ConfigRepository_RetrieveAll_Call) Return(configsPage bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call { + _c.Call.Return(configsPage) + return _c +} + +func (_c *ConfigRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, domainID string, clientIDs []string, filter bootstrap.Filter, offset uint64, limit uint64) bootstrap.ConfigsPage) *ConfigRepository_RetrieveAll_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveByExternalID provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) { + ret := _mock.Called(ctx, externalID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByExternalID") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bootstrap.Config, error)); ok { + return returnFunc(ctx, externalID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bootstrap.Config); ok { + r0 = returnFunc(ctx, externalID) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, externalID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigRepository_RetrieveByExternalID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByExternalID' +type ConfigRepository_RetrieveByExternalID_Call struct { + *mock.Call +} + +// RetrieveByExternalID is a helper method to define mock.On call +// - ctx context.Context +// - externalID string +func (_e *ConfigRepository_Expecter) RetrieveByExternalID(ctx interface{}, externalID interface{}) *ConfigRepository_RetrieveByExternalID_Call { + return &ConfigRepository_RetrieveByExternalID_Call{Call: _e.mock.On("RetrieveByExternalID", ctx, externalID)} +} + +func (_c *ConfigRepository_RetrieveByExternalID_Call) Run(run func(ctx context.Context, externalID string)) *ConfigRepository_RetrieveByExternalID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigRepository_RetrieveByExternalID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByExternalID_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *ConfigRepository_RetrieveByExternalID_Call) RunAndReturn(run func(ctx context.Context, externalID string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByExternalID_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveByID provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) RetrieveByID(ctx context.Context, domainID string, id string) (bootstrap.Config, error) { + ret := _mock.Called(ctx, domainID, id) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByID") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (bootstrap.Config, error)); ok { + return returnFunc(ctx, domainID, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) bootstrap.Config); ok { + r0 = returnFunc(ctx, domainID, id) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, domainID, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigRepository_RetrieveByID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByID' +type ConfigRepository_RetrieveByID_Call struct { + *mock.Call +} + +// RetrieveByID is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - id string +func (_e *ConfigRepository_Expecter) RetrieveByID(ctx interface{}, domainID interface{}, id interface{}) *ConfigRepository_RetrieveByID_Call { + return &ConfigRepository_RetrieveByID_Call{Call: _e.mock.On("RetrieveByID", ctx, domainID, id)} +} + +func (_c *ConfigRepository_RetrieveByID_Call) Run(run func(ctx context.Context, domainID string, id string)) *ConfigRepository_RetrieveByID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_RetrieveByID_Call) Return(config bootstrap.Config, err error) *ConfigRepository_RetrieveByID_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *ConfigRepository_RetrieveByID_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string) (bootstrap.Config, error)) *ConfigRepository_RetrieveByID_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) Save(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string) (string, error) { + ret := _mock.Called(ctx, cfg, chsConnIDs) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config, []string) (string, error)); ok { + return returnFunc(ctx, cfg, chsConnIDs) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config, []string) string); ok { + r0 = returnFunc(ctx, cfg, chsConnIDs) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, bootstrap.Config, []string) error); ok { + r1 = returnFunc(ctx, cfg, chsConnIDs) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type ConfigRepository_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - ctx context.Context +// - cfg bootstrap.Config +// - chsConnIDs []string +func (_e *ConfigRepository_Expecter) Save(ctx interface{}, cfg interface{}, chsConnIDs interface{}) *ConfigRepository_Save_Call { + return &ConfigRepository_Save_Call{Call: _e.mock.On("Save", ctx, cfg, chsConnIDs)} +} + +func (_c *ConfigRepository_Save_Call) Run(run func(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string)) *ConfigRepository_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 bootstrap.Config + if args[1] != nil { + arg1 = args[1].(bootstrap.Config) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *ConfigRepository_Save_Call) Return(s string, err error) *ConfigRepository_Save_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *ConfigRepository_Save_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string) (string, error)) *ConfigRepository_Save_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) Update(ctx context.Context, cfg bootstrap.Config) error { + ret := _mock.Called(ctx, cfg) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Config) error); ok { + r0 = returnFunc(ctx, cfg) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type ConfigRepository_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - cfg bootstrap.Config +func (_e *ConfigRepository_Expecter) Update(ctx interface{}, cfg interface{}) *ConfigRepository_Update_Call { + return &ConfigRepository_Update_Call{Call: _e.mock.On("Update", ctx, cfg)} +} + +func (_c *ConfigRepository_Update_Call) Run(run func(ctx context.Context, cfg bootstrap.Config)) *ConfigRepository_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 bootstrap.Config + if args[1] != nil { + arg1 = args[1].(bootstrap.Config) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigRepository_Update_Call) Return(err error) *ConfigRepository_Update_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_Update_Call) RunAndReturn(run func(ctx context.Context, cfg bootstrap.Config) error) *ConfigRepository_Update_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCert provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) UpdateCert(ctx context.Context, domainID string, clientID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) { + ret := _mock.Called(ctx, domainID, clientID, clientCert, clientKey, caCert) + + if len(ret) == 0 { + panic("no return value specified for UpdateCert") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (bootstrap.Config, error)); ok { + return returnFunc(ctx, domainID, clientID, clientCert, clientKey, caCert) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) bootstrap.Config); ok { + r0 = returnFunc(ctx, domainID, clientID, clientCert, clientKey, caCert) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok { + r1 = returnFunc(ctx, domainID, clientID, clientCert, clientKey, caCert) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ConfigRepository_UpdateCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCert' +type ConfigRepository_UpdateCert_Call struct { + *mock.Call +} + +// UpdateCert is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - clientID string +// - clientCert string +// - clientKey string +// - caCert string +func (_e *ConfigRepository_Expecter) UpdateCert(ctx interface{}, domainID interface{}, clientID interface{}, clientCert interface{}, clientKey interface{}, caCert interface{}) *ConfigRepository_UpdateCert_Call { + return &ConfigRepository_UpdateCert_Call{Call: _e.mock.On("UpdateCert", ctx, domainID, clientID, clientCert, clientKey, caCert)} +} + +func (_c *ConfigRepository_UpdateCert_Call) Run(run func(ctx context.Context, domainID string, clientID string, clientCert string, clientKey string, caCert string)) *ConfigRepository_UpdateCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *ConfigRepository_UpdateCert_Call) Return(config bootstrap.Config, err error) *ConfigRepository_UpdateCert_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *ConfigRepository_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, domainID string, clientID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error)) *ConfigRepository_UpdateCert_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannel provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) UpdateChannel(ctx context.Context, c bootstrap.Channel) error { + ret := _mock.Called(ctx, c) + + if len(ret) == 0 { + panic("no return value specified for UpdateChannel") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Channel) error); ok { + r0 = returnFunc(ctx, c) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_UpdateChannel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannel' +type ConfigRepository_UpdateChannel_Call struct { + *mock.Call +} + +// UpdateChannel is a helper method to define mock.On call +// - ctx context.Context +// - c bootstrap.Channel +func (_e *ConfigRepository_Expecter) UpdateChannel(ctx interface{}, c interface{}) *ConfigRepository_UpdateChannel_Call { + return &ConfigRepository_UpdateChannel_Call{Call: _e.mock.On("UpdateChannel", ctx, c)} +} + +func (_c *ConfigRepository_UpdateChannel_Call) Run(run func(ctx context.Context, c bootstrap.Channel)) *ConfigRepository_UpdateChannel_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 bootstrap.Channel + if args[1] != nil { + arg1 = args[1].(bootstrap.Channel) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ConfigRepository_UpdateChannel_Call) Return(err error) *ConfigRepository_UpdateChannel_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_UpdateChannel_Call) RunAndReturn(run func(ctx context.Context, c bootstrap.Channel) error) *ConfigRepository_UpdateChannel_Call { + _c.Call.Return(run) + return _c +} + +// UpdateConnections provides a mock function for the type ConfigRepository +func (_mock *ConfigRepository) UpdateConnections(ctx context.Context, domainID string, id string, channels []bootstrap.Channel, connections []string) error { + ret := _mock.Called(ctx, domainID, id, channels, connections) + + if len(ret) == 0 { + panic("no return value specified for UpdateConnections") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []bootstrap.Channel, []string) error); ok { + r0 = returnFunc(ctx, domainID, id, channels, connections) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// ConfigRepository_UpdateConnections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateConnections' +type ConfigRepository_UpdateConnections_Call struct { + *mock.Call +} + +// UpdateConnections is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - id string +// - channels []bootstrap.Channel +// - connections []string +func (_e *ConfigRepository_Expecter) UpdateConnections(ctx interface{}, domainID interface{}, id interface{}, channels interface{}, connections interface{}) *ConfigRepository_UpdateConnections_Call { + return &ConfigRepository_UpdateConnections_Call{Call: _e.mock.On("UpdateConnections", ctx, domainID, id, channels, connections)} +} + +func (_c *ConfigRepository_UpdateConnections_Call) Run(run func(ctx context.Context, domainID string, id string, channels []bootstrap.Channel, connections []string)) *ConfigRepository_UpdateConnections_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 []bootstrap.Channel + if args[3] != nil { + arg3 = args[3].([]bootstrap.Channel) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *ConfigRepository_UpdateConnections_Call) Return(err error) *ConfigRepository_UpdateConnections_Call { + _c.Call.Return(err) + return _c +} + +func (_c *ConfigRepository_UpdateConnections_Call) RunAndReturn(run func(ctx context.Context, domainID string, id string, channels []bootstrap.Channel, connections []string) error) *ConfigRepository_UpdateConnections_Call { + _c.Call.Return(run) + return _c +} diff --git a/bootstrap/mocks/service.go b/bootstrap/mocks/service.go new file mode 100644 index 000000000..d7b383214 --- /dev/null +++ b/bootstrap/mocks/service.go @@ -0,0 +1,1019 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// Add provides a mock function for the type Service +func (_mock *Service) Add(ctx context.Context, session authn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) { + ret := _mock.Called(ctx, session, token, cfg) + + if len(ret) == 0 { + panic("no return value specified for Add") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bootstrap.Config) (bootstrap.Config, error)); ok { + return returnFunc(ctx, session, token, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bootstrap.Config) bootstrap.Config); ok { + r0 = returnFunc(ctx, session, token, cfg) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, bootstrap.Config) error); ok { + r1 = returnFunc(ctx, session, token, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_Add_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Add' +type Service_Add_Call struct { + *mock.Call +} + +// Add is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - token string +// - cfg bootstrap.Config +func (_e *Service_Expecter) Add(ctx interface{}, session interface{}, token interface{}, cfg interface{}) *Service_Add_Call { + return &Service_Add_Call{Call: _e.mock.On("Add", ctx, session, token, cfg)} +} + +func (_c *Service_Add_Call) Run(run func(ctx context.Context, session authn.Session, token string, cfg bootstrap.Config)) *Service_Add_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 bootstrap.Config + if args[3] != nil { + arg3 = args[3].(bootstrap.Config) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_Add_Call) Return(config bootstrap.Config, err error) *Service_Add_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *Service_Add_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error)) *Service_Add_Call { + _c.Call.Return(run) + return _c +} + +// Bootstrap provides a mock function for the type Service +func (_mock *Service) Bootstrap(ctx context.Context, externalKey string, externalID string, secure bool) (bootstrap.Config, error) { + ret := _mock.Called(ctx, externalKey, externalID, secure) + + if len(ret) == 0 { + panic("no return value specified for Bootstrap") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, bool) (bootstrap.Config, error)); ok { + return returnFunc(ctx, externalKey, externalID, secure) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, bool) bootstrap.Config); ok { + r0 = returnFunc(ctx, externalKey, externalID, secure) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, bool) error); ok { + r1 = returnFunc(ctx, externalKey, externalID, secure) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_Bootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Bootstrap' +type Service_Bootstrap_Call struct { + *mock.Call +} + +// Bootstrap is a helper method to define mock.On call +// - ctx context.Context +// - externalKey string +// - externalID string +// - secure bool +func (_e *Service_Expecter) Bootstrap(ctx interface{}, externalKey interface{}, externalID interface{}, secure interface{}) *Service_Bootstrap_Call { + return &Service_Bootstrap_Call{Call: _e.mock.On("Bootstrap", ctx, externalKey, externalID, secure)} +} + +func (_c *Service_Bootstrap_Call) Run(run func(ctx context.Context, externalKey string, externalID string, secure bool)) *Service_Bootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_Bootstrap_Call) Return(config bootstrap.Config, err error) *Service_Bootstrap_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *Service_Bootstrap_Call) RunAndReturn(run func(ctx context.Context, externalKey string, externalID string, secure bool) (bootstrap.Config, error)) *Service_Bootstrap_Call { + _c.Call.Return(run) + return _c +} + +// ChangeState provides a mock function for the type Service +func (_mock *Service) ChangeState(ctx context.Context, session authn.Session, token string, id string, state bootstrap.State) error { + ret := _mock.Called(ctx, session, token, id, state) + + if len(ret) == 0 { + panic("no return value specified for ChangeState") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, bootstrap.State) error); ok { + r0 = returnFunc(ctx, session, token, id, state) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_ChangeState_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ChangeState' +type Service_ChangeState_Call struct { + *mock.Call +} + +// ChangeState is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - token string +// - id string +// - state bootstrap.State +func (_e *Service_Expecter) ChangeState(ctx interface{}, session interface{}, token interface{}, id interface{}, state interface{}) *Service_ChangeState_Call { + return &Service_ChangeState_Call{Call: _e.mock.On("ChangeState", ctx, session, token, id, state)} +} + +func (_c *Service_ChangeState_Call) Run(run func(ctx context.Context, session authn.Session, token string, id string, state bootstrap.State)) *Service_ChangeState_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 bootstrap.State + if args[4] != nil { + arg4 = args[4].(bootstrap.State) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_ChangeState_Call) Return(err error) *Service_ChangeState_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_ChangeState_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, token string, id string, state bootstrap.State) error) *Service_ChangeState_Call { + _c.Call.Return(run) + return _c +} + +// ConnectClientHandler provides a mock function for the type Service +func (_mock *Service) ConnectClientHandler(ctx context.Context, channelID string, clientID string) error { + ret := _mock.Called(ctx, channelID, clientID) + + if len(ret) == 0 { + panic("no return value specified for ConnectClientHandler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, channelID, clientID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_ConnectClientHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConnectClientHandler' +type Service_ConnectClientHandler_Call struct { + *mock.Call +} + +// ConnectClientHandler is a helper method to define mock.On call +// - ctx context.Context +// - channelID string +// - clientID string +func (_e *Service_Expecter) ConnectClientHandler(ctx interface{}, channelID interface{}, clientID interface{}) *Service_ConnectClientHandler_Call { + return &Service_ConnectClientHandler_Call{Call: _e.mock.On("ConnectClientHandler", ctx, channelID, clientID)} +} + +func (_c *Service_ConnectClientHandler_Call) Run(run func(ctx context.Context, channelID string, clientID string)) *Service_ConnectClientHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ConnectClientHandler_Call) Return(err error) *Service_ConnectClientHandler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_ConnectClientHandler_Call) RunAndReturn(run func(ctx context.Context, channelID string, clientID string) error) *Service_ConnectClientHandler_Call { + _c.Call.Return(run) + return _c +} + +// DisconnectClientHandler provides a mock function for the type Service +func (_mock *Service) DisconnectClientHandler(ctx context.Context, channelID string, clientID string) error { + ret := _mock.Called(ctx, channelID, clientID) + + if len(ret) == 0 { + panic("no return value specified for DisconnectClientHandler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, channelID, clientID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_DisconnectClientHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisconnectClientHandler' +type Service_DisconnectClientHandler_Call struct { + *mock.Call +} + +// DisconnectClientHandler is a helper method to define mock.On call +// - ctx context.Context +// - channelID string +// - clientID string +func (_e *Service_Expecter) DisconnectClientHandler(ctx interface{}, channelID interface{}, clientID interface{}) *Service_DisconnectClientHandler_Call { + return &Service_DisconnectClientHandler_Call{Call: _e.mock.On("DisconnectClientHandler", ctx, channelID, clientID)} +} + +func (_c *Service_DisconnectClientHandler_Call) Run(run func(ctx context.Context, channelID string, clientID string)) *Service_DisconnectClientHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_DisconnectClientHandler_Call) Return(err error) *Service_DisconnectClientHandler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_DisconnectClientHandler_Call) RunAndReturn(run func(ctx context.Context, channelID string, clientID string) error) *Service_DisconnectClientHandler_Call { + _c.Call.Return(run) + return _c +} + +// List provides a mock function for the type Service +func (_mock *Service) List(ctx context.Context, session authn.Session, filter bootstrap.Filter, offset uint64, limit uint64) (bootstrap.ConfigsPage, error) { + ret := _mock.Called(ctx, session, filter, offset, limit) + + if len(ret) == 0 { + panic("no return value specified for List") + } + + var r0 bootstrap.ConfigsPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Filter, uint64, uint64) (bootstrap.ConfigsPage, error)); ok { + return returnFunc(ctx, session, filter, offset, limit) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Filter, uint64, uint64) bootstrap.ConfigsPage); ok { + r0 = returnFunc(ctx, session, filter, offset, limit) + } else { + r0 = ret.Get(0).(bootstrap.ConfigsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, bootstrap.Filter, uint64, uint64) error); ok { + r1 = returnFunc(ctx, session, filter, offset, limit) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_List_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'List' +type Service_List_Call struct { + *mock.Call +} + +// List is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - filter bootstrap.Filter +// - offset uint64 +// - limit uint64 +func (_e *Service_Expecter) List(ctx interface{}, session interface{}, filter interface{}, offset interface{}, limit interface{}) *Service_List_Call { + return &Service_List_Call{Call: _e.mock.On("List", ctx, session, filter, offset, limit)} +} + +func (_c *Service_List_Call) Run(run func(ctx context.Context, session authn.Session, filter bootstrap.Filter, offset uint64, limit uint64)) *Service_List_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 bootstrap.Filter + if args[2] != nil { + arg2 = args[2].(bootstrap.Filter) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_List_Call) Return(configsPage bootstrap.ConfigsPage, err error) *Service_List_Call { + _c.Call.Return(configsPage, err) + return _c +} + +func (_c *Service_List_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, filter bootstrap.Filter, offset uint64, limit uint64) (bootstrap.ConfigsPage, error)) *Service_List_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function for the type Service +func (_mock *Service) Remove(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type Service_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) Remove(ctx interface{}, session interface{}, id interface{}) *Service_Remove_Call { + return &Service_Remove_Call{Call: _e.mock.On("Remove", ctx, session, id)} +} + +func (_c *Service_Remove_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_Remove_Call) Return(err error) *Service_Remove_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Remove_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_Remove_Call { + _c.Call.Return(run) + return _c +} + +// RemoveChannelHandler provides a mock function for the type Service +func (_mock *Service) RemoveChannelHandler(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveChannelHandler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveChannelHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveChannelHandler' +type Service_RemoveChannelHandler_Call struct { + *mock.Call +} + +// RemoveChannelHandler is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Service_Expecter) RemoveChannelHandler(ctx interface{}, id interface{}) *Service_RemoveChannelHandler_Call { + return &Service_RemoveChannelHandler_Call{Call: _e.mock.On("RemoveChannelHandler", ctx, id)} +} + +func (_c *Service_RemoveChannelHandler_Call) Run(run func(ctx context.Context, id string)) *Service_RemoveChannelHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_RemoveChannelHandler_Call) Return(err error) *Service_RemoveChannelHandler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveChannelHandler_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Service_RemoveChannelHandler_Call { + _c.Call.Return(run) + return _c +} + +// RemoveConfigHandler provides a mock function for the type Service +func (_mock *Service) RemoveConfigHandler(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveConfigHandler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveConfigHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveConfigHandler' +type Service_RemoveConfigHandler_Call struct { + *mock.Call +} + +// RemoveConfigHandler is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Service_Expecter) RemoveConfigHandler(ctx interface{}, id interface{}) *Service_RemoveConfigHandler_Call { + return &Service_RemoveConfigHandler_Call{Call: _e.mock.On("RemoveConfigHandler", ctx, id)} +} + +func (_c *Service_RemoveConfigHandler_Call) Run(run func(ctx context.Context, id string)) *Service_RemoveConfigHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_RemoveConfigHandler_Call) Return(err error) *Service_RemoveConfigHandler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveConfigHandler_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Service_RemoveConfigHandler_Call { + _c.Call.Return(run) + return _c +} + +// Update provides a mock function for the type Service +func (_mock *Service) Update(ctx context.Context, session authn.Session, cfg bootstrap.Config) error { + ret := _mock.Called(ctx, session, cfg) + + if len(ret) == 0 { + panic("no return value specified for Update") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, bootstrap.Config) error); ok { + r0 = returnFunc(ctx, session, cfg) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type Service_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - cfg bootstrap.Config +func (_e *Service_Expecter) Update(ctx interface{}, session interface{}, cfg interface{}) *Service_Update_Call { + return &Service_Update_Call{Call: _e.mock.On("Update", ctx, session, cfg)} +} + +func (_c *Service_Update_Call) Run(run func(ctx context.Context, session authn.Session, cfg bootstrap.Config)) *Service_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 bootstrap.Config + if args[2] != nil { + arg2 = args[2].(bootstrap.Config) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_Update_Call) Return(err error) *Service_Update_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Update_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, cfg bootstrap.Config) error) *Service_Update_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCert provides a mock function for the type Service +func (_mock *Service) UpdateCert(ctx context.Context, session authn.Session, clientID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error) { + ret := _mock.Called(ctx, session, clientID, clientCert, clientKey, caCert) + + if len(ret) == 0 { + panic("no return value specified for UpdateCert") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, string) (bootstrap.Config, error)); ok { + return returnFunc(ctx, session, clientID, clientCert, clientKey, caCert) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string, string) bootstrap.Config); ok { + r0 = returnFunc(ctx, session, clientID, clientCert, clientKey, caCert) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, string, string) error); ok { + r1 = returnFunc(ctx, session, clientID, clientCert, clientKey, caCert) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCert' +type Service_UpdateCert_Call struct { + *mock.Call +} + +// UpdateCert is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - clientID string +// - clientCert string +// - clientKey string +// - caCert string +func (_e *Service_Expecter) UpdateCert(ctx interface{}, session interface{}, clientID interface{}, clientCert interface{}, clientKey interface{}, caCert interface{}) *Service_UpdateCert_Call { + return &Service_UpdateCert_Call{Call: _e.mock.On("UpdateCert", ctx, session, clientID, clientCert, clientKey, caCert)} +} + +func (_c *Service_UpdateCert_Call) Run(run func(ctx context.Context, session authn.Session, clientID string, clientCert string, clientKey string, caCert string)) *Service_UpdateCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_UpdateCert_Call) Return(config bootstrap.Config, err error) *Service_UpdateCert_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *Service_UpdateCert_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, clientID string, clientCert string, clientKey string, caCert string) (bootstrap.Config, error)) *Service_UpdateCert_Call { + _c.Call.Return(run) + return _c +} + +// UpdateChannelHandler provides a mock function for the type Service +func (_mock *Service) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) error { + ret := _mock.Called(ctx, channel) + + if len(ret) == 0 { + panic("no return value specified for UpdateChannelHandler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, bootstrap.Channel) error); ok { + r0 = returnFunc(ctx, channel) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_UpdateChannelHandler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateChannelHandler' +type Service_UpdateChannelHandler_Call struct { + *mock.Call +} + +// UpdateChannelHandler is a helper method to define mock.On call +// - ctx context.Context +// - channel bootstrap.Channel +func (_e *Service_Expecter) UpdateChannelHandler(ctx interface{}, channel interface{}) *Service_UpdateChannelHandler_Call { + return &Service_UpdateChannelHandler_Call{Call: _e.mock.On("UpdateChannelHandler", ctx, channel)} +} + +func (_c *Service_UpdateChannelHandler_Call) Run(run func(ctx context.Context, channel bootstrap.Channel)) *Service_UpdateChannelHandler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 bootstrap.Channel + if args[1] != nil { + arg1 = args[1].(bootstrap.Channel) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_UpdateChannelHandler_Call) Return(err error) *Service_UpdateChannelHandler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_UpdateChannelHandler_Call) RunAndReturn(run func(ctx context.Context, channel bootstrap.Channel) error) *Service_UpdateChannelHandler_Call { + _c.Call.Return(run) + return _c +} + +// UpdateConnections provides a mock function for the type Service +func (_mock *Service) UpdateConnections(ctx context.Context, session authn.Session, token string, id string, connections []string) error { + ret := _mock.Called(ctx, session, token, id, connections) + + if len(ret) == 0 { + panic("no return value specified for UpdateConnections") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) error); ok { + r0 = returnFunc(ctx, session, token, id, connections) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_UpdateConnections_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateConnections' +type Service_UpdateConnections_Call struct { + *mock.Call +} + +// UpdateConnections is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - token string +// - id string +// - connections []string +func (_e *Service_Expecter) UpdateConnections(ctx interface{}, session interface{}, token interface{}, id interface{}, connections interface{}) *Service_UpdateConnections_Call { + return &Service_UpdateConnections_Call{Call: _e.mock.On("UpdateConnections", ctx, session, token, id, connections)} +} + +func (_c *Service_UpdateConnections_Call) Run(run func(ctx context.Context, session authn.Session, token string, id string, connections []string)) *Service_UpdateConnections_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_UpdateConnections_Call) Return(err error) *Service_UpdateConnections_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_UpdateConnections_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, token string, id string, connections []string) error) *Service_UpdateConnections_Call { + _c.Call.Return(run) + return _c +} + +// View provides a mock function for the type Service +func (_mock *Service) View(ctx context.Context, session authn.Session, id string) (bootstrap.Config, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for View") + } + + var r0 bootstrap.Config + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (bootstrap.Config, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) bootstrap.Config); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(bootstrap.Config) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_View_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'View' +type Service_View_Call struct { + *mock.Call +} + +// View is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) View(ctx interface{}, session interface{}, id interface{}) *Service_View_Call { + return &Service_View_Call{Call: _e.mock.On("View", ctx, session, id)} +} + +func (_c *Service_View_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_View_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_View_Call) Return(config bootstrap.Config, err error) *Service_View_Call { + _c.Call.Return(config, err) + return _c +} + +func (_c *Service_View_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (bootstrap.Config, error)) *Service_View_Call { + _c.Call.Return(run) + return _c +} diff --git a/bootstrap/postgres/configs.go b/bootstrap/postgres/configs.go new file mode 100644 index 000000000..7981f39ed --- /dev/null +++ b/bootstrap/postgres/configs.go @@ -0,0 +1,771 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" +) + +var ( + errSaveChannels = errors.New("failed to insert channels to database") + errSaveConnections = errors.New("failed to insert connections to database") + errUpdateChannels = errors.New("failed to update channels in bootstrap configuration database") + errRemoveChannels = errors.New("failed to remove channels from bootstrap configuration in database") + errConnectClient = errors.New("failed to connect client in bootstrap configuration in database") + errDisconnectClient = errors.New("failed to disconnect client in bootstrap configuration in database") +) + +const cleanupQuery = `DELETE FROM channels ch WHERE NOT EXISTS ( + SELECT channel_id FROM connections c WHERE ch.magistrala_channel = c.channel_id);` + +var _ bootstrap.ConfigRepository = (*configRepository)(nil) + +type configRepository struct { + db postgres.Database + log *slog.Logger +} + +// NewConfigRepository instantiates a PostgreSQL implementation of config +// repository. +func NewConfigRepository(db postgres.Database, log *slog.Logger) bootstrap.ConfigRepository { + return &configRepository{db: db, log: log} +} + +func (cr configRepository) Save(ctx context.Context, cfg bootstrap.Config, chsConnIDs []string) (clientID string, err error) { + q := `INSERT INTO configs (magistrala_client, domain_id, name, client_cert, client_key, ca_cert, magistrala_secret, external_id, external_key, content, state) + VALUES (:magistrala_client, :domain_id, :name, :client_cert, :client_key, :ca_cert, :magistrala_secret, :external_id, :external_key, :content, :state)` + + tx, err := cr.db.BeginTxx(ctx, nil) + if err != nil { + return "", errors.Wrap(repoerr.ErrCreateEntity, err) + } + dbcfg := toDBConfig(cfg) + + defer func() { + if err != nil { + err = cr.rollback("Save method", err, tx) + } + }() + + if _, err := tx.NamedExec(q, dbcfg); err != nil { + switch pgErr := err.(type) { + case *pgconn.PgError: + if pgErr.Code == pgerrcode.UniqueViolation { + err = repoerr.ErrConflict + } + } + return "", err + } + + if err := insertChannels(cfg.DomainID, cfg.Channels, tx); err != nil { + return "", errors.Wrap(errSaveChannels, err) + } + + if err := insertConnections(ctx, cfg, chsConnIDs, tx); err != nil { + return "", errors.Wrap(errSaveConnections, err) + } + + if commitErr := tx.Commit(); commitErr != nil { + return "", commitErr + } + + return cfg.ClientID, nil +} + +func (cr configRepository) RetrieveByID(ctx context.Context, domainID, id string) (bootstrap.Config, error) { + q := `SELECT magistrala_client, magistrala_secret, external_id, external_key, name, content, state, client_cert, ca_cert + FROM configs + WHERE magistrala_client = :magistrala_client AND domain_id = :domain_id` + + dbcfg := dbConfig{ + ClientID: id, + DomainID: domainID, + } + row, err := cr.db.NamedQueryContext(ctx, q, dbcfg) + if err != nil { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + if !row.Next() { + return bootstrap.Config{}, repoerr.ErrNotFound + } + + if err := row.StructScan(&dbcfg); err != nil { + return bootstrap.Config{}, err + } + + q = `SELECT magistrala_channel, name, metadata FROM channels ch + INNER JOIN connections conn + ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id + WHERE conn.config_id = :magistrala_client AND conn.domain_id = :domain_id` + + rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg) + if err != nil { + cr.log.Error(fmt.Sprintf("Failed to retrieve connected due to %s", err)) + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + chans := []bootstrap.Channel{} + for rows.Next() { + dbch := dbChannel{} + if err := rows.StructScan(&dbch); err != nil { + cr.log.Error(fmt.Sprintf("Failed to read connected client due to %s", err)) + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + dbch.DomainID = nullString(dbcfg.DomainID) + + ch, err := toChannel(dbch) + if err != nil { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + chans = append(chans, ch) + } + + cfg := toConfig(dbcfg) + cfg.Channels = chans + + return cfg, nil +} + +func (cr configRepository) RetrieveAll(ctx context.Context, domainID string, clientIDs []string, filter bootstrap.Filter, offset, limit uint64) bootstrap.ConfigsPage { + search, params := buildRetrieveQueryParams(domainID, clientIDs, filter) + n := len(params) + + q := `SELECT magistrala_client, magistrala_secret, external_id, external_key, name, content, state + FROM configs %s ORDER BY magistrala_client LIMIT $%d OFFSET $%d` + q = fmt.Sprintf(q, search, n+1, n+2) + + rows, err := cr.db.QueryContext(ctx, q, append(params, limit, offset)...) + if err != nil { + cr.log.Error(fmt.Sprintf("Failed to retrieve configs due to %s", err)) + return bootstrap.ConfigsPage{} + } + defer rows.Close() + + var name, content sql.NullString + configs := []bootstrap.Config{} + + for rows.Next() { + c := bootstrap.Config{DomainID: domainID} + if err := rows.Scan(&c.ClientID, &c.ClientSecret, &c.ExternalID, &c.ExternalKey, &name, &content, &c.State); err != nil { + cr.log.Error(fmt.Sprintf("Failed to read retrieved config due to %s", err)) + return bootstrap.ConfigsPage{} + } + + c.Name = name.String + c.Content = content.String + configs = append(configs, c) + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM configs %s`, search) + + var total uint64 + if err := cr.db.QueryRowxContext(ctx, q, params...).Scan(&total); err != nil { + cr.log.Error(fmt.Sprintf("Failed to count configs due to %s", err)) + return bootstrap.ConfigsPage{} + } + + return bootstrap.ConfigsPage{ + Total: total, + Limit: limit, + Offset: offset, + Configs: configs, + } +} + +func (cr configRepository) RetrieveByExternalID(ctx context.Context, externalID string) (bootstrap.Config, error) { + q := `SELECT magistrala_client, magistrala_secret, external_key, domain_id, name, client_cert, client_key, ca_cert, content, state + FROM configs + WHERE external_id = :external_id` + dbcfg := dbConfig{ + ExternalID: externalID, + } + + row, err := cr.db.NamedQueryContext(ctx, q, dbcfg) + if err != nil { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + if !row.Next() { + return bootstrap.Config{}, repoerr.ErrNotFound + } + + if err := row.StructScan(&dbcfg); err != nil { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + q = `SELECT magistrala_channel, name, metadata FROM channels ch + INNER JOIN connections conn + ON ch.magistrala_channel = conn.channel_id AND ch.domain_id = conn.domain_id + WHERE conn.config_id = :magistrala_client AND conn.domain_id = :domain_id` + + rows, err := cr.db.NamedQueryContext(ctx, q, dbcfg) + if err != nil { + cr.log.Error(fmt.Sprintf("Failed to retrieve connected due to %s", err)) + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + channels := []bootstrap.Channel{} + for rows.Next() { + dbch := dbChannel{} + if err := rows.StructScan(&dbch); err != nil { + cr.log.Error(fmt.Sprintf("Failed to read connected client due to %s", err)) + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + ch, err := toChannel(dbch) + if err != nil { + cr.log.Error(fmt.Sprintf("Failed to deserialize channel due to %s", err)) + return bootstrap.Config{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + channels = append(channels, ch) + } + + cfg := toConfig(dbcfg) + cfg.Channels = channels + + return cfg, nil +} + +func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) error { + q := `UPDATE configs SET name = :name, content = :content WHERE magistrala_client = :magistrala_client AND domain_id = :domain_id ` + + dbcfg := dbConfig{ + Name: nullString(cfg.Name), + Content: nullString(cfg.Content), + ClientID: cfg.ClientID, + DomainID: cfg.DomainID, + } + + res, err := cr.db.NamedExecContext(ctx, q, dbcfg) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + cnt, err := res.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + if cnt == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (cr configRepository) UpdateCert(ctx context.Context, domainID, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { + q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_client = :magistrala_client AND domain_id = :domain_id + RETURNING magistrala_client, client_cert, client_key, ca_cert` + + dbcfg := dbConfig{ + ClientID: clientID, + ClientCert: nullString(clientCert), + DomainID: domainID, + ClientKey: nullString(clientKey), + CaCert: nullString(caCert), + } + + row, err := cr.db.NamedQueryContext(ctx, q, dbcfg) + if err != nil { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + if ok := row.Next(); !ok { + return bootstrap.Config{}, errors.Wrap(repoerr.ErrNotFound, row.Err()) + } + + if err := row.StructScan(&dbcfg); err != nil { + return bootstrap.Config{}, err + } + + return toConfig(dbcfg), nil +} + +func (cr configRepository) UpdateConnections(ctx context.Context, domainID, id string, channels []bootstrap.Channel, connections []string) (err error) { + tx, err := cr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + defer func() { + if err != nil { + err = cr.rollback("UpdateConnections method", err, tx) + } else { + if commitErr := tx.Commit(); commitErr != nil { + err = commitErr + } + } + }() + + if err = insertChannels(domainID, channels, tx); err != nil { + err = errors.Wrap(repoerr.ErrUpdateEntity, err) + return err + } + + if err = updateConnections(domainID, id, connections, tx); err != nil { + if e, ok := err.(*pgconn.PgError); ok { + if e.Code == pgerrcode.ForeignKeyViolation { + err = repoerr.ErrNotFound + } + } + err = errors.Wrap(repoerr.ErrUpdateEntity, err) + return err + } + + return nil +} + +func (cr configRepository) Remove(ctx context.Context, domainID, id string) error { + q := `DELETE FROM configs WHERE magistrala_client = :magistrala_client AND domain_id = :domain_id` + dbcfg := dbConfig{ + ClientID: id, + DomainID: domainID, + } + + if _, err := cr.db.NamedExecContext(ctx, q, dbcfg); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if _, err := cr.db.ExecContext(ctx, cleanupQuery); err != nil { + cr.log.Warn("Failed to clean dangling channels after removal") + } + + return nil +} + +func (cr configRepository) ChangeState(ctx context.Context, domainID, id string, state bootstrap.State) error { + q := `UPDATE configs SET state = :state WHERE magistrala_client = :magistrala_client AND domain_id = :domain_id;` + + dbcfg := dbConfig{ + ClientID: id, + State: state, + DomainID: domainID, + } + + res, err := cr.db.NamedExecContext(ctx, q, dbcfg) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + cnt, err := res.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + if cnt == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (cr configRepository) ListExisting(ctx context.Context, domainID string, ids []string) ([]bootstrap.Channel, error) { + var channels []bootstrap.Channel + if len(ids) == 0 { + return channels, nil + } + + var chans pgtype.TextArray + if err := chans.Set(ids); err != nil { + return []bootstrap.Channel{}, err + } + + q := "SELECT magistrala_channel, name, metadata FROM channels WHERE domain_id = $1 AND magistrala_channel = ANY ($2)" + rows, err := cr.db.QueryxContext(ctx, q, domainID, chans) + if err != nil { + return []bootstrap.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + for rows.Next() { + var dbch dbChannel + if err := rows.StructScan(&dbch); err != nil { + cr.log.Error(fmt.Sprintf("Failed to read retrieved channels due to %s", err)) + return []bootstrap.Channel{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + ch, err := toChannel(dbch) + if err != nil { + cr.log.Error(fmt.Sprintf("Failed to deserialize channel due to %s", err)) + return []bootstrap.Channel{}, err + } + + channels = append(channels, ch) + } + + return channels, nil +} + +func (cr configRepository) RemoveClient(ctx context.Context, id string) error { + q := `DELETE FROM configs WHERE magistrala_client = $1` + _, err := cr.db.ExecContext(ctx, q, id) + + if _, err := cr.db.ExecContext(ctx, cleanupQuery); err != nil { + cr.log.Warn("Failed to clean dangling channels after removal") + } + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + return nil +} + +func (cr configRepository) UpdateChannel(ctx context.Context, c bootstrap.Channel) error { + dbch, err := toDBChannel("", c) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + q := `UPDATE channels SET name = :name, metadata = :metadata, updated_at = :updated_at, updated_by = :updated_by + WHERE magistrala_channel = :magistrala_channel` + if _, err = cr.db.NamedExecContext(ctx, q, dbch); err != nil { + return errors.Wrap(errUpdateChannels, err) + } + return nil +} + +func (cr configRepository) RemoveChannel(ctx context.Context, id string) error { + q := `DELETE FROM channels WHERE magistrala_channel = $1` + if _, err := cr.db.ExecContext(ctx, q, id); err != nil { + return errors.Wrap(errRemoveChannels, err) + } + return nil +} + +func (cr configRepository) ConnectClient(ctx context.Context, channelID, clientID string) error { + q := `UPDATE configs SET state = $1 + WHERE magistrala_client = $2 + AND EXISTS (SELECT 1 FROM connections WHERE config_id = $2 AND channel_id = $3)` + + result, err := cr.db.ExecContext(ctx, q, bootstrap.Active, clientID, channelID) + if err != nil { + return errors.Wrap(errConnectClient, err) + } + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + return nil +} + +func (cr configRepository) DisconnectClient(ctx context.Context, channelID, clientID string) error { + q := `UPDATE configs SET state = $1 + WHERE magistrala_client = $2 + AND EXISTS (SELECT 1 FROM connections WHERE config_id = $2 AND channel_id = $3)` + _, err := cr.db.ExecContext(ctx, q, bootstrap.Inactive, clientID, channelID) + if err != nil { + return errors.Wrap(errDisconnectClient, err) + } + return nil +} + +func buildRetrieveQueryParams(domainID string, clientIDs []string, filter bootstrap.Filter) (string, []any) { + params := []any{} + queries := []string{} + + if len(clientIDs) != 0 { + queries = append(queries, fmt.Sprintf("magistrala_client IN ('%s')", strings.Join(clientIDs, "','"))) + } else if domainID != "" { + params = append(params, domainID) + queries = append(queries, fmt.Sprintf("domain_id = $%d", len(params))) + } + + // Adjust the starting point for placeholders based on the current length of params + counter := len(params) + 1 + for k, v := range filter.FullMatch { + params = append(params, v) + queries = append(queries, fmt.Sprintf("%s = $%d", k, counter)) + counter++ + } + for k, v := range filter.PartialMatch { + params = append(params, v) + queries = append(queries, fmt.Sprintf("LOWER(%s) LIKE '%%' || $%d || '%%'", k, counter)) + counter++ + } + + if len(queries) > 0 { + return "WHERE " + strings.Join(queries, " AND "), params + } + return "", params +} + +func (cr configRepository) rollback(content string, defErr error, tx *sqlx.Tx) error { + if err := tx.Rollback(); err != nil { + return errors.Wrap(defErr, errors.Wrap(errors.New("failed to rollback at "+content), err)) + } + + return defErr +} + +func insertChannels(domainID string, channels []bootstrap.Channel, tx *sqlx.Tx) error { + if len(channels) == 0 { + return nil + } + + var chans []dbChannel + for _, ch := range channels { + dbch, err := toDBChannel(domainID, ch) + if err != nil { + return err + } + chans = append(chans, dbch) + } + q := `INSERT INTO channels (magistrala_channel, domain_id, name, metadata, parent_id, description, created_at, updated_at, updated_by, status) + VALUES (:magistrala_channel, :domain_id, :name, :metadata, :parent_id, :description, :created_at, :updated_at, :updated_by, :status)` + if _, err := tx.NamedExec(q, chans); err != nil { + e := err + if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgerrcode.UniqueViolation { + e = repoerr.ErrConflict + } + return e + } + + return nil +} + +func insertConnections(_ context.Context, cfg bootstrap.Config, connections []string, tx *sqlx.Tx) error { + if len(connections) == 0 { + return nil + } + + q := `INSERT INTO connections (config_id, channel_id, domain_id) + VALUES (:config_id, :channel_id, :domain_id)` + + conns := []dbConnection{} + for _, conn := range connections { + dbconn := dbConnection{ + Config: cfg.ClientID, + Channel: conn, + DomainID: cfg.DomainID, + } + conns = append(conns, dbconn) + } + _, err := tx.NamedExec(q, conns) + + return err +} + +func updateConnections(domainID, id string, connections []string, tx *sqlx.Tx) error { + if len(connections) == 0 { + return nil + } + + q := `DELETE FROM connections + WHERE config_id = $1 AND domain_id = $2 + AND channel_id NOT IN ($3)` + + var conn pgtype.TextArray + if err := conn.Set(connections); err != nil { + return err + } + + res, err := tx.Exec(q, id, domainID, conn) + if err != nil { + return err + } + + cnt, err := res.RowsAffected() + if err != nil { + return err + } + + q = `INSERT INTO connections (config_id, channel_id, domain_id) + VALUES (:config_id, :channel_id, :domain_id)` + + conns := []dbConnection{} + for _, conn := range connections { + dbconn := dbConnection{ + Config: id, + Channel: conn, + DomainID: domainID, + } + conns = append(conns, dbconn) + } + + if _, err := tx.NamedExec(q, conns); err != nil { + return err + } + + if cnt == 0 { + return nil + } + + _, err = tx.Exec(cleanupQuery) + + return err +} + +func nullString(s string) sql.NullString { + if s == "" { + return sql.NullString{} + } + + return sql.NullString{ + String: s, + Valid: true, + } +} + +func nullTime(t time.Time) sql.NullTime { + if t.IsZero() { + return sql.NullTime{} + } + + return sql.NullTime{ + Time: t, + Valid: true, + } +} + +type dbConfig struct { + DomainID string `db:"domain_id"` + ClientID string `db:"magistrala_client"` + ClientSecret string `db:"magistrala_secret"` + Name sql.NullString `db:"name"` + ClientCert sql.NullString `db:"client_cert"` + ClientKey sql.NullString `db:"client_key"` + CaCert sql.NullString `db:"ca_cert"` + ExternalID string `db:"external_id"` + ExternalKey string `db:"external_key"` + Content sql.NullString `db:"content"` + State bootstrap.State `db:"state"` +} + +func toDBConfig(cfg bootstrap.Config) dbConfig { + return dbConfig{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + DomainID: cfg.DomainID, + Name: nullString(cfg.Name), + ClientCert: nullString(cfg.ClientCert), + ClientKey: nullString(cfg.ClientKey), + CaCert: nullString(cfg.CACert), + ExternalID: cfg.ExternalID, + ExternalKey: cfg.ExternalKey, + Content: nullString(cfg.Content), + State: cfg.State, + } +} + +func toConfig(dbcfg dbConfig) bootstrap.Config { + cfg := bootstrap.Config{ + ClientID: dbcfg.ClientID, + ClientSecret: dbcfg.ClientSecret, + DomainID: dbcfg.DomainID, + ExternalID: dbcfg.ExternalID, + ExternalKey: dbcfg.ExternalKey, + State: dbcfg.State, + } + + if dbcfg.Name.Valid { + cfg.Name = dbcfg.Name.String + } + + if dbcfg.Content.Valid { + cfg.Content = dbcfg.Content.String + } + + if dbcfg.ClientCert.Valid { + cfg.ClientCert = dbcfg.ClientCert.String + } + + if dbcfg.ClientKey.Valid { + cfg.ClientKey = dbcfg.ClientKey.String + } + + if dbcfg.CaCert.Valid { + cfg.CACert = dbcfg.CaCert.String + } + return cfg +} + +type dbChannel struct { + ID string `db:"magistrala_channel"` + Name sql.NullString `db:"name"` + DomainID sql.NullString `db:"domain_id"` + Metadata string `db:"metadata"` + Parent sql.NullString `db:"parent_id,omitempty"` + Description string `db:"description,omitempty"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy sql.NullString `db:"updated_by,omitempty"` + Status clients.Status `db:"status"` +} + +func toDBChannel(domainID string, ch bootstrap.Channel) (dbChannel, error) { + dbch := dbChannel{ + ID: ch.ID, + Name: nullString(ch.Name), + DomainID: nullString(domainID), + Parent: nullString(ch.Parent), + Description: ch.Description, + CreatedAt: ch.CreatedAt, + UpdatedAt: nullTime(ch.UpdatedAt), + UpdatedBy: nullString(ch.UpdatedBy), + Status: ch.Status, + } + + metadata, err := json.Marshal(ch.Metadata) + if err != nil { + return dbChannel{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + + dbch.Metadata = string(metadata) + return dbch, nil +} + +func toChannel(dbch dbChannel) (bootstrap.Channel, error) { + ch := bootstrap.Channel{ + ID: dbch.ID, + Description: dbch.Description, + CreatedAt: dbch.CreatedAt, + Status: dbch.Status, + } + + if dbch.Name.Valid { + ch.Name = dbch.Name.String + } + if dbch.DomainID.Valid { + ch.DomainID = dbch.DomainID.String + } + if dbch.Parent.Valid { + ch.Parent = dbch.Parent.String + } + if dbch.UpdatedBy.Valid { + ch.UpdatedBy = dbch.UpdatedBy.String + } + if dbch.UpdatedAt.Valid { + ch.UpdatedAt = dbch.UpdatedAt.Time + } + + if err := json.Unmarshal([]byte(dbch.Metadata), &ch.Metadata); err != nil { + return bootstrap.Channel{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + + return ch, nil +} + +type dbConnection struct { + Config string `db:"config_id"` + Channel string `db:"channel_id"` + DomainID string `db:"domain_id"` +} diff --git a/bootstrap/postgres/configs_test.go b/bootstrap/postgres/configs_test.go new file mode 100644 index 000000000..d65c70c95 --- /dev/null +++ b/bootstrap/postgres/configs_test.go @@ -0,0 +1,913 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "strconv" + "testing" + + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/bootstrap/postgres" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/gofrs/uuid/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const numConfigs = 10 + +var ( + config = bootstrap.Config{ + ClientID: "smq-client", + ClientSecret: "smq-key", + ExternalID: "external-id", + ExternalKey: "external-key", + DomainID: testsutil.GenerateUUID(&testing.T{}), + Channels: []bootstrap.Channel{ + {ID: "1", Name: "name 1", Metadata: map[string]any{"meta": 1.0}}, + {ID: "2", Name: "name 2", Metadata: map[string]any{"meta": 2.0}}, + }, + Content: "content", + State: bootstrap.Inactive, + } + + channels = []string{"1", "2"} +) + +func TestSave(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + diff := "different" + + duplicateClient := config + duplicateClient.ExternalID = diff + duplicateClient.ClientSecret = diff + duplicateClient.Channels = []bootstrap.Channel{} + + duplicateExternal := config + duplicateExternal.ClientID = diff + duplicateExternal.ClientSecret = diff + duplicateExternal.Channels = []bootstrap.Channel{} + + duplicateChannels := config + duplicateChannels.ExternalID = diff + duplicateChannels.ClientSecret = diff + duplicateChannels.ClientID = diff + + cases := []struct { + desc string + config bootstrap.Config + connections []string + err error + }{ + { + desc: "save a config", + config: config, + connections: channels, + err: nil, + }, + { + desc: "save config with same Client ID", + config: duplicateClient, + connections: nil, + err: repoerr.ErrConflict, + }, + { + desc: "save config with same external ID", + config: duplicateExternal, + connections: nil, + err: repoerr.ErrConflict, + }, + { + desc: "save config with same Channels", + config: duplicateChannels, + connections: channels, + err: repoerr.ErrConflict, + }, + } + for _, tc := range cases { + id, err := repo.Save(context.Background(), tc.config, tc.connections) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, id, tc.config.ClientID, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.config.ClientID, id)) + } + } +} + +func TestRetrieveByID(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + id, err := repo.Save(context.Background(), c, channels) + require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + nonexistentConfID, err := uuid.NewV4() + require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + + cases := []struct { + desc string + domainID string + id string + err error + }{ + { + desc: "retrieve config", + domainID: c.DomainID, + id: id, + err: nil, + }, + { + desc: "retrieve config with wrong domain ID ", + domainID: "2", + id: id, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve a non-existing config", + domainID: c.DomainID, + id: nonexistentConfID.String(), + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve a config with invalid ID", + domainID: c.DomainID, + id: "invalid", + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + _, err := repo.RetrieveByID(context.Background(), tc.domainID, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestRetrieveAll(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + clientIDs := make([]string, numConfigs) + + for i := 0; i < numConfigs; i++ { + c := config + + // Use UUID to prevent conflict errors. + uid, err := uuid.NewV4() + require.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ExternalID = uid.String() + c.Name = fmt.Sprintf("name %d", i) + c.ClientID = uid.String() + c.ClientSecret = uid.String() + + clientIDs[i] = c.ClientID + + if i%2 == 0 { + c.State = bootstrap.Active + } + + if i > 0 { + c.Channels = nil + } + + _, err = repo.Save(context.Background(), c, channels) + require.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + } + cases := []struct { + desc string + domainID string + clientID []string + offset uint64 + limit uint64 + filter bootstrap.Filter + size int + }{ + { + desc: "retrieve all configs", + domainID: config.DomainID, + clientID: []string{}, + offset: 0, + limit: uint64(numConfigs), + size: numConfigs, + }, + { + desc: "retrieve a subset of configs", + domainID: config.DomainID, + clientID: []string{}, + offset: 5, + limit: uint64(numConfigs - 5), + size: numConfigs - 5, + }, + { + desc: "retrieve with wrong domain ID ", + domainID: "2", + clientID: []string{}, + offset: 0, + limit: uint64(numConfigs), + size: 0, + }, + { + desc: "retrieve all active configs ", + domainID: config.DomainID, + clientID: []string{}, + offset: 0, + limit: uint64(numConfigs), + filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}}, + size: numConfigs / 2, + }, + { + desc: "retrieve all with partial match filter", + domainID: config.DomainID, + clientID: []string{}, + offset: 0, + limit: uint64(numConfigs), + filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}}, + size: 1, + }, + { + desc: "retrieve search by name", + domainID: config.DomainID, + clientID: []string{}, + offset: 0, + limit: uint64(numConfigs), + filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "1"}}, + size: 1, + }, + { + desc: "retrieve by valid clientIDs", + domainID: config.DomainID, + clientID: clientIDs, + offset: 0, + limit: uint64(numConfigs), + size: 10, + }, + { + desc: "retrieve by non-existing clientID", + domainID: config.DomainID, + clientID: []string{"non-existing"}, + offset: 0, + limit: uint64(numConfigs), + size: 0, + }, + } + for _, tc := range cases { + ret := repo.RetrieveAll(context.Background(), tc.domainID, tc.clientID, tc.filter, tc.offset, tc.limit) + size := len(ret.Configs) + assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size)) + } +} + +func TestRetrieveByExternalID(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + cases := []struct { + desc string + externalID string + err error + }{ + { + desc: "retrieve with invalid external ID", + externalID: strconv.Itoa(numConfigs + 1), + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve with external key", + externalID: c.ExternalID, + err: nil, + }, + } + for _, tc := range cases { + _, err := repo.RetrieveByExternalID(context.Background(), tc.externalID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestUpdate(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + c.Content = "new content" + c.Name = "new name" + + wrongDomainID := c + wrongDomainID.DomainID = "3" + + cases := []struct { + desc string + id string + config bootstrap.Config + err error + }{ + { + desc: "update with wrong domainID ", + config: wrongDomainID, + err: repoerr.ErrNotFound, + }, + { + desc: "update a config", + config: c, + err: nil, + }, + } + for _, tc := range cases { + err := repo.Update(context.Background(), tc.config) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestUpdateCert(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + c.Content = "new content" + c.Name = "new name" + + wrongDomainID := c + wrongDomainID.DomainID = "3" + + cases := []struct { + desc string + clientID string + domainID string + cert string + certKey string + ca string + expectedConfig bootstrap.Config + err error + }{ + { + desc: "update with wrong domain ID ", + clientID: "", + cert: "cert", + certKey: "certKey", + ca: "", + domainID: wrongDomainID.DomainID, + expectedConfig: bootstrap.Config{}, + err: repoerr.ErrNotFound, + }, + { + desc: "update a config", + clientID: c.ClientID, + cert: "cert", + certKey: "certKey", + ca: "ca", + domainID: c.DomainID, + expectedConfig: bootstrap.Config{ + ClientID: c.ClientID, + ClientCert: "cert", + CACert: "ca", + ClientKey: "certKey", + DomainID: c.DomainID, + }, + err: nil, + }, + } + for _, tc := range cases { + cfg, err := repo.UpdateCert(context.Background(), tc.domainID, tc.clientID, tc.cert, tc.certKey, tc.ca) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg)) + } +} + +func TestUpdateConnections(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + // Use UUID to prevent conflicts. + uid, err = uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + c.Channels = []bootstrap.Channel{} + c2, err := repo.Save(context.Background(), c, []string{channels[0]}) + assert.Nil(t, err, fmt.Sprintf("Saving a config expected to succeed: %s.\n", err)) + + cases := []struct { + desc string + domainID string + id string + channels []bootstrap.Channel + connections []string + err error + }{ + { + desc: "update connections of non-existing config", + domainID: config.DomainID, + id: "unknown", + channels: nil, + connections: []string{channels[1]}, + err: repoerr.ErrNotFound, + }, + { + desc: "update connections", + domainID: config.DomainID, + id: c.ClientID, + channels: nil, + connections: []string{channels[1]}, + err: nil, + }, + { + desc: "update connections with existing channels", + domainID: config.DomainID, + id: c2, + channels: nil, + connections: channels, + err: nil, + }, + { + desc: "update connections no channels", + domainID: config.DomainID, + id: c.ClientID, + channels: nil, + connections: nil, + err: nil, + }, + } + for _, tc := range cases { + err := repo.UpdateConnections(context.Background(), tc.domainID, tc.id, tc.channels, tc.connections) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestRemove(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + id, err := repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + // Removal works the same for both existing and non-existing + // (removed) config + for i := 0; i < 2; i++ { + err := repo.Remove(context.Background(), c.DomainID, id) + assert.Nil(t, err, fmt.Sprintf("%d: failed to remove config due to: %s", i, err)) + + _, err = repo.RetrieveByID(context.Background(), c.DomainID, id) + assert.True(t, errors.Contains(err, repoerr.ErrNotFound), fmt.Sprintf("%d: expected %s got %s", i, repoerr.ErrNotFound, err)) + } +} + +func TestChangeState(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + saved, err := repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + cases := []struct { + desc string + domainID string + id string + state bootstrap.State + err error + }{ + { + desc: "change state with wrong domain ID ", + id: saved, + domainID: "2", + err: repoerr.ErrNotFound, + }, + { + desc: "change state with wrong id", + id: "wrong", + domainID: c.DomainID, + err: repoerr.ErrNotFound, + }, + { + desc: "change state to Active", + id: saved, + domainID: c.DomainID, + state: bootstrap.Active, + err: nil, + }, + { + desc: "change state to Inactive", + id: saved, + domainID: c.DomainID, + state: bootstrap.Inactive, + err: nil, + }, + } + for _, tc := range cases { + err := repo.ChangeState(context.Background(), tc.domainID, tc.id, tc.state) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestListExisting(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + var chs []bootstrap.Channel + chs = append(chs, config.Channels...) + + cases := []struct { + desc string + domainID string + connections []string + existing []bootstrap.Channel + }{ + { + desc: "list all existing channels", + domainID: c.DomainID, + connections: channels, + existing: chs, + }, + { + desc: "list a subset of existing channels", + domainID: c.DomainID, + connections: []string{channels[0], "5"}, + existing: []bootstrap.Channel{chs[0]}, + }, + { + desc: "list a subset of existing channels empty", + domainID: c.DomainID, + connections: []string{"5", "6"}, + existing: []bootstrap.Channel{}, + }, + } + for _, tc := range cases { + existing, err := repo.ListExisting(context.Background(), tc.domainID, tc.connections) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error: %s", tc.desc, err)) + assert.ElementsMatch(t, tc.existing, existing, fmt.Sprintf("%s: Got non-matching elements.", tc.desc)) + } +} + +func TestRemoveClient(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + saved, err := repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + for i := 0; i < 2; i++ { + err := repo.RemoveClient(context.Background(), saved) + assert.Nil(t, err, fmt.Sprintf("an unexpected error occurred: %s\n", err)) + } +} + +func TestUpdateChannel(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + id := c.Channels[0].ID + update := bootstrap.Channel{ + ID: id, + Name: "update name", + Metadata: map[string]any{"update": "metadata update"}, + } + err = repo.UpdateChannel(context.Background(), update) + assert.Nil(t, err, fmt.Sprintf("updating config expected to succeed: %s.\n", err)) + + cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ClientID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + var retrieved bootstrap.Channel + for _, c := range cfg.Channels { + if c.ID == id { + retrieved = c + break + } + } + update.DomainID = retrieved.DomainID + assert.Equal(t, update, retrieved, fmt.Sprintf("expected %s, go %s", update, retrieved)) +} + +func TestRemoveChannel(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + _, err = repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + err = repo.RemoveChannel(context.Background(), c.Channels[0].ID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + + cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ClientID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + assert.NotContains(t, cfg.Channels, c.Channels[0], fmt.Sprintf("expected to remove channel %s from %s", c.Channels[0], cfg.Channels)) +} + +func TestConnectClient(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + c.State = bootstrap.Inactive + saved, err := repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + wrongID := testsutil.GenerateUUID(&testing.T{}) + + connectedClient := c + + randomClient := c + randomClientID, _ := uuid.NewV4() + randomClient.ClientID = randomClientID.String() + + emptyClient := c + emptyClient.ClientID = "" + + cases := []struct { + desc string + domainID string + id string + state bootstrap.State + channels []bootstrap.Channel + connections []string + err error + }{ + { + desc: "connect disconnected client", + domainID: c.DomainID, + id: saved, + state: bootstrap.Inactive, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "connect already connected client", + domainID: c.DomainID, + id: connectedClient.ClientID, + state: connectedClient.State, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "connect non-existent client", + domainID: c.DomainID, + id: wrongID, + channels: c.Channels, + connections: channels, + err: repoerr.ErrNotFound, + }, + { + desc: "connect random client", + domainID: c.DomainID, + id: randomClient.ClientID, + channels: c.Channels, + connections: channels, + err: repoerr.ErrNotFound, + }, + { + desc: "connect empty client", + domainID: c.DomainID, + id: emptyClient.ClientID, + channels: c.Channels, + connections: channels, + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + for i, ch := range tc.channels { + if i == 0 { + err = repo.ConnectClient(context.Background(), ch.ID, tc.id) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err)) + cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ClientID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg)) + } else { + _ = repo.ConnectClient(context.Background(), ch.ID, tc.id) + } + } + + cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ClientID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + assert.Equal(t, cfg.State, bootstrap.Active, fmt.Sprintf("expected to be active when a connection is added from %s", cfg)) + } +} + +func TestDisconnectClient(t *testing.T) { + repo := postgres.NewConfigRepository(db, testLog) + err := deleteChannels(context.Background(), repo) + require.Nil(t, err, "Channels cleanup expected to succeed.") + + c := config + // Use UUID to prevent conflicts. + uid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("Got unexpected error: %s.\n", err)) + c.ClientSecret = uid.String() + c.ClientID = uid.String() + c.ExternalID = uid.String() + c.ExternalKey = uid.String() + c.State = bootstrap.Inactive + saved, err := repo.Save(context.Background(), c, channels) + assert.Nil(t, err, fmt.Sprintf("Saving config expected to succeed: %s.\n", err)) + + wrongID := testsutil.GenerateUUID(&testing.T{}) + + connectedClient := c + + randomClient := c + randomClientID, _ := uuid.NewV4() + randomClient.ClientID = randomClientID.String() + + emptyClient := c + emptyClient.ClientID = "" + + cases := []struct { + desc string + domainID string + id string + state bootstrap.State + channels []bootstrap.Channel + connections []string + err error + }{ + { + desc: "disconnect connected client", + domainID: c.DomainID, + id: connectedClient.ClientID, + state: connectedClient.State, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "disconnect already disconnected client", + domainID: c.DomainID, + id: saved, + state: bootstrap.Inactive, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "disconnect invalid client", + domainID: c.DomainID, + id: wrongID, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "disconnect random client", + domainID: c.DomainID, + id: randomClient.ClientID, + channels: c.Channels, + connections: channels, + err: nil, + }, + { + desc: "disconnect empty client", + domainID: c.DomainID, + id: emptyClient.ClientID, + channels: c.Channels, + connections: channels, + err: nil, + }, + } + + for _, tc := range cases { + for _, ch := range tc.channels { + err = repo.DisconnectClient(context.Background(), ch.ID, tc.id) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: Expected error: %s, got: %s.\n", tc.desc, tc.err, err)) + } + + cfg, err := repo.RetrieveByID(context.Background(), c.DomainID, c.ClientID) + assert.Nil(t, err, fmt.Sprintf("Retrieving config expected to succeed: %s.\n", err)) + assert.Equal(t, cfg.State, bootstrap.Inactive, fmt.Sprintf("expected to be inactive when a connection is removed from %s", cfg)) + } +} + +func deleteChannels(ctx context.Context, repo bootstrap.ConfigRepository) error { + for _, ch := range channels { + if err := repo.RemoveChannel(ctx, ch); err != nil { + return err + } + } + + return nil +} diff --git a/bootstrap/postgres/doc.go b/bootstrap/postgres/doc.go new file mode 100644 index 000000000..73a678477 --- /dev/null +++ b/bootstrap/postgres/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres contains repository implementations using PostgreSQL as +// the underlying database. +package postgres diff --git a/bootstrap/postgres/init.go b/bootstrap/postgres/init.go new file mode 100644 index 000000000..5ab55938d --- /dev/null +++ b/bootstrap/postgres/init.go @@ -0,0 +1,108 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import migrate "github.com/rubenv/sql-migrate" + +// Migration of bootstrap service. +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "configs_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS configs ( + mainflux_client TEXT UNIQUE NOT NULL, + owner VARCHAR(254), + name TEXT, + mainflux_key CHAR(36) UNIQUE NOT NULL, + external_id TEXT UNIQUE NOT NULL, + external_key TEXT NOT NULL, + content TEXT, + client_cert TEXT, + client_key TEXT, + ca_cert TEXT, + state BIGINT NOT NULL, + PRIMARY KEY (mainflux_client, owner) + )`, + `CREATE TABLE IF NOT EXISTS unknown_configs ( + external_id TEXT UNIQUE NOT NULL, + external_key TEXT NOT NULL, + PRIMARY KEY (external_id, external_key) + )`, + `CREATE TABLE IF NOT EXISTS channels ( + mainflux_channel TEXT UNIQUE NOT NULL, + owner VARCHAR(254), + name TEXT, + metadata JSON, + PRIMARY KEY (mainflux_channel, owner) + )`, + `CREATE TABLE IF NOT EXISTS connections ( + channel_id TEXT, + channel_owner VARCHAR(256), + config_id TEXT, + config_owner VARCHAR(256), + FOREIGN KEY (channel_id, channel_owner) REFERENCES channels (mainflux_channel, owner) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (config_id, config_owner) REFERENCES configs (mainflux_client, owner) ON DELETE CASCADE ON UPDATE CASCADE, + PRIMARY KEY (channel_id, channel_owner, config_id, config_owner) + )`, + }, + Down: []string{ + "DROP TABLE connections", + "DROP TABLE configs", + "DROP TABLE channels", + "DROP TABLE unknown_configs", + }, + }, + { + Id: "configs_2", + Up: []string{ + "DROP TABLE IF EXISTS unknown_configs", + }, + Down: []string{ + "CREATE TABLE IF NOT EXISTS unknown_configs", + }, + }, + { + Id: "configs_3", + Up: []string{ + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS parent_id VARCHAR(36)`, + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS description VARCHAR(1024)`, + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS created_at TIMESTAMP`, + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_at TIMESTAMP`, + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS updated_by VARCHAR(254)`, + `ALTER TABLE IF EXISTS channels ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0)`, + }, + }, + { + Id: "configs_4", + Up: []string{ + `ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_client TO magistrala_client`, + `ALTER TABLE IF EXISTS configs RENAME COLUMN mainflux_key TO magistrala_secret`, + `ALTER TABLE IF EXISTS channels RENAME COLUMN mainflux_channel TO magistrala_channel`, + }, + }, + { + Id: "configs_5", + Up: []string{ + `ALTER TABLE IF EXISTS configs RENAME COLUMN owner TO domain_id`, + `ALTER TABLE IF EXISTS channels RENAME COLUMN owner TO domain_id`, + `ALTER TABLE IF EXISTS configs ADD CONSTRAINT configs_name_domain_id_key UNIQUE (name, domain_id)`, + }, + }, + { + Id: "configs_6", + Up: []string{ + `ALTER TABLE IF EXISTS connections DROP CONSTRAINT IF EXISTS connections_pkey`, + `ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS channel_owner`, + `ALTER TABLE IF EXISTS connections DROP COLUMN IF EXISTS config_owner`, + `ALTER TABLE IF EXISTS connections ADD COLUMN IF NOT EXISTS domain_id VARCHAR(256) NOT NULL`, + `ALTER TABLE IF EXISTS connections ADD CONSTRAINT connections_pkey PRIMARY KEY (channel_id, config_id, domain_id)`, + `ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (channel_id, domain_id) REFERENCES channels (magistrala_channel, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`, + `ALTER TABLE IF EXISTS connections ADD FOREIGN KEY (config_id, domain_id) REFERENCES configs (magistrala_client, domain_id) ON DELETE CASCADE ON UPDATE CASCADE`, + }, + }, + }, + } +} diff --git a/bootstrap/postgres/setup_test.go b/bootstrap/postgres/setup_test.go new file mode 100644 index 000000000..4151768d9 --- /dev/null +++ b/bootstrap/postgres/setup_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/absmach/supermq/bootstrap/postgres" + smqlog "github.com/absmach/supermq/logger" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var ( + testLog, _ = smqlog.New(os.Stdout, "info") + db *sqlx.DB +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err)) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + testLog.Error(fmt.Sprintf("Could not connect to docker: %s", err)) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = pgclient.Setup(dbConfig, *postgres.Migration()); err != nil { + testLog.Error(fmt.Sprintf("Could not setup test DB connection: %s", err)) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + testLog.Error(fmt.Sprintf("Could not purge container: %s", err)) + } + + os.Exit(code) +} diff --git a/bootstrap/reader.go b/bootstrap/reader.go new file mode 100644 index 000000000..91b4b5e83 --- /dev/null +++ b/bootstrap/reader.go @@ -0,0 +1,95 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/json" + "io" + "net/http" +) + +// bootstrapRes represent SuperMQ Response to the Bootatrap request. +// This is used as a response from ConfigReader and can easily be +// replace with any other response format. +type bootstrapRes struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Channels []channelRes `json:"channels"` + Content string `json:"content,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` +} + +type channelRes struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +func (res bootstrapRes) Code() int { + return http.StatusOK +} + +func (res bootstrapRes) Headers() map[string]string { + return map[string]string{} +} + +func (res bootstrapRes) Empty() bool { + return false +} + +type reader struct { + encKey []byte +} + +// NewConfigReader return new reader which is used to generate response +// from the config. +func NewConfigReader(encKey []byte) ConfigReader { + return reader{encKey: encKey} +} + +func (r reader) ReadConfig(cfg Config, secure bool) (any, error) { + var channels []channelRes + for _, ch := range cfg.Channels { + channels = append(channels, channelRes{ID: ch.ID, Name: ch.Name, Metadata: ch.Metadata}) + } + + res := bootstrapRes{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + Channels: channels, + Content: cfg.Content, + ClientCert: cfg.ClientCert, + ClientKey: cfg.ClientKey, + CACert: cfg.CACert, + } + if secure { + b, err := json.Marshal(res) + if err != nil { + return nil, err + } + return r.encrypt(b) + } + + return res, nil +} + +func (r reader) encrypt(in []byte) ([]byte, error) { + block, err := aes.NewCipher(r.encKey) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(in)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], in) + return ciphertext, nil +} diff --git a/bootstrap/reader_test.go b/bootstrap/reader_test.go new file mode 100644 index 000000000..8f617ec20 --- /dev/null +++ b/bootstrap/reader_test.go @@ -0,0 +1,126 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap_test + +import ( + "crypto/aes" + "crypto/cipher" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" +) + +type readChan struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +type readResp struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Channels []readChan `json:"channels"` + Content string `json:"content,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` +} + +func dec(in []byte) ([]byte, error) { + block, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + if len(in) < aes.BlockSize { + return nil, errors.ErrMalformedEntity + } + iv := in[:aes.BlockSize] + in = in[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(in, in) + return in, nil +} + +func TestReadConfig(t *testing.T) { + cfg := bootstrap.Config{ + ClientID: "smq_id", + ClientCert: "client_cert", + ClientKey: "client_key", + CACert: "ca_cert", + ClientSecret: "smq_key", + Channels: []bootstrap.Channel{ + { + ID: "smq_id", + Name: "smq_name", + Metadata: map[string]any{"key": "value}"}, + }, + }, + Content: "content", + } + ret := readResp{ + ClientID: "smq_id", + ClientSecret: "smq_key", + Channels: []readChan{ + { + ID: "smq_id", + Name: "smq_name", + Metadata: map[string]any{"key": "value}"}, + }, + }, + Content: "content", + ClientCert: "client_cert", + ClientKey: "client_key", + CACert: "ca_cert", + } + + bin, err := json.Marshal(ret) + assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err)) + + reader := bootstrap.NewConfigReader(encKey) + cases := []struct { + desc string + config bootstrap.Config + enc []byte + secret bool + err error + }{ + { + desc: "read a config", + config: cfg, + enc: bin, + secret: false, + }, + { + desc: "read encrypted config", + config: cfg, + enc: bin, + secret: true, + }, + } + + for _, tc := range cases { + res, err := reader.ReadConfig(tc.config, tc.secret) + assert.Nil(t, err, fmt.Sprintf("Reading config to succeed: %s.\n", err)) + + if tc.secret { + d, err := dec(res.([]byte)) + assert.Nil(t, err, fmt.Sprintf("Decrypting expected to succeed: %s.\n", err)) + assert.Equal(t, tc.enc, d, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, d)) + continue + } + b, err := json.Marshal(res) + assert.Nil(t, err, fmt.Sprintf("Marshalling expected to succeed: %s.\n", err)) + assert.Equal(t, tc.enc, b, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.enc, b)) + resp, ok := res.(supermq.Response) + assert.True(t, ok, "If not encrypted, reader should return response.") + assert.False(t, resp.Empty(), fmt.Sprintf("Response should not be empty %s.", err)) + assert.Equal(t, http.StatusOK, resp.Code(), "Default config response code should be 200.") + } +} diff --git a/bootstrap/service.go b/bootstrap/service.go new file mode 100644 index 000000000..cd4462c26 --- /dev/null +++ b/bootstrap/service.go @@ -0,0 +1,503 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "encoding/hex" + + "github.com/absmach/supermq" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/policies" + mgsdk "github.com/absmach/supermq/pkg/sdk" +) + +var ( + // ErrClients indicates failure to communicate with SuperMQ Clients service. + // It can be due to networking error or invalid/unauthenticated request. + ErrClients = errors.New("failed to receive response from Clients service") + + // ErrExternalKey indicates a non-existent bootstrap configuration for given external key. + ErrExternalKey = errors.NewAuthZError("failed to get bootstrap configuration for given external key") + + // ErrExternalKeySecure indicates error in getting bootstrap configuration for given encrypted external key. + ErrExternalKeySecure = errors.NewAuthZError("failed to get bootstrap configuration for given encrypted external key") + + // ErrBootstrap indicates error in getting bootstrap configuration. + ErrBootstrap = errors.New("failed to read bootstrap configuration") + + // ErrAddBootstrap indicates error in adding bootstrap configuration. + ErrAddBootstrap = errors.NewServiceError("failed to add bootstrap configuration") + + // ErrBootstrapState indicates an invalid bootstrap state. + ErrBootstrapState = errors.NewRequestError("invalid bootstrap state") + + // ErrNotInSameDomain indicates entities are not in the same domain. + errNotInSameDomain = errors.New("entities are not in the same domain") + + errUpdateConnections = errors.New("failed to update connections") + errRemoveBootstrap = errors.New("failed to remove bootstrap configuration") + errChangeState = errors.New("failed to change state of bootstrap configuration") + errUpdateChannel = errors.New("failed to update channel") + errRemoveConfig = errors.New("failed to remove bootstrap configuration") + errRemoveChannel = errors.New("failed to remove channel") + errCreateClient = errors.New("failed to create client") + errConnectClient = errors.New("failed to connect client") + errDisconnectClient = errors.New("failed to disconnect client") + errCheckChannels = errors.New("failed to check if channels exists") + errConnectionChannels = errors.New("failed to check channels connections") + errClientNotFound = errors.New("failed to find client") + errUpdateCert = errors.New("failed to update cert") +) + +var _ Service = (*bootstrapService)(nil) + +// Service specifies an API that must be fulfilled by the domain service +// implementation, and all of its decorators (e.g. logging & metrics). +type Service interface { + // Add adds new Client Config to the user identified by the provided token. + Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error) + + // View returns Client Config with given ID belonging to the user identified by the given token. + View(ctx context.Context, session smqauthn.Session, id string) (Config, error) + + // Update updates editable fields of the provided Config. + Update(ctx context.Context, session smqauthn.Session, cfg Config) error + + // UpdateCert updates an existing Config certificate and token. + // A non-nil error is returned to indicate operation failure. + UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (Config, error) + + // UpdateConnections updates list of Channels related to given Config. + UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) error + + // List returns subset of Configs with given search params that belong to the + // user identified by the given token. + List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error) + + // Remove removes Config with specified token that belongs to the user identified by the given token. + Remove(ctx context.Context, session smqauthn.Session, id string) error + + // Bootstrap returns Config to the Client with provided external ID using external key. + Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error) + + // ChangeState changes state of the Client with given client ID and domain ID. + ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state State) error + + // Methods RemoveConfig, UpdateChannel, and RemoveChannel are used as + // handlers for events. That's why these methods surpass ownership check. + + // UpdateChannelHandler updates Channel with data received from an event. + UpdateChannelHandler(ctx context.Context, channel Channel) error + + // RemoveConfigHandler removes Configuration with id received from an event. + RemoveConfigHandler(ctx context.Context, id string) error + + // RemoveChannelHandler removes Channel with id received from an event. + RemoveChannelHandler(ctx context.Context, id string) error + + // ConnectClientHandler changes state of the Config to active when connect event occurs. + ConnectClientHandler(ctx context.Context, channelID, clientID string) error + + // DisconnectClientHandler changes state of the Config to inactive when disconnect event occurs. + DisconnectClientHandler(ctx context.Context, channelID, clientID string) error +} + +// ConfigReader is used to parse Config into format which will be encoded +// as a JSON and consumed from the client side. The purpose of this interface +// is to provide convenient way to generate custom configuration response +// based on the specific Config which will be consumed by the client. +type ConfigReader interface { + ReadConfig(Config, bool) (any, error) +} + +type bootstrapService struct { + policies policies.Service + configs ConfigRepository + sdk mgsdk.SDK + encKey []byte + idProvider supermq.IDProvider +} + +// New returns new Bootstrap service. +func New(policyService policies.Service, configs ConfigRepository, sdk mgsdk.SDK, encKey []byte, idp supermq.IDProvider) Service { + return &bootstrapService{ + configs: configs, + sdk: sdk, + policies: policyService, + encKey: encKey, + idProvider: idp, + } +} + +func (bs bootstrapService) Add(ctx context.Context, session smqauthn.Session, token string, cfg Config) (Config, error) { + toConnect := bs.toIDList(cfg.Channels) + + // Check if channels exist. This is the way to prevent fetching channels that already exist. + existing, err := bs.configs.ListExisting(ctx, session.DomainID, toConnect) + if err != nil { + return Config{}, errors.Wrap(errCheckChannels, err) + } + + cfg.Channels, err = bs.connectionChannels(ctx, toConnect, bs.toIDList(existing), session.DomainID, token) + if err != nil { + return Config{}, errors.Wrap(errConnectionChannels, err) + } + + id := cfg.ClientID + mgClient, err := bs.client(ctx, session.DomainID, id, token) + if err != nil { + return Config{}, errors.Wrap(errClientNotFound, err) + } + + for _, channel := range cfg.Channels { + if channel.DomainID != mgClient.DomainID { + return Config{}, errors.Wrap(svcerr.ErrMalformedEntity, errNotInSameDomain) + } + } + + cfg.ClientID = mgClient.ID + cfg.DomainID = session.DomainID + cfg.State = Inactive + cfg.ClientSecret = mgClient.Credentials.Secret + + saved, err := bs.configs.Save(ctx, cfg, toConnect) + if err != nil { + // If id is empty, then a new client has been created function - bs.client(id, token) + // So, on bootstrap config save error , delete the newly created client. + if id == "" { + if errT := bs.sdk.DeleteClient(ctx, cfg.ClientID, cfg.DomainID, token); errT != nil { + err = errors.Wrap(err, errT) + } + } + return Config{}, errors.Wrap(ErrAddBootstrap, err) + } + + cfg.ClientID = saved + cfg.Channels = append(cfg.Channels, existing...) + + return cfg, nil +} + +func (bs bootstrapService) View(ctx context.Context, session smqauthn.Session, id string) (Config, error) { + cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, id) + if err != nil { + return Config{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return cfg, nil +} + +func (bs bootstrapService) Update(ctx context.Context, session smqauthn.Session, cfg Config) error { + cfg.DomainID = session.DomainID + if err := bs.configs.Update(ctx, cfg); err != nil { + return errors.Wrap(errUpdateConnections, err) + } + return nil +} + +func (bs bootstrapService) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (Config, error) { + cfg, err := bs.configs.UpdateCert(ctx, session.DomainID, clientID, clientCert, clientKey, caCert) + if err != nil { + return Config{}, errors.Wrap(errUpdateCert, err) + } + return cfg, nil +} + +func (bs bootstrapService) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) error { + cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, id) + if err != nil { + return errors.Wrap(errUpdateConnections, err) + } + + add, remove := bs.updateList(cfg, connections) + + // Check if channels exist. This is the way to prevent fetching channels that already exist. + existing, err := bs.configs.ListExisting(ctx, session.DomainID, connections) + if err != nil { + return errors.Wrap(errUpdateConnections, err) + } + + channels, err := bs.connectionChannels(ctx, connections, bs.toIDList(existing), session.DomainID, token) + if err != nil { + return errors.Wrap(errUpdateConnections, err) + } + + cfg.Channels = channels + var connect, disconnect []string + + if cfg.State == Active { + connect = add + disconnect = remove + } + + for _, c := range disconnect { + if err := bs.sdk.DisconnectClients(ctx, c, []string{id}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil { + if errors.Contains(err, repoerr.ErrNotFound) { + continue + } + return ErrClients + } + } + + for _, c := range connect { + conIDs := mgsdk.Connection{ + ChannelIDs: []string{c}, + ClientIDs: []string{id}, + Types: []string{"Publish", "Subscribe"}, + } + if err := bs.sdk.Connect(ctx, conIDs, session.DomainID, token); err != nil { + return ErrClients + } + } + if err := bs.configs.UpdateConnections(ctx, session.DomainID, id, channels, connections); err != nil { + return errors.Wrap(errUpdateConnections, err) + } + return nil +} + +func (bs bootstrapService) listClientIDs(ctx context.Context, userID string) ([]string, error) { + tids, err := bs.policies.ListAllObjects(ctx, policies.Policy{ + SubjectType: policies.UserType, + Subject: userID, + Permission: policies.ViewPermission, + ObjectType: policies.ClientType, + }) + if err != nil { + return nil, errors.Wrap(svcerr.ErrNotFound, err) + } + return tids.Policies, nil +} + +func (bs bootstrapService) List(ctx context.Context, session smqauthn.Session, filter Filter, offset, limit uint64) (ConfigsPage, error) { + if session.SuperAdmin { + return bs.configs.RetrieveAll(ctx, session.DomainID, []string{}, filter, offset, limit), nil + } + + // Handle non-admin users + clientIDs, err := bs.listClientIDs(ctx, session.DomainUserID) + if err != nil { + return ConfigsPage{}, errors.Wrap(svcerr.ErrNotFound, err) + } + + if len(clientIDs) == 0 { + return ConfigsPage{ + Total: 0, + Offset: offset, + Limit: limit, + Configs: []Config{}, + }, nil + } + + return bs.configs.RetrieveAll(ctx, session.DomainID, clientIDs, filter, offset, limit), nil +} + +func (bs bootstrapService) Remove(ctx context.Context, session smqauthn.Session, id string) error { + if err := bs.configs.Remove(ctx, session.DomainID, id); err != nil { + return errors.Wrap(errRemoveBootstrap, err) + } + return nil +} + +func (bs bootstrapService) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (Config, error) { + cfg, err := bs.configs.RetrieveByExternalID(ctx, externalID) + if err != nil { + return cfg, errors.Wrap(ErrBootstrap, err) + } + if secure { + dec, err := bs.dec(externalKey) + if err != nil { + return Config{}, errors.Wrap(ErrExternalKeySecure, err) + } + externalKey = dec + } + if cfg.ExternalKey != externalKey { + return Config{}, ErrExternalKey + } + + return cfg, nil +} + +func (bs bootstrapService) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state State) error { + cfg, err := bs.configs.RetrieveByID(ctx, session.DomainID, id) + if err != nil { + return errors.Wrap(errChangeState, err) + } + + if cfg.State == state { + return nil + } + + switch state { + case Active: + for _, c := range cfg.Channels { + if err := bs.sdk.ConnectClients(ctx, c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil { + // Ignore conflict errors as they indicate the connection already exists. + if errors.Contains(err, svcerr.ErrConflict) { + continue + } + return ErrClients + } + } + case Inactive: + for _, c := range cfg.Channels { + if err := bs.sdk.DisconnectClients(ctx, c.ID, []string{cfg.ClientID}, []string{"Publish", "Subscribe"}, session.DomainID, token); err != nil { + if errors.Contains(err, repoerr.ErrNotFound) { + continue + } + return ErrClients + } + } + } + if err := bs.configs.ChangeState(ctx, session.DomainID, id, state); err != nil { + return errors.Wrap(errChangeState, err) + } + return nil +} + +func (bs bootstrapService) UpdateChannelHandler(ctx context.Context, channel Channel) error { + if err := bs.configs.UpdateChannel(ctx, channel); err != nil { + return errors.Wrap(errUpdateChannel, err) + } + return nil +} + +func (bs bootstrapService) RemoveConfigHandler(ctx context.Context, id string) error { + if err := bs.configs.RemoveClient(ctx, id); err != nil { + return errors.Wrap(errRemoveConfig, err) + } + return nil +} + +func (bs bootstrapService) RemoveChannelHandler(ctx context.Context, id string) error { + if err := bs.configs.RemoveChannel(ctx, id); err != nil { + return errors.Wrap(errRemoveChannel, err) + } + return nil +} + +func (bs bootstrapService) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { + if err := bs.configs.ConnectClient(ctx, channelID, clientID); err != nil { + return errors.Wrap(errConnectClient, err) + } + return nil +} + +func (bs bootstrapService) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { + if err := bs.configs.DisconnectClient(ctx, channelID, clientID); err != nil { + return errors.Wrap(errDisconnectClient, err) + } + return nil +} + +// Method client retrieves SuperMQ Client creating one if an empty ID is passed. +func (bs bootstrapService) client(ctx context.Context, domainID, id, token string) (mgsdk.Client, error) { + // If Client ID is not provided, then create new client. + if id == "" { + id, err := bs.idProvider.ID() + if err != nil { + return mgsdk.Client{}, errors.Wrap(errCreateClient, err) + } + client, sdkErr := bs.sdk.CreateClient(ctx, mgsdk.Client{ID: id, Name: "Bootstrapped Client " + id}, domainID, token) + if sdkErr != nil { + return mgsdk.Client{}, errors.Wrap(errCreateClient, sdkErr) + } + return client, nil + } + // If Client ID is provided, then retrieve client + client, sdkErr := bs.sdk.Client(ctx, id, domainID, token) + if sdkErr != nil { + return mgsdk.Client{}, errors.Wrap(ErrClients, sdkErr) + } + return client, nil +} + +func (bs bootstrapService) connectionChannels(ctx context.Context, channels, existing []string, domainID, token string) ([]Channel, error) { + add := make(map[string]bool, len(channels)) + for _, ch := range channels { + add[ch] = true + } + + for _, ch := range existing { + if add[ch] { + delete(add, ch) + } + } + + var ret []Channel + for id := range add { + ch, err := bs.sdk.Channel(ctx, id, domainID, token) + if err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + + ret = append(ret, Channel{ + ID: ch.ID, + Name: ch.Name, + Metadata: ch.Metadata, + DomainID: ch.DomainID, + }) + } + + return ret, nil +} + +// Method updateList accepts config and channel IDs and returns three lists: +// 1) IDs of Channels to be added +// 2) IDs of Channels to be removed +// 3) IDs of common Channels for these two configs. +func (bs bootstrapService) updateList(cfg Config, connections []string) (add, remove []string) { + disconnect := make(map[string]bool, len(cfg.Channels)) + for _, c := range cfg.Channels { + disconnect[c.ID] = true + } + + for _, c := range connections { + if disconnect[c] { + // Don't disconnect common elements. + delete(disconnect, c) + continue + } + // Connect new elements. + add = append(add, c) + } + + for v := range disconnect { + remove = append(remove, v) + } + + return +} + +func (bs bootstrapService) toIDList(channels []Channel) []string { + var ret []string + for _, ch := range channels { + ret = append(ret, ch.ID) + } + + return ret +} + +func (bs bootstrapService) dec(in string) (string, error) { + ciphertext, err := hex.DecodeString(in) + if err != nil { + return "", err + } + block, err := aes.NewCipher(bs.encKey) + if err != nil { + return "", err + } + if len(ciphertext) < aes.BlockSize { + return "", err + } + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(ciphertext, ciphertext) + return string(ciphertext), nil +} diff --git a/bootstrap/service_test.go b/bootstrap/service_test.go new file mode 100644 index 000000000..4ee70aafc --- /dev/null +++ b/bootstrap/service_test.go @@ -0,0 +1,1113 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap_test + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "fmt" + "io" + "sort" + "testing" + + "github.com/absmach/supermq/bootstrap" + mocks "github.com/absmach/supermq/bootstrap/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + policysvc "github.com/absmach/supermq/pkg/policies" + policymocks "github.com/absmach/supermq/pkg/policies/mocks" + mgsdk "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + validToken = "validToken" + invalidToken = "invalid" + invalidDomainID = "invalid" + email = "test@example.com" + unknown = "unknown" + channelsNum = 3 + instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" +) + +var ( + encKey = []byte("1234567891011121") + domainID = testsutil.GenerateUUID(&testing.T{}) + channel = bootstrap.Channel{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: "name", + Metadata: map[string]any{"name": "value"}, + } + + config = bootstrap.Config{ + ClientID: testsutil.GenerateUUID(&testing.T{}), + ClientSecret: testsutil.GenerateUUID(&testing.T{}), + ExternalID: testsutil.GenerateUUID(&testing.T{}), + ExternalKey: testsutil.GenerateUUID(&testing.T{}), + Channels: []bootstrap.Channel{channel}, + Content: "config", + } +) + +var ( + boot *mocks.ConfigRepository + policies *policymocks.Service + sdk *sdkmocks.SDK +) + +func newService() bootstrap.Service { + boot = new(mocks.ConfigRepository) + policies = new(policymocks.Service) + sdk = new(sdkmocks.SDK) + idp := uuid.NewMock() + return bootstrap.New(policies, boot, sdk, encKey, idp) +} + +func enc(in []byte) ([]byte, error) { + block, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(in)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], in) + return ciphertext, nil +} + +func TestAdd(t *testing.T) { + svc := newService() + + neID := config + neID.ClientID = "non-existent" + + wrongChannels := config + ch := channel + ch.ID = "invalid" + wrongChannels.Channels = append(wrongChannels.Channels, ch) + + cases := []struct { + desc string + config bootstrap.Config + token string + session smqauthn.Session + userID string + domainID string + clientErr error + createClientErr error + channelErr error + listExistingErr error + saveErr error + deleteClientErr error + err error + }{ + { + desc: "add a new config", + config: config, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "add a config with an invalid ID", + config: neID, + token: validToken, + userID: validID, + domainID: domainID, + clientErr: errors.NewSDKError(svcerr.ErrNotFound), + err: svcerr.ErrNotFound, + }, + { + desc: "add a config with invalid list of channels", + config: wrongChannels, + token: validToken, + userID: validID, + domainID: domainID, + listExistingErr: svcerr.ErrMalformedEntity, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add empty config", + config: bootstrap.Config{}, + token: validToken, + userID: validID, + domainID: domainID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := sdk.On("Client", mock.Anything, tc.config.ClientID, mock.Anything, tc.token).Return(mgsdk.Client{ID: tc.config.ClientID, Credentials: mgsdk.ClientCredentials{Secret: tc.config.ClientSecret}}, tc.clientErr) + repoCall1 := sdk.On("CreateClient", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Client{}, tc.createClientErr) + repoCall2 := sdk.On("DeleteClient", mock.Anything, tc.config.ClientID, tc.domainID, tc.token).Return(tc.deleteClientErr) + repoCall3 := boot.On("ListExisting", context.Background(), tc.domainID, mock.Anything).Return(tc.config.Channels, tc.listExistingErr) + repoCall4 := boot.On("Save", context.Background(), mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) + _, err := svc.Add(context.Background(), tc.session, tc.token, tc.config) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() + repoCall4.Unset() + }) + } +} + +func TestView(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + configID string + userID string + domain string + clientDomain string + token string + session smqauthn.Session + retrieveErr error + clientErr error + channelErr error + err error + }{ + { + desc: "view an existing config", + configID: config.ClientID, + userID: validID, + clientDomain: domainID, + domain: domainID, + token: validToken, + err: nil, + }, + { + desc: "view a non-existing config", + configID: unknown, + userID: validID, + clientDomain: domainID, + domain: domainID, + token: validToken, + retrieveErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "view a config with invalid domain", + configID: config.ClientID, + userID: validID, + clientDomain: invalidDomainID, + domain: invalidDomainID, + token: validToken, + retrieveErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domain, DomainUserID: validID} + repoCall := boot.On("RetrieveByID", context.Background(), tc.clientDomain, tc.configID).Return(config, tc.retrieveErr) + _, err := svc.View(context.Background(), tc.session, tc.configID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestUpdate(t *testing.T) { + svc := newService() + + c := config + ch := channel + ch.ID = "2" + c.Channels = append(c.Channels, ch) + + modifiedCreated := c + modifiedCreated.Content = "new-config" + modifiedCreated.Name = "new name" + + nonExisting := c + nonExisting.ClientID = unknown + + cases := []struct { + desc string + config bootstrap.Config + token string + session smqauthn.Session + userID string + domainID string + updateErr error + err error + }{ + { + desc: "update a config with state Created", + config: modifiedCreated, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "update a non-existing config", + config: nonExisting, + token: validToken, + userID: validID, + domainID: domainID, + updateErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "update a config with update error", + config: c, + token: validToken, + userID: validID, + domainID: domainID, + updateErr: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := boot.On("Update", context.Background(), mock.Anything).Return(tc.updateErr) + err := svc.Update(context.Background(), tc.session, tc.config) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestUpdateCert(t *testing.T) { + svc := newService() + + c := config + ch := channel + ch.ID = "2" + c.Channels = append(c.Channels, ch) + + cases := []struct { + desc string + token string + session smqauthn.Session + userID string + domainID string + clientID string + clientCert string + clientKey string + caCert string + expectedConfig bootstrap.Config + authorizeErr error + authenticateErr error + updateErr error + err error + }{ + { + desc: "update certs for the valid config", + userID: validID, + domainID: domainID, + clientID: c.ClientID, + clientCert: "newCert", + clientKey: "newKey", + caCert: "newCert", + token: validToken, + expectedConfig: bootstrap.Config{ + Name: c.Name, + ClientSecret: c.ClientSecret, + Channels: c.Channels, + ExternalID: c.ExternalID, + ExternalKey: c.ExternalKey, + Content: c.Content, + State: c.State, + DomainID: c.DomainID, + ClientID: c.ClientID, + ClientCert: "newCert", + CACert: "newCert", + ClientKey: "newKey", + }, + err: nil, + }, + { + desc: "update cert for a non-existing config", + userID: validID, + domainID: domainID, + clientID: "empty", + clientCert: "newCert", + clientKey: "newKey", + caCert: "newCert", + token: validToken, + expectedConfig: bootstrap.Config{}, + updateErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := boot.On("UpdateCert", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.expectedConfig, tc.updateErr) + cfg, err := svc.UpdateCert(context.Background(), tc.session, tc.clientID, tc.clientCert, tc.clientKey, tc.caCert) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + sort.Slice(cfg.Channels, func(i, j int) bool { + return cfg.Channels[i].ID < cfg.Channels[j].ID + }) + sort.Slice(tc.expectedConfig.Channels, func(i, j int) bool { + return tc.expectedConfig.Channels[i].ID < tc.expectedConfig.Channels[j].ID + }) + assert.Equal(t, tc.expectedConfig, cfg, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.expectedConfig, cfg)) + repoCall.Unset() + }) + } +} + +func TestUpdateConnections(t *testing.T) { + svc := newService() + + c := config + c.State = bootstrap.Inactive + + activeConf := config + activeConf.State = bootstrap.Active + + ch := channel + + cases := []struct { + desc string + token string + session smqauthn.Session + id string + state bootstrap.State + userID string + domainID string + connections []string + updateErr error + clientErr error + channelErr error + retrieveErr error + listErr error + err error + }{ + { + desc: "update connections for config with state Inactive", + token: validToken, + userID: validID, + domainID: domainID, + id: c.ClientID, + state: c.State, + connections: []string{ch.ID}, + err: nil, + }, + { + desc: "update connections for config with state Active", + token: validToken, + userID: validID, + domainID: domainID, + id: activeConf.ClientID, + state: activeConf.State, + connections: []string{ch.ID}, + err: nil, + }, + { + desc: "update connections with invalid channels", + token: validToken, + userID: validID, + domainID: domainID, + id: c.ClientID, + connections: []string{"wrong"}, + channelErr: errors.NewSDKError(svcerr.ErrNotFound), + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + sdkCall := sdk.On("Channel", mock.Anything, mock.Anything, tc.domainID, tc.token).Return(mgsdk.Channel{}, tc.channelErr) + repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) + repoCall1 := boot.On("ListExisting", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(c.Channels, tc.listErr) + repoCall2 := boot.On("UpdateConnections", context.Background(), mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.updateErr) + err := svc.UpdateConnections(context.Background(), tc.session, tc.token, tc.id, tc.connections) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + sdkCall.Unset() + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestList(t *testing.T) { + svc := newService() + + numClients := 101 + var saved []bootstrap.Config + for i := 0; i < numClients; i++ { + c := config + c.ExternalID = testsutil.GenerateUUID(t) + c.ExternalKey = testsutil.GenerateUUID(t) + c.Name = fmt.Sprintf("%s-%d", config.Name, i) + if i == 41 { + c.State = bootstrap.Active + } + saved = append(saved, c) + } + cases := []struct { + desc string + config bootstrap.ConfigsPage + filter bootstrap.Filter + offset uint64 + limit uint64 + token string + session smqauthn.Session + userID string + domainID string + listObjectsResponse policysvc.PolicyPage + listObjectsErr error + retrieveErr error + err error + }{ + { + desc: "list configs successfully as super admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + token: validToken, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + userID: validID, + domainID: domainID, + offset: 0, + limit: 10, + err: nil, + }, + { + desc: "list configs with failed super admin check", + config: bootstrap.ConfigsPage{}, + filter: bootstrap.Filter{}, + token: validID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + userID: validID, + domainID: domainID, + listObjectsResponse: policysvc.PolicyPage{}, + offset: 0, + limit: 10, + err: nil, + }, + { + desc: "list configs successfully as domain admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + listObjectsResponse: policysvc.PolicyPage{}, + offset: 0, + limit: 10, + err: nil, + }, + { + desc: "list configs successfully as non admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 0, + Limit: 10, + Configs: saved[0:10], + }, + filter: bootstrap.Filter{}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + listObjectsResponse: policysvc.PolicyPage{Policies: []string{"test", "test"}}, + offset: 0, + limit: 10, + err: nil, + }, + { + desc: "list configs with specified name as super admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 0, + Limit: 100, + Configs: saved[95:96], + }, + filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "95"}}, + token: validToken, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + userID: validID, + domainID: domainID, + offset: 0, + limit: 100, + err: nil, + }, + { + desc: "list configs with specified name as domain admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 0, + Limit: 100, + Configs: saved[95:96], + }, + filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "95"}}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + offset: 0, + limit: 100, + err: nil, + }, + { + desc: "list configs with specified name as non admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 0, + Limit: 100, + Configs: saved[95:96], + }, + filter: bootstrap.Filter{PartialMatch: map[string]string{"name": "95"}}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + listObjectsResponse: policysvc.PolicyPage{Policies: []string{"test", "test"}}, + offset: 0, + limit: 100, + err: nil, + }, + { + desc: "list last page as super admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 95, + Limit: 10, + Configs: saved[95:], + }, + filter: bootstrap.Filter{}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + offset: 95, + limit: 10, + err: nil, + }, + { + desc: "list last page as domain admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 95, + Limit: 10, + Configs: saved[95:], + }, + filter: bootstrap.Filter{}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + offset: 95, + limit: 10, + err: nil, + }, + { + desc: "list last page as non admin", + config: bootstrap.ConfigsPage{ + Total: uint64(len(saved)), + Offset: 95, + Limit: 10, + Configs: saved[95:], + }, + filter: bootstrap.Filter{}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + listObjectsResponse: policysvc.PolicyPage{Policies: []string{"test", "test"}}, + offset: 95, + limit: 10, + err: nil, + }, + { + desc: "list configs with Active state as super admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 35, + Limit: 20, + Configs: []bootstrap.Config{saved[41]}, + }, + filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + offset: 35, + limit: 20, + err: nil, + }, + { + desc: "list configs with Active state as domain admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 35, + Limit: 20, + Configs: []bootstrap.Config{saved[41]}, + }, + filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID, SuperAdmin: true}, + offset: 35, + limit: 20, + err: nil, + }, + { + desc: "list configs with Active state as non admin", + config: bootstrap.ConfigsPage{ + Total: 1, + Offset: 35, + Limit: 20, + Configs: []bootstrap.Config{saved[41]}, + }, + filter: bootstrap.Filter{FullMatch: map[string]string{"state": bootstrap.Active.String()}}, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + listObjectsResponse: policysvc.PolicyPage{Policies: []string{"test", "test"}}, + offset: 35, + limit: 20, + err: nil, + }, + { + desc: "list configs with failed to list objects", + config: bootstrap.ConfigsPage{}, + filter: bootstrap.Filter{}, + offset: 0, + limit: 10, + token: validToken, + userID: validID, + domainID: domainID, + session: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: validID}, + listObjectsResponse: policysvc.PolicyPage{}, + listObjectsErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + policyCall := policies.On("ListAllObjects", mock.Anything, policysvc.Policy{ + SubjectType: policysvc.UserType, + Subject: tc.userID, + Permission: policysvc.ViewPermission, + ObjectType: policysvc.ClientType, + }).Return(tc.listObjectsResponse, tc.listObjectsErr) + repoCall := boot.On("RetrieveAll", context.Background(), mock.Anything, mock.Anything, tc.filter, tc.offset, tc.limit).Return(tc.config, tc.retrieveErr) + + result, err := svc.List(context.Background(), tc.session, tc.filter, tc.offset, tc.limit) + assert.ElementsMatch(t, tc.config.Configs, result.Configs, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.config.Configs, result.Configs)) + assert.Equal(t, tc.config.Total, result.Total, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.config.Total, result.Total)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + policyCall.Unset() + repoCall.Unset() + }) + } +} + +func TestRemove(t *testing.T) { + svc := newService() + + c := config + cases := []struct { + desc string + id string + token string + session smqauthn.Session + userID string + domainID string + removeErr error + err error + }{ + { + desc: "remove an existing config", + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "remove removed config", + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "remove a config with failed remove", + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + removeErr: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := boot.On("Remove", context.Background(), mock.Anything, mock.Anything).Return(tc.removeErr) + err := svc.Remove(context.Background(), tc.session, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestBootstrap(t *testing.T) { + svc := newService() + + c := config + e, err := enc([]byte(c.ExternalKey)) + assert.Nil(t, err, fmt.Sprintf("Encrypting external key expected to succeed: %s.\n", err)) + + cases := []struct { + desc string + config bootstrap.Config + externalKey string + externalID string + userID string + domainID string + err error + encrypted bool + }{ + { + desc: "bootstrap using invalid external id", + config: bootstrap.Config{}, + externalID: "invalid", + externalKey: c.ExternalKey, + userID: validID, + domainID: invalidDomainID, + err: svcerr.ErrNotFound, + encrypted: false, + }, + { + desc: "bootstrap using invalid external key", + config: bootstrap.Config{}, + externalID: c.ExternalID, + externalKey: "invalid", + userID: validID, + domainID: domainID, + err: bootstrap.ErrExternalKey, + encrypted: false, + }, + { + desc: "bootstrap an existing config", + config: c, + externalID: c.ExternalID, + externalKey: c.ExternalKey, + userID: validID, + domainID: domainID, + err: nil, + encrypted: false, + }, + { + desc: "bootstrap encrypted", + config: c, + externalID: c.ExternalID, + externalKey: hex.EncodeToString(e), + userID: validID, + domainID: domainID, + err: nil, + encrypted: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("RetrieveByExternalID", context.Background(), mock.Anything).Return(tc.config, tc.err) + config, err := svc.Bootstrap(context.Background(), tc.externalKey, tc.externalID, tc.encrypted) + assert.Equal(t, tc.config, config, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.config, config)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestChangeState(t *testing.T) { + svc := newService() + + c := config + cases := []struct { + desc string + state bootstrap.State + id string + token string + session smqauthn.Session + userID string + domainID string + retrieveErr error + connectErr errors.SDKError + disconenctErr error + stateErr error + err error + }{ + { + desc: "change state of non-existing config", + state: bootstrap.Active, + id: unknown, + token: validToken, + userID: validID, + domainID: domainID, + retrieveErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "change state to Active", + state: bootstrap.Active, + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "change state to current state", + state: bootstrap.Active, + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "change state to Inactive", + state: bootstrap.Inactive, + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + err: nil, + }, + { + desc: "change state with failed Connect", + state: bootstrap.Active, + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + connectErr: errors.NewSDKError(bootstrap.ErrClients), + err: bootstrap.ErrClients, + }, + { + desc: "change state with invalid state", + state: bootstrap.State(2), + id: c.ClientID, + token: validToken, + userID: validID, + domainID: domainID, + stateErr: svcerr.ErrMalformedEntity, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: validID} + repoCall := boot.On("RetrieveByID", context.Background(), tc.domainID, tc.id).Return(c, tc.retrieveErr) + sdkCall := sdk.On("ConnectClients", mock.Anything, mock.Anything, mock.Anything, []string{"Publish", "Subscribe"}, mock.Anything, tc.token).Return(tc.connectErr) + repoCall1 := boot.On("ChangeState", context.Background(), mock.Anything, mock.Anything, mock.Anything).Return(tc.stateErr) + err := svc.ChangeState(context.Background(), tc.session, tc.token, tc.id, tc.state) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + sdkCall.Unset() + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestUpdateChannelHandler(t *testing.T) { + svc := newService() + + ch := bootstrap.Channel{ + ID: channel.ID, + Name: "new name", + Metadata: map[string]any{"meta": "new"}, + } + + cases := []struct { + desc string + channel bootstrap.Channel + err error + }{ + { + desc: "update an existing channel", + channel: ch, + err: nil, + }, + { + desc: "update a non-existing channel", + channel: bootstrap.Channel{ID: ""}, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("UpdateChannel", context.Background(), mock.Anything).Return(tc.err) + err := svc.UpdateChannelHandler(context.Background(), tc.channel) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestRemoveChannelHandler(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove an existing channel", + id: config.Channels[0].ID, + err: nil, + }, + { + desc: "remove a non-existing channel", + id: "unknown", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("RemoveChannel", context.Background(), mock.Anything).Return(tc.err) + err := svc.RemoveChannelHandler(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestRemoveConfigHandler(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove an existing config", + id: config.ClientID, + err: nil, + }, + { + desc: "remove a non-existing channel", + id: "unknown", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("RemoveClient", context.Background(), mock.Anything).Return(tc.err) + err := svc.RemoveConfigHandler(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestConnectClientHandler(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + clientID string + channelID string + err error + }{ + { + desc: "connect", + channelID: channel.ID, + clientID: config.ClientID, + err: nil, + }, + { + desc: "connect connected", + channelID: channel.ID, + clientID: config.ClientID, + err: svcerr.ErrAddPolicies, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("ConnectClient", context.Background(), mock.Anything, mock.Anything).Return(tc.err) + err := svc.ConnectClientHandler(context.Background(), tc.channelID, tc.clientID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} + +func TestDisconnectClientsHandler(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + clientID string + channelID string + err error + }{ + { + desc: "disconnect", + channelID: channel.ID, + clientID: config.ClientID, + err: nil, + }, + { + desc: "disconnect disconnected", + channelID: channel.ID, + clientID: config.ClientID, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := boot.On("DisconnectClient", context.Background(), mock.Anything, mock.Anything).Return(tc.err) + err := svc.DisconnectClientHandler(context.Background(), tc.channelID, tc.clientID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + }) + } +} diff --git a/bootstrap/state.go b/bootstrap/state.go new file mode 100644 index 000000000..e97aedb0c --- /dev/null +++ b/bootstrap/state.go @@ -0,0 +1,26 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bootstrap + +import "strconv" + +const ( + // Inactive Client is created, but not able to exchange messages using SuperMQ. + Inactive State = iota + // Active Client is created, configured, and whitelisted. + Active +) + +// State represents corresponding SuperMQ Client state. The possible Config States +// as well as description of what that State represents are given in the table: +// | State | What it means | +// |----------+--------------------------------------------------------------------------------| +// | Inactive | Client is created, but isn't able to communicate over SuperMQ | +// | Active | Client is able to communicate using SuperMQ |. +type State int + +// String returns string representation of State. +func (s State) String() string { + return strconv.Itoa(int(s)) +} diff --git a/mqtt/tracing/doc.go b/bootstrap/tracing/doc.go similarity index 74% rename from mqtt/tracing/doc.go rename to bootstrap/tracing/doc.go index 51a9adb06..8a7079a42 100644 --- a/mqtt/tracing/doc.go +++ b/bootstrap/tracing/doc.go @@ -1,11 +1,11 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -// Package tracing provides tracing instrumentation for SuperMQ MQTT adapter service. +// Package tracing provides tracing instrumentation for SuperMQ Users service. // -// This package provides tracing middleware for SuperMQ MQTT adapter service. +// This package provides tracing middleware for SuperMQ Users service. // It can be used to trace incoming requests and add tracing capabilities to -// SuperMQ MQTT adapter service. +// SuperMQ Users service. // // For more details about tracing instrumentation for SuperMQ messaging refer // to the documentation at https://docs.supermq.absmach.eu/tracing/. diff --git a/bootstrap/tracing/tracing.go b/bootstrap/tracing/tracing.go new file mode 100644 index 000000000..c18530e03 --- /dev/null +++ b/bootstrap/tracing/tracing.go @@ -0,0 +1,182 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package tracing + +import ( + "context" + + "github.com/absmach/supermq/bootstrap" + smqauthn "github.com/absmach/supermq/pkg/authn" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +var _ bootstrap.Service = (*tracingMiddleware)(nil) + +type tracingMiddleware struct { + tracer trace.Tracer + svc bootstrap.Service +} + +// New returns a new bootstrap service with tracing capabilities. +func New(svc bootstrap.Service, tracer trace.Tracer) bootstrap.Service { + return &tracingMiddleware{tracer, svc} +} + +// Add traces the "Add" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) Add(ctx context.Context, session smqauthn.Session, token string, cfg bootstrap.Config) (bootstrap.Config, error) { + ctx, span := tm.tracer.Start(ctx, "svc_register_user", trace.WithAttributes( + attribute.String("client_id", cfg.ClientID), + attribute.String("domain_id ", cfg.DomainID), + attribute.String("name", cfg.Name), + attribute.String("external_id", cfg.ExternalID), + attribute.String("content", cfg.Content), + attribute.String("state", cfg.State.String()), + )) + defer span.End() + + return tm.svc.Add(ctx, session, token, cfg) +} + +// View traces the "View" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) View(ctx context.Context, session smqauthn.Session, id string) (bootstrap.Config, error) { + ctx, span := tm.tracer.Start(ctx, "svc_view_user", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.View(ctx, session, id) +} + +// Update traces the "Update" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) Update(ctx context.Context, session smqauthn.Session, cfg bootstrap.Config) error { + ctx, span := tm.tracer.Start(ctx, "svc_update_user", trace.WithAttributes( + attribute.String("name", cfg.Name), + attribute.String("content", cfg.Content), + attribute.String("client_id", cfg.ClientID), + attribute.String("domain_id ", cfg.DomainID), + )) + defer span.End() + + return tm.svc.Update(ctx, session, cfg) +} + +// UpdateCert traces the "UpdateCert" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) UpdateCert(ctx context.Context, session smqauthn.Session, clientID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { + ctx, span := tm.tracer.Start(ctx, "svc_update_cert", trace.WithAttributes( + attribute.String("client_id", clientID), + )) + defer span.End() + + return tm.svc.UpdateCert(ctx, session, clientID, clientCert, clientKey, caCert) +} + +// UpdateConnections traces the "UpdateConnections" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) UpdateConnections(ctx context.Context, session smqauthn.Session, token, id string, connections []string) error { + ctx, span := tm.tracer.Start(ctx, "svc_update_connections", trace.WithAttributes( + attribute.String("id", id), + attribute.StringSlice("connections", connections), + )) + defer span.End() + + return tm.svc.UpdateConnections(ctx, session, token, id, connections) +} + +// List traces the "List" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) List(ctx context.Context, session smqauthn.Session, filter bootstrap.Filter, offset, limit uint64) (bootstrap.ConfigsPage, error) { + ctx, span := tm.tracer.Start(ctx, "svc_list_users", trace.WithAttributes( + attribute.Int64("offset", int64(offset)), + attribute.Int64("limit", int64(limit)), + )) + defer span.End() + + return tm.svc.List(ctx, session, filter, offset, limit) +} + +// Remove traces the "Remove" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) Remove(ctx context.Context, session smqauthn.Session, id string) error { + ctx, span := tm.tracer.Start(ctx, "svc_remove_user", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.Remove(ctx, session, id) +} + +// Bootstrap traces the "Bootstrap" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) Bootstrap(ctx context.Context, externalKey, externalID string, secure bool) (bootstrap.Config, error) { + ctx, span := tm.tracer.Start(ctx, "svc_bootstrap_user", trace.WithAttributes( + attribute.String("external_key", externalKey), + attribute.String("external_id", externalID), + attribute.Bool("secure", secure), + )) + defer span.End() + + return tm.svc.Bootstrap(ctx, externalKey, externalID, secure) +} + +// ChangeState traces the "ChangeState" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) ChangeState(ctx context.Context, session smqauthn.Session, token, id string, state bootstrap.State) error { + ctx, span := tm.tracer.Start(ctx, "svc_change_state", trace.WithAttributes( + attribute.String("id", id), + attribute.String("state", state.String()), + )) + defer span.End() + + return tm.svc.ChangeState(ctx, session, token, id, state) +} + +// UpdateChannelHandler traces the "UpdateChannelHandler" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) UpdateChannelHandler(ctx context.Context, channel bootstrap.Channel) error { + ctx, span := tm.tracer.Start(ctx, "svc_update_channel_handler", trace.WithAttributes( + attribute.String("id", channel.ID), + attribute.String("name", channel.Name), + attribute.String("description", channel.Description), + )) + defer span.End() + + return tm.svc.UpdateChannelHandler(ctx, channel) +} + +// RemoveConfigHandler traces the "RemoveConfigHandler" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) RemoveConfigHandler(ctx context.Context, id string) error { + ctx, span := tm.tracer.Start(ctx, "svc_remove_config_handler", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.RemoveConfigHandler(ctx, id) +} + +// RemoveChannelHandler traces the "RemoveChannelHandler" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) RemoveChannelHandler(ctx context.Context, id string) error { + ctx, span := tm.tracer.Start(ctx, "svc_remove_channel_handler", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.RemoveChannelHandler(ctx, id) +} + +// ConnectClientHandler traces the "ConnectClientHandler" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) ConnectClientHandler(ctx context.Context, channelID, clientID string) error { + ctx, span := tm.tracer.Start(ctx, "svc_connect_client_handler", trace.WithAttributes( + attribute.String("channel_id", channelID), + attribute.String("client_id", clientID), + )) + defer span.End() + + return tm.svc.ConnectClientHandler(ctx, channelID, clientID) +} + +// DisconnectClientHandler traces the "DisconnectClientHandler" operation of the wrapped bootstrap.Service. +func (tm *tracingMiddleware) DisconnectClientHandler(ctx context.Context, channelID, clientID string) error { + ctx, span := tm.tracer.Start(ctx, "svc_disconnect_client_handler", trace.WithAttributes( + attribute.String("channel_id", channelID), + attribute.String("client_id", clientID), + )) + defer span.End() + + return tm.svc.DisconnectClientHandler(ctx, channelID, clientID) +} diff --git a/certs/api/grpc/client.go b/certs/api/grpc/client.go new file mode 100644 index 000000000..a8014b594 --- /dev/null +++ b/certs/api/grpc/client.go @@ -0,0 +1,93 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + "time" + + grpcCertsV1 "github.com/absmach/supermq/api/grpc/certs/v1" + "github.com/absmach/supermq/certs/api" + "github.com/go-kit/kit/endpoint" + kitgrpc "github.com/go-kit/kit/transport/grpc" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +const svcName = "certs.ClientService" + +type grpcClient struct { + timeout time.Duration + getEntityID endpoint.Endpoint + revokeCerts endpoint.Endpoint +} + +func NewClient(conn *grpc.ClientConn, timeout time.Duration) grpcCertsV1.CertsServiceClient { + return &grpcClient{ + getEntityID: kitgrpc.NewClient( + conn, + svcName, + "GetEntityID", + encodeGetEntityIDRequest, + decodeGetEntityIDResponse, + grpcCertsV1.EntityRes{}, + ).Endpoint(), + + revokeCerts: kitgrpc.NewClient( + conn, + svcName, + "RevokeCerts", + encodeRevokeCertsRequest, + decodeRevokeCertsResponse, + emptypb.Empty{}, + ).Endpoint(), + + timeout: timeout, + } +} + +func (c *grpcClient) GetEntityID(ctx context.Context, req *grpcCertsV1.EntityReq, _ ...grpc.CallOption) (*grpcCertsV1.EntityRes, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + res, err := c.getEntityID(ctx, req) + if err != nil { + return nil, err + } + return res.(*grpcCertsV1.EntityRes), nil +} + +func (c *grpcClient) RevokeCerts(ctx context.Context, req *grpcCertsV1.RevokeReq, _ ...grpc.CallOption) (*emptypb.Empty, error) { + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + res, err := c.revokeCerts(ctx, req) + if err != nil { + return nil, err + } + return res.(*emptypb.Empty), nil +} + +func encodeGetEntityIDRequest(_ context.Context, request any) (any, error) { + req := request.(*grpcCertsV1.EntityReq) + return &grpcCertsV1.EntityReq{ + SerialNumber: api.NormalizeSerialNumber(req.GetSerialNumber()), + }, nil +} + +func decodeGetEntityIDResponse(_ context.Context, response any) (any, error) { + res := response.(*grpcCertsV1.EntityRes) + return &grpcCertsV1.EntityRes{ + EntityId: res.EntityId, + }, nil +} + +func encodeRevokeCertsRequest(_ context.Context, request any) (any, error) { + req := request.(*grpcCertsV1.RevokeReq) + return &grpcCertsV1.RevokeReq{ + EntityId: req.GetEntityId(), + }, nil +} + +func decodeRevokeCertsResponse(_ context.Context, response any) (any, error) { + return &emptypb.Empty{}, nil +} diff --git a/certs/api/grpc/endpoint.go b/certs/api/grpc/endpoint.go new file mode 100644 index 000000000..b4ea52d8d --- /dev/null +++ b/certs/api/grpc/endpoint.go @@ -0,0 +1,46 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + + grpcCertsV1 "github.com/absmach/supermq/api/grpc/certs/v1" + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" + "google.golang.org/protobuf/types/known/emptypb" +) + +func getEntityEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(*grpcCertsV1.EntityReq) + + entityID, err := svc.GetEntityID(ctx, req.SerialNumber) + if err != nil { + return nil, err + } + + return &grpcCertsV1.EntityRes{EntityId: entityID}, nil + } +} + +func revokeCertsEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(*grpcCertsV1.RevokeReq) + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + err := svc.RevokeAll(ctx, session, req.EntityId) + if err != nil { + return nil, err + } + + return &emptypb.Empty{}, nil + } +} diff --git a/certs/api/grpc/server.go b/certs/api/grpc/server.go new file mode 100644 index 000000000..b318216c4 --- /dev/null +++ b/certs/api/grpc/server.go @@ -0,0 +1,93 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + + grpcCertsV1 "github.com/absmach/supermq/api/grpc/certs/v1" + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/certs/api/http" + "github.com/absmach/supermq/pkg/errors" + kitgrpc "github.com/go-kit/kit/transport/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +var _ grpcCertsV1.CertsServiceServer = (*grpcServer)(nil) + +type grpcServer struct { + getEntity kitgrpc.Handler + revokeCerts kitgrpc.Handler + grpcCertsV1.UnimplementedCertsServiceServer +} + +func NewServer(svc certs.Service) grpcCertsV1.CertsServiceServer { + return &grpcServer{ + getEntity: kitgrpc.NewServer( + (getEntityEndpoint(svc)), + decodeGetEntityReq, + encodeGetEntityRes, + ), + revokeCerts: kitgrpc.NewServer( + (revokeCertsEndpoint(svc)), + decodeRevokeCertsReq, + encodeRevokeCertsRes, + ), + } +} + +func decodeGetEntityReq(_ context.Context, req any) (any, error) { + return req.(*grpcCertsV1.EntityReq), nil +} + +func encodeGetEntityRes(_ context.Context, res any) (any, error) { + return res.(*grpcCertsV1.EntityRes), nil +} + +func decodeRevokeCertsReq(_ context.Context, req any) (any, error) { + return req.(*grpcCertsV1.RevokeReq), nil +} + +func encodeRevokeCertsRes(_ context.Context, res any) (any, error) { + return res.(*emptypb.Empty), nil +} + +// GetEntityID returns the entity ID for the given entity request. +func (g *grpcServer) GetEntityID(ctx context.Context, req *grpcCertsV1.EntityReq) (*grpcCertsV1.EntityRes, error) { + _, res, err := g.getEntity.ServeGRPC(ctx, req) + if err != nil { + return &grpcCertsV1.EntityRes{}, encodeError(err) + } + return res.(*grpcCertsV1.EntityRes), nil +} + +func (g *grpcServer) RevokeCerts(ctx context.Context, req *grpcCertsV1.RevokeReq) (*emptypb.Empty, error) { + _, res, err := g.revokeCerts.ServeGRPC(ctx, req) + if err != nil { + return &emptypb.Empty{}, encodeError(err) + } + return res.(*emptypb.Empty), nil +} + +func encodeError(err error) error { + switch { + case errors.Contains(err, nil): + return nil + case errors.Contains(err, certs.ErrMalformedEntity), + errors.Contains(err, http.ErrMissingEntityID): + return status.Error(codes.InvalidArgument, err.Error()) + case errors.Contains(err, certs.ErrNotFound): + return status.Error(codes.NotFound, err.Error()) + case errors.Contains(err, certs.ErrConflict): + return status.Error(codes.AlreadyExists, err.Error()) + case errors.Contains(err, certs.ErrCreateEntity), + errors.Contains(err, certs.ErrUpdateEntity), + errors.Contains(err, certs.ErrViewEntity): + return status.Error(codes.Internal, err.Error()) + default: + return status.Error(codes.Internal, "internal server error") + } +} diff --git a/certs/api/http/common.go b/certs/api/http/common.go new file mode 100644 index 000000000..e508412bb --- /dev/null +++ b/certs/api/http/common.go @@ -0,0 +1,97 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/errors" +) + +const ( + // ContentType represents JSON content type. + ContentType = "application/json" + OCSPType = "application/ocsp-response" +) + +// Response contains HTTP response specific methods. +type Response interface { + // Code returns HTTP response code. + Code() int + + // Headers returns map of HTTP headers with their values. + Headers() map[string]string + + // Empty indicates if HTTP response has content. + Empty() bool +} + +// EncodeError encodes an error response. +func EncodeError(_ context.Context, err error, w http.ResponseWriter) { + var wrapper error + if errors.Contains(err, ErrValidation) { + wrapper, err = errors.Unwrap(err) + } + + w.Header().Set("Content-Type", ContentType) + switch { + case errors.Contains(err, certs.ErrCertExpired): + err = unwrap(err) + w.WriteHeader(http.StatusForbidden) + + case errors.Contains(err, certs.ErrCertRevoked): + err = unwrap(err) + w.WriteHeader(http.StatusUnauthorized) + case errors.Contains(err, certs.ErrMalformedEntity), + errors.Contains(err, ErrMissingEntityID), + errors.Contains(err, ErrEmptySerialNo), + errors.Contains(err, ErrEmptyToken), + errors.Contains(err, ErrInvalidQueryParams), + errors.Contains(err, ErrValidation), + errors.Contains(err, ErrInvalidRequest): + err = unwrap(err) + w.WriteHeader(http.StatusBadRequest) + + case errors.Contains(err, certs.ErrCreateEntity), + errors.Contains(err, certs.ErrUpdateEntity), + errors.Contains(err, certs.ErrViewEntity), + errors.Contains(err, certs.ErrFailedCertCreation): + err = unwrap(err) + w.WriteHeader(http.StatusUnprocessableEntity) + + case errors.Contains(err, certs.ErrNotFound), + errors.Contains(err, certs.ErrRootCANotFound), + errors.Contains(err, certs.ErrIntermediateCANotFound): + err = unwrap(err) + w.WriteHeader(http.StatusNotFound) + + case errors.Contains(err, certs.ErrConflict): + err = unwrap(err) + w.WriteHeader(http.StatusConflict) + + default: + w.WriteHeader(http.StatusInternalServerError) + } + + if wrapper != nil { + err = errors.Wrap(wrapper, err) + } + + if errorVal, ok := err.(errors.Error); ok { + if err := json.NewEncoder(w).Encode(errorVal); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + } +} + +func unwrap(err error) error { + wrapper, err := errors.Unwrap(err) + if wrapper != nil { + return wrapper + } + return err +} diff --git a/certs/api/http/endpoint.go b/certs/api/http/endpoint.go new file mode 100644 index 000000000..a699be285 --- /dev/null +++ b/certs/api/http/endpoint.go @@ -0,0 +1,306 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "context" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/go-kit/kit/endpoint" +) + +func renewCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(viewReq) + if err := req.validate(); err != nil { + return renewCertRes{}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return renewCertRes{}, svcerr.ErrAuthentication + } + + cert, err := svc.RenewCert(ctx, session, req.id) + if err != nil { + return renewCertRes{}, err + } + + return renewCertRes{renewed: true, Certificate: cert}, nil + } +} + +func revokeCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(viewReq) + if err := req.validate(); err != nil { + return revokeCertRes{revoked: false}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return revokeCertRes{revoked: false}, svcerr.ErrAuthentication + } + + if err = svc.RevokeBySerial(ctx, session, req.id); err != nil { + return revokeCertRes{revoked: false}, err + } + + return revokeCertRes{revoked: true}, nil + } +} + +func deleteCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(deleteReq) + if err := req.validate(); err != nil { + return deleteCertRes{deleted: false}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return deleteCertRes{deleted: false}, svcerr.ErrAuthentication + } + + if err = svc.RevokeAll(ctx, session, req.entityID); err != nil { + return deleteCertRes{deleted: false}, err + } + + return deleteCertRes{deleted: true}, nil + } +} + +func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(issueCertReq) + if err := req.validate(); err != nil { + return issueCertRes{}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return issueCertRes{}, svcerr.ErrAuthentication + } + + cert, err := svc.IssueCert(ctx, session, req.entityID, req.TTL, req.IpAddrs, req.Options) + if err != nil { + return issueCertRes{}, err + } + + return issueCertRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Key: string(cert.Key), + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + Revoked: cert.Revoked, + issued: true, + }, nil + } +} + +func listCertsEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(listCertsReq) + if err := req.validate(); err != nil { + return listCertsRes{}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return listCertsRes{}, svcerr.ErrAuthentication + } + + certPage, err := svc.ListCerts(ctx, session, req.pm) + if err != nil { + return listCertsRes{}, err + } + + var crts []viewCertRes + for _, c := range certPage.Certificates { + crts = append(crts, viewCertRes{ + SerialNumber: c.SerialNumber, + Revoked: c.Revoked, + EntityID: c.EntityID, + ExpiryTime: c.ExpiryTime, + }) + } + + return listCertsRes{ + Total: certPage.Total, + Offset: certPage.Offset, + Limit: certPage.Limit, + Certificates: crts, + }, nil + } +} + +func viewCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(viewReq) + if err := req.validate(); err != nil { + return viewCertRes{}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return viewCertRes{}, svcerr.ErrAuthentication + } + + cert, err := svc.ViewCert(ctx, session, req.id) + if err != nil { + return viewCertRes{}, err + } + + return viewCertRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Key: string(cert.Key), + Revoked: cert.Revoked, + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + }, nil + } +} + +func ocspEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(ocspReq) + if err := req.validate(); err != nil { + return nil, err + } + + var resBytes []byte + if req.SerialNumber != "" { + resBytes, err = svc.OCSP(ctx, req.SerialNumber, nil) + if err != nil { + return nil, err + } + } else { + ocspRequestDER, err := req.req.Marshal() + if err != nil { + return nil, err + } + resBytes, err = svc.OCSP(ctx, "", ocspRequestDER) + if err != nil { + return nil, err + } + } + + return ocspRawRes{ + Data: resBytes, + }, nil + } +} + +func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(crlReq) + if err := req.validate(); err != nil { + return crlRes{}, err + } + + crlBytes, err := svc.GenerateCRL(ctx) + if err != nil { + return crlRes{}, err + } + + return crlRes{ + CrlBytes: crlBytes, + }, nil + } +} + +func downloadCAEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(downloadReq) + if err := req.validate(); err != nil { + return fileDownloadRes{}, err + } + + cert, err := svc.RetrieveCAChain(ctx) + if err != nil { + return fileDownloadRes{}, err + } + + return fileDownloadRes{ + Certificate: cert.Certificate, + Filename: "ca.zip", + ContentType: "application/zip", + }, nil + } +} + +func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(downloadReq) + if err := req.validate(); err != nil { + return viewCertRes{}, err + } + + cert, err := svc.RetrieveCAChain(ctx) + if err != nil { + return viewCertRes{}, err + } + + return viewCertRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Revoked: cert.Revoked, + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + }, nil + } +} + +func issueFromCSREndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(IssueFromCSRReq) + if err := req.validate(); err != nil { + return issueFromCSRRes{}, err + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return issueFromCSRRes{}, svcerr.ErrAuthentication + } + + cert, err := svc.IssueFromCSR(ctx, session, req.entityID, req.ttl, certs.CSR{CSR: req.CSR}) + if err != nil { + return issueFromCSRRes{}, err + } + + return issueFromCSRRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Revoked: cert.Revoked, + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + }, nil + } +} + +func issueFromCSRInternalEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (response any, err error) { + req := request.(IssueFromCSRInternalReq) + if err := req.validate(); err != nil { + return issueFromCSRRes{}, err + } + + cert, err := svc.IssueFromCSRInternal(ctx, req.entityID, req.ttl, certs.CSR{CSR: req.CSR}) + if err != nil { + return issueFromCSRRes{}, err + } + + return issueFromCSRRes{ + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Revoked: cert.Revoked, + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + }, nil + } +} diff --git a/certs/api/http/errors.go b/certs/api/http/errors.go new file mode 100644 index 000000000..9264abcd5 --- /dev/null +++ b/certs/api/http/errors.go @@ -0,0 +1,44 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import "github.com/absmach/supermq/pkg/errors" + +var ( + // ErrEmptySerialNo indicates that the serial number is empty. + ErrEmptySerialNo = errors.New("empty serial number provided") + + // ErrEmptyTTL indicates that the TTL is empty. + ErrEmptyTTL = errors.New("empty TTL provided") + + // ErrEmptyToken indicates that the token is empty. + ErrEmptyToken = errors.New("empty token provided") + + // ErrEmptyList indicates that entity data is empty. + ErrEmptyList = errors.New("empty list provided") + + // ErrMissingEntityID indicates missing entity ID. + ErrMissingEntityID = errors.New("missing entity ID") + + // ErrMissingCommonName indicates missing common name. + ErrMissingCommonName = errors.New("missing common name") + + // ErrUnsupportedContentType indicates unacceptable or lack of Content-Type. + ErrUnsupportedContentType = errors.New("unsupported content type") + + // ErrValidation indicates that an error was returned by the API. + ErrValidation = errors.New("something went wrong with the request") + + // ErrInvalidQueryParams indicates invalid query parameters. + ErrInvalidQueryParams = errors.New("invalid query parameters") + + // ErrInvalidRequest indicates that the request is invalid. + ErrInvalidRequest = errors.New("invalid request") + + // ErrMissingCSR indicates missing csr. + ErrMissingCSR = errors.New("missing CSR") + + // ErrMissingPrivKey indicates missing csr. + ErrMissingPrivKey = errors.New("missing private key") +) diff --git a/certs/api/http/requests.go b/certs/api/http/requests.go new file mode 100644 index 000000000..1fdb78c94 --- /dev/null +++ b/certs/api/http/requests.go @@ -0,0 +1,152 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "crypto/x509" + "encoding/pem" + "fmt" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/certs/api" + "github.com/absmach/supermq/pkg/errors" + "golang.org/x/crypto/ocsp" +) + +type downloadReq struct{} + +func (req downloadReq) validate() error { + return nil +} + +type viewReq struct { + id string +} + +func (req viewReq) validate() error { + if req.id == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrEmptySerialNo) + } + return nil +} + +type deleteReq struct { + entityID string +} + +func (req deleteReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + return nil +} + +type crlReq struct{} + +func (req crlReq) validate() error { + return nil +} + +type issueCertReq struct { + entityID string `json:"-"` + TTL string `json:"ttl"` + IpAddrs []string `json:"ip_addresses"` + Options certs.SubjectOptions `json:"options"` +} + +func (req issueCertReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + + if req.Options.CommonName == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCommonName) + } + + return nil +} + +type listCertsReq struct { + pm certs.PageMetadata +} + +func (req listCertsReq) validate() error { + return nil +} + +type ocspReq struct { + req *ocsp.Request + StatusParam string `json:"status,omitempty"` + SerialNumber string `json:"serial_number,omitempty"` + Certificate string `json:"certificate,omitempty"` +} + +func (req *ocspReq) validate() error { + if req.req == nil && req.SerialNumber == "" && req.Certificate == "" { + return certs.ErrMalformedEntity + } + + if req.Certificate != "" { + serialNumber, err := extractSerialFromCertContent(req.Certificate) + if err != nil { + return errors.Wrap(certs.ErrMalformedEntity, fmt.Errorf("failed to extract serial from certificate: %w", err)) + } + req.SerialNumber = serialNumber + } + + req.SerialNumber = api.NormalizeSerialNumber(req.SerialNumber) + + return nil +} + +func extractSerialFromCertContent(certContent string) (string, error) { + certData := []byte(certContent) + + block, _ := pem.Decode(certData) + if block == nil { + return "", fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "", fmt.Errorf("failed to parse certificate: %w", err) + } + + serialHex := cert.SerialNumber.Text(16) + return api.NormalizeSerialNumber(serialHex), nil +} + +type IssueFromCSRReq struct { + entityID string + ttl string + CSR []byte `json:"csr"` +} + +func (req IssueFromCSRReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + if len(req.CSR) == 0 { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCSR) + } + + return nil +} + +type IssueFromCSRInternalReq struct { + entityID string + ttl string + CSR []byte `json:"csr"` +} + +func (req IssueFromCSRInternalReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + } + if len(req.CSR) == 0 { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCSR) + } + + return nil +} diff --git a/certs/api/http/responses.go b/certs/api/http/responses.go new file mode 100644 index 000000000..9e0cd8102 --- /dev/null +++ b/certs/api/http/responses.go @@ -0,0 +1,205 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "net/http" + "time" + + "github.com/absmach/supermq/certs" +) + +var ( + _ Response = (*revokeCertRes)(nil) + _ Response = (*issueCertRes)(nil) + _ Response = (*renewCertRes)(nil) + _ Response = (*ocspRawRes)(nil) +) + +type renewCertRes struct { + renewed bool + Certificate certs.Certificate `json:"certificate,omitempty"` +} + +func (res renewCertRes) Code() int { + if res.renewed { + return http.StatusOK + } + + return http.StatusBadRequest +} + +func (res renewCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res renewCertRes) Empty() bool { + return false +} + +type revokeCertRes struct { + revoked bool +} + +func (res revokeCertRes) Code() int { + if res.revoked { + return http.StatusNoContent + } + + return http.StatusUnprocessableEntity +} + +func (res revokeCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res revokeCertRes) Empty() bool { + return true +} + +type deleteCertRes struct { + deleted bool +} + +func (res deleteCertRes) Code() int { + if res.deleted { + return http.StatusNoContent + } + + return http.StatusUnprocessableEntity +} + +func (res deleteCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteCertRes) Empty() bool { + return true +} + +type issueCertRes struct { + SerialNumber string `json:"serial_number"` + Certificate string `json:"certificate,omitempty"` + Key string `json:"key,omitempty"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time"` + EntityID string `json:"entity_id"` + issued bool +} + +func (res issueCertRes) Code() int { + if res.issued { + return http.StatusCreated + } + + return http.StatusBadRequest +} + +func (res issueCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res issueCertRes) Empty() bool { + return false +} + +type listCertsRes struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit,omitempty"` + Certificates []viewCertRes `json:"certificates,omitempty"` +} + +func (res listCertsRes) Code() int { + return http.StatusOK +} + +func (res listCertsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listCertsRes) Empty() bool { + return false +} + +type viewCertRes struct { + SerialNumber string `json:"serial_number,omitempty"` + Certificate string `json:"certificate,omitempty"` + Key string `json:"key,omitempty"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time,omitempty"` + EntityID string `json:"entity_id,omitempty"` +} + +func (res viewCertRes) Code() int { + return http.StatusOK +} + +func (res viewCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewCertRes) Empty() bool { + return false +} + +type crlRes struct { + CrlBytes []byte `json:"crl"` +} + +func (res crlRes) Code() int { + return http.StatusOK +} + +func (res crlRes) Headers() map[string]string { + return map[string]string{} +} + +func (res crlRes) Empty() bool { + return false +} + +type ocspRawRes struct { + Data []byte `json:"-"` +} + +func (res ocspRawRes) Code() int { + return http.StatusOK +} + +func (res ocspRawRes) Headers() map[string]string { + return map[string]string{} +} + +func (res ocspRawRes) Empty() bool { + return false +} + +type fileDownloadRes struct { + Certificate []byte `json:"certificate"` + PrivateKey []byte `json:"private_key"` + CA []byte `json:"ca"` + Filename string + ContentType string +} + +type issueFromCSRRes struct { + SerialNumber string `json:"serial_number"` + Certificate string `json:"certificate,omitempty"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time"` + EntityID string `json:"entity_id"` +} + +func (res issueFromCSRRes) Code() int { + return http.StatusOK +} + +func (res issueFromCSRRes) Headers() map[string]string { + return map[string]string{} +} + +func (res issueFromCSRRes) Empty() bool { + return false +} diff --git a/certs/api/http/transport.go b/certs/api/http/transport.go new file mode 100644 index 000000000..da5c1f2ae --- /dev/null +++ b/certs/api/http/transport.go @@ -0,0 +1,393 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "archive/zip" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strconv" + "strings" + + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/certs" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "golang.org/x/crypto/ocsp" +) + +const ( + offsetKey = "offset" + limitKey = "limit" + entityKey = "entity_id" + ocspStatusParam = "force_status" + entityIDParam = "entityID" + ttl = "ttl" + defOffset = 0 + defLimit = 10 +) + +func authMiddleware(expectedSecret string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := apiutil.ExtractBearerToken(r) + if token == "" { + EncodeError(r.Context(), apiutil.ErrBearerToken, w) + return + } + + if token != expectedSecret { + EncodeError(r.Context(), errors.Wrap(certs.ErrMalformedEntity, errors.New("invalid authentication token")), w) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc certs.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, instanceID string, secret string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(loggingErrorEncoder(logger, EncodeError)), + } + + mux := chi.NewRouter() + + mux.Route("/{domainID}", func(r chi.Router) { + r.Route("/certs", func(r chi.Router) { + r.Group(func(r chi.Router) { + r.Use(authn.Middleware()) + r.Post("/issue/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + issueCertEndpoint(svc), + decodeIssueCert, + api.EncodeResponse, + opts..., + ), "issue_cert").ServeHTTP) + r.Patch("/{id}/renew", otelhttp.NewHandler(kithttp.NewServer( + renewCertEndpoint(svc), + decodeView, + api.EncodeResponse, + opts..., + ), "renew_cert").ServeHTTP) + r.Patch("/{id}/revoke", otelhttp.NewHandler(kithttp.NewServer( + revokeCertEndpoint(svc), + decodeView, + api.EncodeResponse, + opts..., + ), "revoke_cert").ServeHTTP) + r.Delete("/{entityID}/delete", otelhttp.NewHandler(kithttp.NewServer( + deleteCertEndpoint(svc), + decodeDelete, + api.EncodeResponse, + opts..., + ), "delete_cert").ServeHTTP) + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listCertsEndpoint(svc), + decodeListCerts, + api.EncodeResponse, + opts..., + ), "list_certs").ServeHTTP) + r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( + viewCertEndpoint(svc), + decodeView, + api.EncodeResponse, + opts..., + ), "view_cert").ServeHTTP) + r.Route("/csrs", func(r chi.Router) { + r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + issueFromCSREndpoint(svc), + decodeIssueFromCSR, + api.EncodeResponse, + opts..., + ), "issue_from_csr").ServeHTTP) + }) + }) + }) + }) + + mux.Route("/certs", func(r chi.Router) { + r.Post("/ocsp", otelhttp.NewHandler(kithttp.NewServer( + ocspEndpoint(svc), + decodeOCSPRequest, + encodeOSCPResponse, + opts..., + ), "ocsp").ServeHTTP) + r.Get("/crl", otelhttp.NewHandler(kithttp.NewServer( + generateCRLEndpoint(svc), + decodeCRL, + api.EncodeResponse, + opts..., + ), "generate_crl").ServeHTTP) + r.Get("/view-ca", otelhttp.NewHandler(kithttp.NewServer( + viewCAEndpoint(svc), + decodeViewCA, + api.EncodeResponse, + opts..., + ), "view_ca").ServeHTTP) + r.Get("/download-ca", otelhttp.NewHandler(kithttp.NewServer( + downloadCAEndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "download_ca").ServeHTTP) + }) + + mux.Group(func(r chi.Router) { + r.Use(authMiddleware(secret)) + r.Post("/certs/csrs/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + issueFromCSRInternalEndpoint(svc), + decodeIssueFromCSRInternal, + api.EncodeResponse, + opts..., + ), "issue_from_csr_internal").ServeHTTP) + }) + + mux.Get("/health", certs.Health("certs", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeView(_ context.Context, r *http.Request) (any, error) { + req := viewReq{ + id: chi.URLParam(r, "id"), + } + return req, nil +} + +func decodeDelete(_ context.Context, r *http.Request) (any, error) { + req := deleteReq{ + entityID: chi.URLParam(r, "entityID"), + } + return req, nil +} + +func decodeCRL(_ context.Context, r *http.Request) (any, error) { + req := crlReq{} + return req, nil +} + +func decodeDownloadCA(_ context.Context, r *http.Request) (any, error) { + req := downloadReq{} + return req, nil +} + +func decodeViewCA(_ context.Context, r *http.Request) (any, error) { + req := downloadReq{} + return req, nil +} + +func decodeOCSPRequest(_ context.Context, r *http.Request) (any, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, errors.Wrap(certs.ErrMalformedEntity, err) + } + defer r.Body.Close() + + req, err := ocsp.ParseRequest(body) + if err != nil { + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "application/json") { + return decodeJsonOCSPRequest(body) + } + return nil, fmt.Errorf("invalid OCSP request: %w", err) + } + + request := ocspReq{ + req: req, + StatusParam: strings.TrimSpace(r.URL.Query().Get(ocspStatusParam)), + } + + return request, nil +} + +func decodeJsonOCSPRequest(body []byte) (any, error) { + var simple ocspReq + if err := json.Unmarshal(body, &simple); err != nil { + return nil, fmt.Errorf("invalid JSON OCSP request: %w", err) + } + + request := ocspReq{ + SerialNumber: simple.SerialNumber, + Certificate: simple.Certificate, + } + + return request, nil +} + +func decodeIssueCert(_ context.Context, r *http.Request) (any, error) { + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + req := issueCertReq{ + entityID: chi.URLParam(r, entityIDParam), + } + if err := json.Unmarshal(body, &req); err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) + } + + return req, nil +} + +func decodeListCerts(_ context.Context, r *http.Request) (any, error) { + o, err := readNumQuery(r, offsetKey, defOffset) + if err != nil { + return nil, err + } + + l, err := readNumQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + entity, err := readStringQuery(r, entityKey, "") + if err != nil { + return nil, err + } + + req := listCertsReq{ + pm: certs.PageMetadata{ + Offset: o, + Limit: l, + EntityID: entity, + }, + } + return req, nil +} + +func decodeIssueFromCSR(_ context.Context, r *http.Request) (any, error) { + t, err := readStringQuery(r, ttl, "") + if err != nil { + return nil, err + } + + req := IssueFromCSRReq{ + entityID: chi.URLParam(r, "entityID"), + ttl: t, + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to read request body")) + } + defer r.Body.Close() + + if err := json.Unmarshal(body, &req); err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to decode JSON")) + } + + return req, nil +} + +func decodeIssueFromCSRInternal(_ context.Context, r *http.Request) (any, error) { + t, err := readStringQuery(r, ttl, "") + if err != nil { + return nil, err + } + + req := IssueFromCSRInternalReq{ + entityID: chi.URLParam(r, "entityID"), + ttl: t, + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to read request body")) + } + defer r.Body.Close() + + if err := json.Unmarshal(body, &req); err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to decode JSON")) + } + + return req, nil +} + +func encodeOSCPResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + res := response.(ocspRawRes) + + w.Header().Set("Content-Type", OCSPType) + _, err := w.Write(res.Data) + return err +} + +func encodeCADownloadResponse(_ context.Context, w http.ResponseWriter, response any) error { + resp := response.(fileDownloadRes) + var buffer bytes.Buffer + zw := zip.NewWriter(&buffer) + + f, err := zw.Create("ca.crt") + if err != nil { + return err + } + + if _, err = f.Write(resp.Certificate); err != nil { + return err + } + + if err := zw.Close(); err != nil { + return err + } + + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", resp.Filename)) + w.Header().Set("Content-Type", resp.ContentType) + + _, err = w.Write(buffer.Bytes()) + + return err +} + +// loggingErrorEncoder is a go-kit error encoder logging decorator. +func loggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder { + return func(ctx context.Context, err error, w http.ResponseWriter) { + if errors.Contains(err, ErrValidation) { + logger.Error(err.Error()) + } + enc(ctx, err, w) + } +} + +// readStringQuery reads the value of string http query parameters for a given key. +func readStringQuery(r *http.Request, key, def string) (string, error) { + vals := r.URL.Query()[key] + if len(vals) > 1 { + return "", ErrInvalidQueryParams + } + + if len(vals) == 0 { + return def, nil + } + + return vals[0], nil +} + +// readNumQuery returns a numeric value. +func readNumQuery(r *http.Request, key string, def uint64) (uint64, error) { + vals := r.URL.Query()[key] + if len(vals) > 1 { + return 0, ErrInvalidQueryParams + } + if len(vals) == 0 { + return def, nil + } + val := vals[0] + + v, err := strconv.ParseUint(val, 10, 64) + if err != nil { + return 0, errors.Wrap(ErrInvalidQueryParams, err) + } + return v, nil +} diff --git a/certs/api/utils.go b/certs/api/utils.go new file mode 100644 index 000000000..d3f19ae1d --- /dev/null +++ b/certs/api/utils.go @@ -0,0 +1,32 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import "strings" + +var serialReplacer = strings.NewReplacer(":", "", " ", "") + +// NormalizeSerialNumber normalizes a serial number to use colon-separated hex format. +func NormalizeSerialNumber(serial string) string { + if len(serial) < 2 { + return serialReplacer.Replace(serial) + } + cleaned := serialReplacer.Replace(serial) + cleaned = strings.ToLower(cleaned) + if len(cleaned)%2 != 0 { + cleaned = "0" + cleaned + } + + capacity := len(cleaned) + (len(cleaned)/2 - 1) + var result strings.Builder + result.Grow(capacity) + for i := 0; i < len(cleaned); i += 2 { + if i > 0 { + result.WriteString(":") + } + result.WriteString(cleaned[i : i+2]) + } + + return result.String() +} diff --git a/certs/api/utils_test.go b/certs/api/utils_test.go new file mode 100644 index 000000000..fe64182de --- /dev/null +++ b/certs/api/utils_test.go @@ -0,0 +1,136 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "sync" + "testing" +) + +func TestNormalizeSerialNumber(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "already normalized", + input: "1a:2b:3c:4d", + expected: "1a:2b:3c:4d", + }, + { + name: "no separators", + input: "1a2b3c4d", + expected: "1a:2b:3c:4d", + }, + { + name: "with spaces", + input: "1a 2b 3c 4d", + expected: "1a:2b:3c:4d", + }, + { + name: "mixed separators", + input: "1a:2b 3c:4d", + expected: "1a:2b:3c:4d", + }, + { + name: "uppercase input", + input: "1A:2B:3C:4D", + expected: "1a:2b:3c:4d", + }, + { + name: "odd length - needs padding", + input: "1a2b3", + expected: "01:a2:b3", + }, + { + name: "single character", + input: "a", + expected: "a", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "long serial number", + input: "01:23:45:67:89:ab:cd:ef:12:34:56:78", + expected: "01:23:45:67:89:ab:cd:ef:12:34:56:78", + }, + { + name: "complex mixed format", + input: "01 23:45 67:89AB cd ef", + expected: "01:23:45:67:89:ab:cd:ef", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := NormalizeSerialNumber(tt.input) + if result != tt.expected { + t.Errorf("NormalizeSerialNumber(%q) = %q, expected %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestNormalizeSerialNumberConcurrent(t *testing.T) { + input := "1A:2B 3C:4D" + expected := "1a:2b:3c:4d" + + const numGoroutines = 100 + var wg sync.WaitGroup + results := make(chan string, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + result := NormalizeSerialNumber(input) + results <- result + }() + } + + wg.Wait() + close(results) + + for result := range results { + if result != expected { + t.Errorf("Concurrent execution failed: got %q, expected %q", result, expected) + } + } +} + +func BenchmarkNormalizeSerialNumber(b *testing.B) { + testCases := []struct { + name string + input string + }{ + {"short", "1a2b"}, + {"medium", "1a:2b:3c:4d:5e:6f"}, + {"long", "01:23:45:67:89:ab:cd:ef:12:34:56:78:90:ab:cd:ef"}, + {"mixed_format", "01 23:45 67:89AB cd ef 12:34"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + NormalizeSerialNumber(tc.input) + } + }) + } +} + +func BenchmarkNormalizeSerialNumberParallel(b *testing.B) { + input := "1A:2B 3C:4D:5E:6F:7G:8H" + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + NormalizeSerialNumber(input) + } + }) +} diff --git a/certs/certs.go b/certs/certs.go new file mode 100644 index 000000000..3600046e5 --- /dev/null +++ b/certs/certs.go @@ -0,0 +1,179 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package certs + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "net" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" +) + +type CertType int + +const ( + RootCA CertType = iota + IntermediateCA + ClientCert +) + +const ( + Root = "RootCA" + Inter = "IntermediateCA" + Client = "ClientCert" + Unknown = "Unknown" +) + +func (c CertType) String() string { + switch c { + case RootCA: + return Root + case IntermediateCA: + return Inter + case ClientCert: + return Client + default: + return Unknown + } +} + +func CertTypeFromString(s string) (CertType, error) { + switch s { + case Root: + return RootCA, nil + case Inter: + return IntermediateCA, nil + case Client: + return ClientCert, nil + default: + return -1, errors.New("unknown cert type") + } +} + +type CA struct { + Type CertType + Certificate *x509.Certificate + PrivateKey *rsa.PrivateKey + SerialNumber string +} + +type Certificate struct { + SerialNumber string `json:"serial_number"` + Certificate []byte `json:"certificate"` + Key []byte `json:"key"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time"` + EntityID string `json:"entity_id"` + Type CertType `json:"type"` + DownloadUrl string `json:"-"` +} + +type CertificatePage struct { + PageMetadata + Certificates []Certificate +} + +type PageMetadata struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset,omitempty"` + Limit uint64 `json:"limit,omitempty"` + EntityID string `json:"entity_id,omitempty"` +} + +type CSRMetadata struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DNSNames []string `json:"dns_names"` + IPAddresses []string `json:"ip_addresses"` + EmailAddresses []string `json:"email_addresses"` + ExtraExtensions []pkix.Extension `json:"extra_extensions"` +} + +type CSR struct { + CSR []byte `json:"csr,omitempty"` + PrivateKey []byte `json:"private_key,omitempty"` +} + +type CSRPage struct { + PageMetadata + CSRs []CSR `json:"csrs,omitempty"` +} + +type SubjectOptions struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DnsNames []string `json:"dns_names"` + IpAddresses []net.IP `json:"ip_addresses"` +} + +type Service interface { + // RenewCert renews a certificate by issuing a new certificate with the same parameters. + // Returns the new certificate with extended TTL and a new serial number. + RenewCert(ctx context.Context, session authn.Session, serialNumber string) (Certificate, error) + + // RevokeBySerial revokes a single certificate by its serial number. + RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error + + // RevokeAll revokes all certificates for a given entity ID. + RevokeAll(ctx context.Context, session authn.Session, entityID string) error + + // ViewCert retrieves a certificate record from the database. + ViewCert(ctx context.Context, session authn.Session, serialNumber string) (Certificate, error) + + // ListCerts retrieves the certificates from the database while applying filters. + ListCerts(ctx context.Context, session authn.Session, pm PageMetadata) (CertificatePage, error) + + // IssueCert issues a certificate from the database. + IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) + + // OCSP forwards OCSP requests to OpenBao's OCSP endpoint. + // If ocspRequestDER is provided, it will be used directly; otherwise, a request will be built from the serialNumber. + OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) + + // GetEntityID retrieves the entity ID for a certificate. + GetEntityID(ctx context.Context, serialNumber string) (string, error) + + // GenerateCRL creates cert revocation list. + GenerateCRL(ctx context.Context) ([]byte, error) + + // RetrieveCAChain retrieves the chain of CA i.e. root and intermediate cert concat together. + RetrieveCAChain(ctx context.Context) (Certificate, error) + + // IssueFromCSR creates a certificate from a given CSR. + IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr CSR) (Certificate, error) + + // IssueFromCSRInternal creates a certificate from a given CSR using agent token authentication. + IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) +} + +type Repository interface { + // SaveCertEntityMapping saves the mapping between certificate serial number and entity ID. + SaveCertEntityMapping(ctx context.Context, serialNumber, entityID string) error + + // GetEntityIDBySerial retrieves the entity ID for a given certificate serial number. + GetEntityIDBySerial(ctx context.Context, serialNumber string) (string, error) + + // ListCertsByEntityID lists all certificate serial numbers for a given entity ID. + ListCertsByEntityID(ctx context.Context, entityID string) ([]string, error) + + // RemoveCertEntityMapping removes the mapping between certificate and entity ID. + RemoveCertEntityMapping(ctx context.Context, serialNumber string) error +} diff --git a/certs/certs_test.go b/certs/certs_test.go new file mode 100644 index 000000000..839671901 --- /dev/null +++ b/certs/certs_test.go @@ -0,0 +1,569 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package certs_test + +import ( + "context" + "testing" + "time" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/certs/mocks" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const ( + serialNumber = "20:f4:bd:43:2c:c7:06:82:c7:f2:00:47:51:b6:81:6f:fa:c4:46:0c" + entityID = "c1a1daea-ce24-4847-b892-1780bf25b10c" + domainID = "domain-id" + testCertPEM = "-----BEGIN CERTIFICATE-----\nMIIEMjCCAxqgAwIBAgIUIPS9QyzHBoLH8gBHUbaBb/rERgwwDQYJKoZIhvcNAQEL\nBQAwgaAxDzANBgNVBAYTBkZSQU5DRTEOMAwGA1UECBMFUEFSSVMxDjAMBgNVBAcT\nBVBBUklTMRowGAYDVQQKExFBYnN0cmFjdCBNYWNoaW5lczEaMBgGA1UECxMRQWJz\ndHJhY3QgTWFjaGluZXMxNTAzBgNVBAMTLEFic3RyYWN0IE1hY2hpbmVzIFJvb3Qg\nQ2VydGlmaWNhdGUgQXV0aG9yaXR5MB4XDTI1MDgyNTExNTAyNFoXDTI1MDgyNTIx\nNTA1NFowDzENMAsGA1UEAxMEMDAwMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC\nAQoCggEBAMT4eHWFYUVAmQWC0bcgcBuBQjDVWdXD2WJWx8ybeC8vIwsGyCRMEem4\nlveP937ZjM3TTX0Nst4chF0L3WN0FTGTztwlqtpCK67AxcMEdGj54kIlVMAZexLz\nY4mQ5Oe/S4L4elv/ARHDV87BZ0m7oD1b2AC+8CBdm9aWcaD1RZk6qtzLRjs17ouY\nuslj5dN33VuzTYYUlPaTFjCY2nnebK0FLNjJkBVjoIlmT1Oo56uw9SQpLczk4PtL\nlVzeNKHGh0mx3g13tyNOAjKrMvxb7GTQ3tKsL6zZfiWggw4gROqjGQuCejAibfrr\nftN77YndLF4JYqiUZRCsZlRMSkpcSWMCAwEAAaOB8zCB8DAOBgNVHQ8BAf8EBAMC\nA6gwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQWBBSEDX9D\nU9O6ORjZOJzceZmE2yC93DAfBgNVHSMEGDAWgBSZCSNs3yScbg5YSiuN1VuS6o3g\nyTA7BggrBgEFBQcBAQQvMC0wKwYIKwYBBQUHMAKGH2h0dHA6Ly8xMjcuMC4wLjE6\nODIwMC92MS9wa2kvY2EwDwYDVR0RBAgwBocEwKhkFDAxBgNVHR8EKjAoMCagJKAi\nhiBodHRwOi8vMTI3LjAuMC4xOjgyMDAvdjEvcGtpL2NybDANBgkqhkiG9w0BAQsF\nAAOCAQEAK5fOOweOOJzWmjC0/6A9T/xnTOeXcwdp3gBmMNkaCs/qlh+3Dofo9vHS\nX1vitXbcqbMmJnXuRLkA+qTTlJvhVD8fa4RtixJZ5N0uDMPJ5FVv9tipSoqcnQH8\nwR4iPvrlQQr5hiBt/nfsaTLuDLZgMcKs5N30yHslJXfeLcWrawaQHpIddgavbgqM\n/9L/PoWM2hJknUyg7kis5SNejUGwOh/U1MUf1b18kaUKeK3Q4vhVHVz4foiRZ9M0\niw9xTj2rJJdOE/omE6qJFIfWIF0DuOCYt7z8TKhqKuTfNjmmiqlcgT14P6hniFkK\nl/5upJw86TWS8J0RXQJ1Nbw68EMEuQ==\n-----END CERTIFICATE-----" +) + +var ( + certValidityPeriod = time.Hour * 24 * 30 + testSession = smqauthn.Session{ + DomainUserID: entityID, + UserID: entityID, + DomainID: domainID, + } +) + +func TestIssueCert(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCases := []struct { + desc string + entityID string + ttl string + cert certs.Certificate + err error + agentErr error + repoErr error + expectedCert certs.Certificate + }{ + { + desc: "issue cert successfully", + entityID: "entityID", + ttl: "1h", + cert: certs.Certificate{ + SerialNumber: serialNumber, + }, + expectedCert: certs.Certificate{ + SerialNumber: serialNumber, + EntityID: "entityID", + }, + err: nil, + }, + { + desc: "failed agent issue cert", + entityID: "entityID", + ttl: "1h", + cert: certs.Certificate{}, + agentErr: errors.New("agent error"), + err: certs.ErrFailedCertCreation, + }, + { + desc: "failed repository save mapping", + entityID: "entityID", + ttl: "1h", + cert: certs.Certificate{ + SerialNumber: serialNumber, + }, + repoErr: errors.New("repo error"), + err: certs.ErrFailedCertCreation, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + options := certs.SubjectOptions{ + CommonName: tc.entityID, + } + agentCall := agent.On("Issue", tc.ttl, []string{}, options).Return(tc.cert, tc.agentErr) + repoCall := repo.On("SaveCertEntityMapping", mock.Anything, tc.cert.SerialNumber, tc.entityID).Return(tc.repoErr) + + cert, err := svc.IssueCert(context.Background(), testSession, tc.entityID, tc.ttl, []string{}, options) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedCert, cert) + } + + agentCall.Unset() + repoCall.Unset() + }) + } +} + +func TestRevokeBySerial(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCases := []struct { + desc string + serial string + agentErr error + err error + }{ + { + desc: "revoke cert by serial successfully", + serial: serialNumber, + err: nil, + }, + { + desc: "failed agent revoke", + serial: serialNumber, + agentErr: errors.New("agent error"), + err: certs.ErrUpdateEntity, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + agentCall := agent.On("Revoke", tc.serial).Return(tc.agentErr) + + err = svc.RevokeBySerial(context.Background(), testSession, tc.serial) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + } + + agentCall.Unset() + }) + } +} + +func TestRenewCert(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + newCert := certs.Certificate{ + SerialNumber: serialNumber, + EntityID: entityID, + Certificate: []byte(testCertPEM), + ExpiryTime: time.Now().Add(30 * 24 * time.Hour), + } + + testCases := []struct { + desc string + serial string + viewErr error + renewErr error + newCert certs.Certificate + revoked bool + expectedErr error + }{ + { + desc: "renew cert successfully", + serial: serialNumber, + newCert: newCert, + expectedErr: nil, + }, + { + desc: "failed agent renew", + serial: serialNumber, + renewErr: certs.ErrUpdateEntity, + newCert: certs.Certificate{}, + expectedErr: certs.ErrUpdateEntity, + }, + { + desc: "failed agent view", + serial: serialNumber, + viewErr: certs.ErrViewEntity, + newCert: certs.Certificate{}, + expectedErr: certs.ErrViewEntity, + }, + { + desc: "revoked certificate cannot be renewed", + serial: serialNumber, + newCert: certs.Certificate{}, + revoked: true, + expectedErr: certs.ErrCertRevoked, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + agentCall := agent.On("Renew", mock.Anything, certValidityPeriod.String()).Return(tc.newCert, tc.renewErr) + agentCall1 := agent.On("View", tc.serial).Return(certs.Certificate{Certificate: []byte(testCertPEM), Revoked: tc.revoked}, tc.viewErr) + + renewedCert, err := svc.RenewCert(context.Background(), testSession, tc.serial) + require.True(t, errors.Contains(err, tc.expectedErr), "expected error %v, got %v", tc.expectedErr, err) + if tc.expectedErr == nil { + require.Equal(t, tc.newCert, renewedCert) + } + agentCall1.Unset() + agentCall.Unset() + }) + } +} + +func TestGetEntityID(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCases := []struct { + desc string + serial string + entityID string + repoErr error + err error + }{ + { + desc: "get entity ID successfully", + serial: serialNumber, + entityID: "entity-123", + err: nil, + }, + { + desc: "error retrieving from repository", + serial: serialNumber, + repoErr: errors.New("not found"), + err: certs.ErrViewEntity, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("GetEntityIDBySerial", mock.Anything, tc.serial).Return(tc.entityID, tc.repoErr) + + entityID, err := svc.GetEntityID(context.Background(), tc.serial) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + require.Empty(t, entityID) + } else { + require.NoError(t, err) + require.Equal(t, tc.entityID, entityID) + } + + repoCall.Unset() + }) + } +} + +func TestListCerts(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + pageMetadata := certs.PageMetadata{Limit: 10, Offset: 0} + pageMetadataWithEntity := certs.PageMetadata{Limit: 10, Offset: 0, EntityID: "entity-123"} + + expectedCertPage := certs.CertificatePage{ + Certificates: []certs.Certificate{ + {SerialNumber: "123"}, + {SerialNumber: "456"}, + }, + PageMetadata: pageMetadata, + } + + testCases := []struct { + desc string + pm certs.PageMetadata + certPage certs.CertificatePage + serialNumbers []string + agentErr error + repoErr error + expectedResult certs.CertificatePage + err error + }{ + { + desc: "list certs successfully without entity filter", + pm: pageMetadata, + certPage: expectedCertPage, + expectedResult: expectedCertPage, + err: nil, + }, + { + desc: "list certs successfully with entity filter", + pm: pageMetadataWithEntity, + serialNumbers: []string{"123", "456"}, + expectedResult: certs.CertificatePage{ + Certificates: []certs.Certificate{ + {SerialNumber: "123", EntityID: "entity-123"}, + {SerialNumber: "456", EntityID: "entity-123"}, + }, + PageMetadata: certs.PageMetadata{ + Limit: 10, + Offset: 0, + EntityID: "entity-123", + Total: 2, // Set the total count + }, + }, + err: nil, + }, + { + desc: "error listing certs from agent", + pm: pageMetadata, + certPage: certs.CertificatePage{}, + agentErr: errors.New("agent error"), + err: certs.ErrViewEntity, + }, + { + desc: "error listing certs by entity from repo", + pm: pageMetadataWithEntity, + repoErr: errors.New("repo error"), + err: certs.ErrViewEntity, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + var agentCall, repoCall *mock.Call + var agentViewCalls []*mock.Call + + if tc.pm.EntityID != "" { + repoCall = repo.On("ListCertsByEntityID", mock.Anything, tc.pm.EntityID).Return(tc.serialNumbers, tc.repoErr) + if tc.repoErr == nil && len(tc.serialNumbers) > 0 { + for _, serial := range tc.serialNumbers { + viewCall := agent.On("View", serial).Return(certs.Certificate{SerialNumber: serial}, nil) + agentViewCalls = append(agentViewCalls, viewCall) + } + } + } else { + agentCall = agent.On("ListCerts", tc.pm).Return(tc.certPage, tc.agentErr) + if tc.agentErr == nil { + for _, cert := range tc.certPage.Certificates { + repo.On("GetEntityIDBySerial", mock.Anything, cert.SerialNumber).Return("", errors.New("not found")) + } + } + } + + certPage, err := svc.ListCerts(context.Background(), testSession, tc.pm) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedResult.Total, certPage.Total) + require.Len(t, certPage.Certificates, len(tc.expectedResult.Certificates)) + } + + if agentCall != nil { + agentCall.Unset() + } + if repoCall != nil { + repoCall.Unset() + } + for _, viewCall := range agentViewCalls { + viewCall.Unset() + } + }) + } +} + +func TestRevokeAll(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCases := []struct { + desc string + entityID string + serialNumbers []string + repoErr error + agentErr error + removeErr error + err error + }{ + { + desc: "revoke all certs successfully", + entityID: "entity-123", + serialNumbers: []string{"123", "456"}, + err: nil, + }, + { + desc: "error listing certs by entity", + entityID: "entity-123", + repoErr: errors.New("repo error"), + err: certs.ErrViewEntity, + }, + { + desc: "error revoking cert", + entityID: "entity-123", + serialNumbers: []string{"123"}, + agentErr: errors.New("agent error"), + err: certs.ErrUpdateEntity, + }, + { + desc: "no certificates found for entity", + entityID: "entity-123", + err: certs.ErrNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("ListCertsByEntityID", mock.Anything, tc.entityID).Return(tc.serialNumbers, tc.repoErr) + + var agentCalls, removeCalls []*mock.Call + if tc.repoErr == nil && len(tc.serialNumbers) > 0 { + for _, serial := range tc.serialNumbers { + agentCall := agent.On("Revoke", serial).Return(tc.agentErr) + agentCalls = append(agentCalls, agentCall) + + if tc.agentErr == nil { + removeCall := repo.On("RemoveCertEntityMapping", mock.Anything, serial).Return(tc.removeErr) + removeCalls = append(removeCalls, removeCall) + } + } + } + + err := svc.RevokeAll(context.Background(), testSession, tc.entityID) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + } + + // Clean up mocks + repoCall.Unset() + for _, call := range agentCalls { + call.Unset() + } + for _, call := range removeCalls { + call.Unset() + } + }) + } +} + +func TestIssueFromCSR(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCSR := certs.CSR{ + CSR: []byte("test-csr-data"), + } + + testCases := []struct { + desc string + entityID string + ttl string + csr certs.CSR + cert certs.Certificate + expectedCert certs.Certificate + agentErr error + repoErr error + err error + }{ + { + desc: "issue cert from CSR successfully", + entityID: "entity-123", + ttl: "1h", + csr: testCSR, + cert: certs.Certificate{ + SerialNumber: serialNumber, + }, + expectedCert: certs.Certificate{ + SerialNumber: serialNumber, + EntityID: "entity-123", + }, + err: nil, + }, + { + desc: "failed agent sign CSR", + entityID: "entity-123", + ttl: "1h", + csr: testCSR, + cert: certs.Certificate{}, + agentErr: errors.New("agent error"), + err: certs.ErrFailedCertCreation, + }, + { + desc: "failed repository save mapping", + entityID: "entity-123", + ttl: "1h", + csr: testCSR, + cert: certs.Certificate{ + SerialNumber: serialNumber, + }, + repoErr: errors.New("repo error"), + err: certs.ErrFailedCertCreation, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + agentCall := agent.On("SignCSR", tc.csr.CSR, tc.ttl).Return(tc.cert, tc.agentErr) + var repoCall *mock.Call + if tc.agentErr == nil { + repoCall = repo.On("SaveCertEntityMapping", mock.Anything, tc.cert.SerialNumber, tc.entityID).Return(tc.repoErr) + } + + cert, err := svc.IssueFromCSR(context.Background(), testSession, tc.entityID, tc.ttl, tc.csr) + if tc.err != nil { + require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedCert, cert) + } + + agentCall.Unset() + if repoCall != nil { + repoCall.Unset() + } + }) + } +} + +func TestGenerateCRL(t *testing.T) { + agent := new(mocks.Agent) + repo := new(mocks.Repository) + svc, err := certs.NewService(context.Background(), agent, repo) + require.NoError(t, err) + + testCases := []struct { + desc string + crlBytes []byte + agentErr error + err error + }{ + { + desc: "generate CRL successfully", + crlBytes: []byte("test-crl-data"), + err: nil, + }, + { + desc: "failed with agent error", + crlBytes: nil, + agentErr: errors.New("agent error"), + err: certs.ErrFailedCertCreation, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + agentCall := agent.On("GetCRL").Return(tc.crlBytes, tc.agentErr) + + crlBytes, err := svc.GenerateCRL(context.Background()) + if tc.err != nil { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.crlBytes, crlBytes) + } + + agentCall.Unset() + }) + } +} diff --git a/certs/client/client.go b/certs/client/client.go new file mode 100644 index 000000000..629ba9c2d --- /dev/null +++ b/certs/client/client.go @@ -0,0 +1,28 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "context" + + grpcCertsV1 "github.com/absmach/supermq/api/grpc/certs/v1" + grpc "github.com/absmach/supermq/pkg/grpcclient" + grpchealth "google.golang.org/grpc/health/grpc_health_v1" +) + +func NewCertsClient(ctx context.Context, cfg grpc.Config) (grpc.Handler, grpcCertsV1.CertsServiceClient, error) { + client, err := grpc.NewHandler(cfg) + if err != nil { + return nil, nil, err + } + + health := grpchealth.NewHealthClient(client.Connection()) + resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{ + Service: "certs", + }) + if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING { + return nil, nil, grpc.ErrSvcNotServing + } + return client, grpcCertsV1.NewCertsServiceClient(client.Connection()), nil +} diff --git a/certs/health.go b/certs/health.go new file mode 100644 index 000000000..bd1dabf94 --- /dev/null +++ b/certs/health.go @@ -0,0 +1,76 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package certs + +import ( + "encoding/json" + "net/http" +) + +const ( + contentType = "Content-Type" + contentTypeJSON = "application/health+json" + svcStatus = "pass" + description = " service" +) + +var ( + // Version represents the last service git tag in git history. + // It's meant to be set using go build ldflags. + Version = "0.0.0" + // Commit represents the service git commit hash. + // It's meant to be set using go build ldflags. + + Commit = "ffffffff" + // BuildTime represetns the service build time. + // It's meant to be set using go build ldflags. + BuildTime = "1970-01-01_00:00:00" +) + +// HealthInfo contains version endpoint response. +type HealthInfo struct { + // Status contains service status. + Status string `json:"status"` + + // Version contains current service version. + Version string `json:"version"` + + // Commit represents the git hash commit. + Commit string `json:"commit"` + + // Description contains service description. + Description string `json:"description"` + + // BuildTime contains service build time. + BuildTime string `json:"build_time"` + + // InstanceID contains the ID of the current service instance + InstanceID string `json:"instance_id"` +} + +// Health exposes an HTTP handler for retrieving service health. +func Health(service, instanceID string) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(contentType, contentTypeJSON) + if r.Method != http.MethodGet && r.Method != http.MethodHead { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + res := HealthInfo{ + Status: svcStatus, + Version: Version, + Commit: Commit, + Description: service + description, + BuildTime: BuildTime, + InstanceID: instanceID, + } + + w.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(w).Encode(res); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + }) +} diff --git a/certs/middleware/authorization.go b/certs/middleware/authorization.go new file mode 100644 index 000000000..d2e225e66 --- /dev/null +++ b/certs/middleware/authorization.go @@ -0,0 +1,111 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + crt "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/policies" +) + +var _ crt.Service = (*authorizationMiddleware)(nil) + +type authorizationMiddleware struct { + authz authz.Authorization + svc crt.Service +} + +func AuthorizationMiddleware(authz authz.Authorization, svc crt.Service) crt.Service { + return &authorizationMiddleware{authz, svc} +} + +func (am *authorizationMiddleware) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (crt.Certificate, error) { + if err := am.checkUserDomainPermission(ctx, session, policies.MembershipPermission); err != nil { + return crt.Certificate{}, err + } + return am.svc.RenewCert(ctx, session, serialNumber) +} + +func (am *authorizationMiddleware) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error { + if err := am.checkUserDomainPermission(ctx, session, policies.AdminPermission); err != nil { + return err + } + return am.svc.RevokeBySerial(ctx, session, serialNumber) +} + +func (am *authorizationMiddleware) RevokeAll(ctx context.Context, session authn.Session, entityID string) error { + if err := am.checkUserDomainPermission(ctx, session, policies.AdminPermission); err != nil { + return err + } + return am.svc.RevokeAll(ctx, session, entityID) +} + +func (am *authorizationMiddleware) IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, options crt.SubjectOptions) (crt.Certificate, error) { + if err := am.checkUserDomainPermission(ctx, session, policies.MembershipPermission); err != nil { + return crt.Certificate{}, err + } + return am.svc.IssueCert(ctx, session, entityID, ttl, ipAddrs, options) +} + +func (am *authorizationMiddleware) ListCerts(ctx context.Context, session authn.Session, pm crt.PageMetadata) (crt.CertificatePage, error) { + if err := am.checkUserDomainPermission(ctx, session, policies.MembershipPermission); err != nil { + return crt.CertificatePage{}, err + } + return am.svc.ListCerts(ctx, session, pm) +} + +func (am *authorizationMiddleware) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (crt.Certificate, error) { + if err := am.checkUserDomainPermission(ctx, session, policies.MembershipPermission); err != nil { + return crt.Certificate{}, err + } + return am.svc.ViewCert(ctx, session, serialNumber) +} + +func (am *authorizationMiddleware) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + return am.svc.GetEntityID(ctx, serialNumber) +} + +func (am *authorizationMiddleware) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) { + return am.svc.OCSP(ctx, serialNumber, ocspRequestDER) +} + +func (am *authorizationMiddleware) GenerateCRL(ctx context.Context) ([]byte, error) { + return am.svc.GenerateCRL(ctx) +} + +func (am *authorizationMiddleware) RetrieveCAChain(ctx context.Context) (crt.Certificate, error) { + return am.svc.RetrieveCAChain(ctx) +} + +func (am *authorizationMiddleware) IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr crt.CSR) (crt.Certificate, error) { + if err := am.checkUserDomainPermission(ctx, session, policies.MembershipPermission); err != nil { + return crt.Certificate{}, err + } + return am.svc.IssueFromCSR(ctx, session, entityID, ttl, csr) +} + +func (am *authorizationMiddleware) IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr crt.CSR) (crt.Certificate, error) { + return am.svc.IssueFromCSRInternal(ctx, entityID, ttl, csr) +} + +func (am *authorizationMiddleware) checkUserDomainPermission(ctx context.Context, session authn.Session, permission string) error { + req := authz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: permission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + if err := am.authz.Authorize(ctx, req, nil); err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + return nil +} diff --git a/certs/middleware/logging.go b/certs/middleware/logging.go new file mode 100644 index 000000000..040184328 --- /dev/null +++ b/certs/middleware/logging.go @@ -0,0 +1,174 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" +) + +var _ certs.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc certs.Service +} + +// LoggingMiddleware adds logging facilities to the core service. +func LoggingMiddleware(svc certs.Service, logger *slog.Logger) certs.Service { + return &loggingMiddleware{logger, svc} +} + +func (lm *loggingMiddleware) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (cert certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method renew_cert for cert %s took %s to complete", serialNumber, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s and returned new cert %s.", message, cert.SerialNumber)) + }(time.Now()) + return lm.svc.RenewCert(ctx, session, serialNumber) +} + +func (lm *loggingMiddleware) IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method issue_cert for entity %s took %s to complete", entityID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.IssueCert(ctx, session, entityID, ttl, ipAddrs, options) +} + +func (lm *loggingMiddleware) ListCerts(ctx context.Context, session authn.Session, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method list_certs took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.ListCerts(ctx, session, pm) +} + +func (lm *loggingMiddleware) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method revoke_by_serial took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RevokeBySerial(ctx, session, serialNumber) +} + +func (lm *loggingMiddleware) RevokeAll(ctx context.Context, session authn.Session, entityId string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method revoke_all took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RevokeAll(ctx, session, entityId) +} + +func (lm *loggingMiddleware) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (cert certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method view_cert for serial number %s took %s to complete", serialNumber, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.ViewCert(ctx, session, serialNumber) +} + +func (lm *loggingMiddleware) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) (ocspBytes []byte, err error) { + defer func(begin time.Time) { + requestType := "serial_number" + if len(ocspRequestDER) > 0 { + requestType = "raw_request" + } + message := fmt.Sprintf("Method ocsp (%s) for serial number %s took %s to complete", requestType, serialNumber, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.OCSP(ctx, serialNumber, ocspRequestDER) +} + +func (lm *loggingMiddleware) GetEntityID(ctx context.Context, serialNumber string) (entityID string, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method get_entity_id for serial number %s took %s to complete", serialNumber, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.GetEntityID(ctx, serialNumber) +} + +func (lm *loggingMiddleware) GenerateCRL(ctx context.Context) (crl []byte, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method generate_crl took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.GenerateCRL(ctx) +} + +func (lm *loggingMiddleware) RetrieveCAChain(ctx context.Context) (cert certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method retrieve_ca_chain took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(fmt.Sprintf("%s without errors.", message)) + }(time.Now()) + return lm.svc.RetrieveCAChain(ctx) +} + +func (lm *loggingMiddleware) IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method issue_from_csr took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.IssueFromCSR(ctx, session, entityID, ttl, csr) +} + +func (lm *loggingMiddleware) IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method issue_from_csr_internal for entity %s took %s to complete", entityID, time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.IssueFromCSRInternal(ctx, entityID, ttl, csr) +} diff --git a/certs/middleware/metrics.go b/certs/middleware/metrics.go new file mode 100644 index 000000000..763c6c8bd --- /dev/null +++ b/certs/middleware/metrics.go @@ -0,0 +1,127 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + "github.com/go-kit/kit/metrics" +) + +var _ certs.Service = (*metricsMiddleware)(nil) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + svc certs.Service +} + +// MetricsMiddleware instruments core service by tracking request count and latency. +func MetricsMiddleware(svc certs.Service, counter metrics.Counter, latency metrics.Histogram) certs.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + svc: svc, + } +} + +func (mm *metricsMiddleware) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "renew_certificate").Add(1) + mm.latency.With("method", "renew_certificate").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RenewCert(ctx, session, serialNumber) +} + +func (mm *metricsMiddleware) IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "issue_certificate").Add(1) + mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.IssueCert(ctx, session, entityID, ttl, ipAddrs, options) +} + +func (mm *metricsMiddleware) ListCerts(ctx context.Context, session authn.Session, pm certs.PageMetadata) (certs.CertificatePage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_certificates").Add(1) + mm.latency.With("method", "list_certificates").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ListCerts(ctx, session, pm) +} + +func (mm *metricsMiddleware) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error { + defer func(begin time.Time) { + mm.counter.With("method", "revoke_by_serial").Add(1) + mm.latency.With("method", "revoke_by_serial").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RevokeBySerial(ctx, session, serialNumber) +} + +func (mm *metricsMiddleware) RevokeAll(ctx context.Context, session authn.Session, entityId string) error { + defer func(begin time.Time) { + mm.counter.With("method", "revoke_all").Add(1) + mm.latency.With("method", "revoke_all").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RevokeAll(ctx, session, entityId) +} + +func (mm *metricsMiddleware) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "view_certificate").Add(1) + mm.latency.With("method", "view_certificate").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.ViewCert(ctx, session, serialNumber) +} + +func (mm *metricsMiddleware) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) { + defer func(begin time.Time) { + mm.counter.With("method", "ocsp").Add(1) + mm.latency.With("method", "ocsp").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.OCSP(ctx, serialNumber, ocspRequestDER) +} + +func (mm *metricsMiddleware) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_entity_id").Add(1) + mm.latency.With("method", "get_entity_id").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.GetEntityID(ctx, serialNumber) +} + +func (mm *metricsMiddleware) GenerateCRL(ctx context.Context) ([]byte, error) { + defer func(begin time.Time) { + mm.counter.With("method", "generate_crl").Add(1) + mm.latency.With("method", "generate_crl").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.GenerateCRL(ctx) +} + +func (mm *metricsMiddleware) RetrieveCAChain(ctx context.Context) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "retrieve_ca_chain").Add(1) + mm.latency.With("method", "retrieve_ca_chain").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RetrieveCAChain(ctx) +} + +func (mm *metricsMiddleware) IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "issue_from_csr").Add(1) + mm.latency.With("method", "issue_from_csr").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.IssueFromCSR(ctx, session, entityID, ttl, csr) +} + +func (mm *metricsMiddleware) IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "issue_from_csr_internal").Add(1) + mm.latency.With("method", "issue_from_csr_internal").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.IssueFromCSRInternal(ctx, entityID, ttl, csr) +} diff --git a/certs/middleware/tracing.go b/certs/middleware/tracing.go new file mode 100644 index 000000000..9dae63773 --- /dev/null +++ b/certs/middleware/tracing.go @@ -0,0 +1,96 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + "go.opentelemetry.io/otel/trace" +) + +var _ certs.Service = (*tracingMiddleware)(nil) + +type tracingMiddleware struct { + tracer trace.Tracer + svc certs.Service +} + +// New returns a new auth service with tracing capabilities. +func New(svc certs.Service, tracer trace.Tracer) certs.Service { + return &tracingMiddleware{tracer, svc} +} + +func (tm *tracingMiddleware) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "renew_cert") + defer span.End() + return tm.svc.RenewCert(ctx, session, serialNumber) +} + +func (tm *tracingMiddleware) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_by_serial") + defer span.End() + return tm.svc.RevokeBySerial(ctx, session, serialNumber) +} + +func (tm *tracingMiddleware) RevokeAll(ctx context.Context, session authn.Session, entityID string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_all") + defer span.End() + return tm.svc.RevokeAll(ctx, session, entityID) +} + +func (tm *tracingMiddleware) IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "issue_cert") + defer span.End() + return tm.svc.IssueCert(ctx, session, entityID, ttl, ipAddrs, options) +} + +func (tm *tracingMiddleware) ListCerts(ctx context.Context, session authn.Session, pm certs.PageMetadata) (certs.CertificatePage, error) { + ctx, span := tm.tracer.Start(ctx, "list_certs") + defer span.End() + return tm.svc.ListCerts(ctx, session, pm) +} + +func (tm *tracingMiddleware) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "view_cert") + defer span.End() + return tm.svc.ViewCert(ctx, session, serialNumber) +} + +func (tm *tracingMiddleware) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) { + ctx, span := tm.tracer.Start(ctx, "ocsp") + defer span.End() + return tm.svc.OCSP(ctx, serialNumber, ocspRequestDER) +} + +func (tm *tracingMiddleware) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + ctx, span := tm.tracer.Start(ctx, "get_entity_id") + defer span.End() + return tm.svc.GetEntityID(ctx, serialNumber) +} + +func (tm *tracingMiddleware) GenerateCRL(ctx context.Context) ([]byte, error) { + ctx, span := tm.tracer.Start(ctx, "generate_crl") + defer span.End() + return tm.svc.GenerateCRL(ctx) +} + +func (tm *tracingMiddleware) RetrieveCAChain(ctx context.Context) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "retrieve_ca_chain") + defer span.End() + return tm.svc.RetrieveCAChain(ctx) +} + +func (tm *tracingMiddleware) IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "issue_from_csr") + defer span.End() + return tm.svc.IssueFromCSR(ctx, session, entityID, ttl, csr) +} + +func (tm *tracingMiddleware) IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "issue_from_csr_internal") + defer span.End() + return tm.svc.IssueFromCSRInternal(ctx, entityID, ttl, csr) +} diff --git a/certs/mocks/agent.go b/certs/mocks/agent.go new file mode 100644 index 000000000..3461017f6 --- /dev/null +++ b/certs/mocks/agent.go @@ -0,0 +1,702 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/certs" + mock "github.com/stretchr/testify/mock" +) + +// NewAgent creates a new instance of Agent. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAgent(t interface { + mock.TestingT + Cleanup(func()) +}) *Agent { + mock := &Agent{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Agent is an autogenerated mock type for the Agent type +type Agent struct { + mock.Mock +} + +type Agent_Expecter struct { + mock *mock.Mock +} + +func (_m *Agent) EXPECT() *Agent_Expecter { + return &Agent_Expecter{mock: &_m.Mock} +} + +// GetCA provides a mock function for the type Agent +func (_mock *Agent) GetCA() ([]byte, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCA") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func() ([]byte, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() []byte); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_GetCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCA' +type Agent_GetCA_Call struct { + *mock.Call +} + +// GetCA is a helper method to define mock.On call +func (_e *Agent_Expecter) GetCA() *Agent_GetCA_Call { + return &Agent_GetCA_Call{Call: _e.mock.On("GetCA")} +} + +func (_c *Agent_GetCA_Call) Run(run func()) *Agent_GetCA_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Agent_GetCA_Call) Return(bytes []byte, err error) *Agent_GetCA_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Agent_GetCA_Call) RunAndReturn(run func() ([]byte, error)) *Agent_GetCA_Call { + _c.Call.Return(run) + return _c +} + +// GetCAChain provides a mock function for the type Agent +func (_mock *Agent) GetCAChain() ([]byte, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCAChain") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func() ([]byte, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() []byte); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_GetCAChain_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCAChain' +type Agent_GetCAChain_Call struct { + *mock.Call +} + +// GetCAChain is a helper method to define mock.On call +func (_e *Agent_Expecter) GetCAChain() *Agent_GetCAChain_Call { + return &Agent_GetCAChain_Call{Call: _e.mock.On("GetCAChain")} +} + +func (_c *Agent_GetCAChain_Call) Run(run func()) *Agent_GetCAChain_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Agent_GetCAChain_Call) Return(bytes []byte, err error) *Agent_GetCAChain_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Agent_GetCAChain_Call) RunAndReturn(run func() ([]byte, error)) *Agent_GetCAChain_Call { + _c.Call.Return(run) + return _c +} + +// GetCRL provides a mock function for the type Agent +func (_mock *Agent) GetCRL() ([]byte, error) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCRL") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func() ([]byte, error)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() []byte); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func() error); ok { + r1 = returnFunc() + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_GetCRL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCRL' +type Agent_GetCRL_Call struct { + *mock.Call +} + +// GetCRL is a helper method to define mock.On call +func (_e *Agent_Expecter) GetCRL() *Agent_GetCRL_Call { + return &Agent_GetCRL_Call{Call: _e.mock.On("GetCRL")} +} + +func (_c *Agent_GetCRL_Call) Run(run func()) *Agent_GetCRL_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Agent_GetCRL_Call) Return(bytes []byte, err error) *Agent_GetCRL_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Agent_GetCRL_Call) RunAndReturn(run func() ([]byte, error)) *Agent_GetCRL_Call { + _c.Call.Return(run) + return _c +} + +// Issue provides a mock function for the type Agent +func (_mock *Agent) Issue(ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { + ret := _mock.Called(ttl, ipAddrs, options) + + if len(ret) == 0 { + panic("no return value specified for Issue") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { + return returnFunc(ttl, ipAddrs, options) + } + if returnFunc, ok := ret.Get(0).(func(string, []string, certs.SubjectOptions) certs.Certificate); ok { + r0 = returnFunc(ttl, ipAddrs, options) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(string, []string, certs.SubjectOptions) error); ok { + r1 = returnFunc(ttl, ipAddrs, options) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_Issue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Issue' +type Agent_Issue_Call struct { + *mock.Call +} + +// Issue is a helper method to define mock.On call +// - ttl string +// - ipAddrs []string +// - options certs.SubjectOptions +func (_e *Agent_Expecter) Issue(ttl interface{}, ipAddrs interface{}, options interface{}) *Agent_Issue_Call { + return &Agent_Issue_Call{Call: _e.mock.On("Issue", ttl, ipAddrs, options)} +} + +func (_c *Agent_Issue_Call) Run(run func(ttl string, ipAddrs []string, options certs.SubjectOptions)) *Agent_Issue_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + var arg2 certs.SubjectOptions + if args[2] != nil { + arg2 = args[2].(certs.SubjectOptions) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Agent_Issue_Call) Return(certificate certs.Certificate, err error) *Agent_Issue_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Agent_Issue_Call) RunAndReturn(run func(ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error)) *Agent_Issue_Call { + _c.Call.Return(run) + return _c +} + +// ListCerts provides a mock function for the type Agent +func (_mock *Agent) ListCerts(pm certs.PageMetadata) (certs.CertificatePage, error) { + ret := _mock.Called(pm) + + if len(ret) == 0 { + panic("no return value specified for ListCerts") + } + + var r0 certs.CertificatePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(certs.PageMetadata) (certs.CertificatePage, error)); ok { + return returnFunc(pm) + } + if returnFunc, ok := ret.Get(0).(func(certs.PageMetadata) certs.CertificatePage); ok { + r0 = returnFunc(pm) + } else { + r0 = ret.Get(0).(certs.CertificatePage) + } + if returnFunc, ok := ret.Get(1).(func(certs.PageMetadata) error); ok { + r1 = returnFunc(pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_ListCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCerts' +type Agent_ListCerts_Call struct { + *mock.Call +} + +// ListCerts is a helper method to define mock.On call +// - pm certs.PageMetadata +func (_e *Agent_Expecter) ListCerts(pm interface{}) *Agent_ListCerts_Call { + return &Agent_ListCerts_Call{Call: _e.mock.On("ListCerts", pm)} +} + +func (_c *Agent_ListCerts_Call) Run(run func(pm certs.PageMetadata)) *Agent_ListCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 certs.PageMetadata + if args[0] != nil { + arg0 = args[0].(certs.PageMetadata) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Agent_ListCerts_Call) Return(certificatePage certs.CertificatePage, err error) *Agent_ListCerts_Call { + _c.Call.Return(certificatePage, err) + return _c +} + +func (_c *Agent_ListCerts_Call) RunAndReturn(run func(pm certs.PageMetadata) (certs.CertificatePage, error)) *Agent_ListCerts_Call { + _c.Call.Return(run) + return _c +} + +// OCSP provides a mock function for the type Agent +func (_mock *Agent) OCSP(serialNumber string, ocspRequestDER []byte) ([]byte, error) { + ret := _mock.Called(serialNumber, ocspRequestDER) + + if len(ret) == 0 { + panic("no return value specified for OCSP") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, []byte) ([]byte, error)); ok { + return returnFunc(serialNumber, ocspRequestDER) + } + if returnFunc, ok := ret.Get(0).(func(string, []byte) []byte); ok { + r0 = returnFunc(serialNumber, ocspRequestDER) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func(string, []byte) error); ok { + r1 = returnFunc(serialNumber, ocspRequestDER) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_OCSP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OCSP' +type Agent_OCSP_Call struct { + *mock.Call +} + +// OCSP is a helper method to define mock.On call +// - serialNumber string +// - ocspRequestDER []byte +func (_e *Agent_Expecter) OCSP(serialNumber interface{}, ocspRequestDER interface{}) *Agent_OCSP_Call { + return &Agent_OCSP_Call{Call: _e.mock.On("OCSP", serialNumber, ocspRequestDER)} +} + +func (_c *Agent_OCSP_Call) Run(run func(serialNumber string, ocspRequestDER []byte)) *Agent_OCSP_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 []byte + if args[1] != nil { + arg1 = args[1].([]byte) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Agent_OCSP_Call) Return(bytes []byte, err error) *Agent_OCSP_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Agent_OCSP_Call) RunAndReturn(run func(serialNumber string, ocspRequestDER []byte) ([]byte, error)) *Agent_OCSP_Call { + _c.Call.Return(run) + return _c +} + +// Renew provides a mock function for the type Agent +func (_mock *Agent) Renew(cert certs.Certificate, increment string) (certs.Certificate, error) { + ret := _mock.Called(cert, increment) + + if len(ret) == 0 { + panic("no return value specified for Renew") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(certs.Certificate, string) (certs.Certificate, error)); ok { + return returnFunc(cert, increment) + } + if returnFunc, ok := ret.Get(0).(func(certs.Certificate, string) certs.Certificate); ok { + r0 = returnFunc(cert, increment) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(certs.Certificate, string) error); ok { + r1 = returnFunc(cert, increment) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_Renew_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Renew' +type Agent_Renew_Call struct { + *mock.Call +} + +// Renew is a helper method to define mock.On call +// - cert certs.Certificate +// - increment string +func (_e *Agent_Expecter) Renew(cert interface{}, increment interface{}) *Agent_Renew_Call { + return &Agent_Renew_Call{Call: _e.mock.On("Renew", cert, increment)} +} + +func (_c *Agent_Renew_Call) Run(run func(cert certs.Certificate, increment string)) *Agent_Renew_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 certs.Certificate + if args[0] != nil { + arg0 = args[0].(certs.Certificate) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Agent_Renew_Call) Return(certificate certs.Certificate, err error) *Agent_Renew_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Agent_Renew_Call) RunAndReturn(run func(cert certs.Certificate, increment string) (certs.Certificate, error)) *Agent_Renew_Call { + _c.Call.Return(run) + return _c +} + +// Revoke provides a mock function for the type Agent +func (_mock *Agent) Revoke(serialNumber string) error { + ret := _mock.Called(serialNumber) + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(string) error); ok { + r0 = returnFunc(serialNumber) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Agent_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type Agent_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - serialNumber string +func (_e *Agent_Expecter) Revoke(serialNumber interface{}) *Agent_Revoke_Call { + return &Agent_Revoke_Call{Call: _e.mock.On("Revoke", serialNumber)} +} + +func (_c *Agent_Revoke_Call) Run(run func(serialNumber string)) *Agent_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Agent_Revoke_Call) Return(err error) *Agent_Revoke_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Agent_Revoke_Call) RunAndReturn(run func(serialNumber string) error) *Agent_Revoke_Call { + _c.Call.Return(run) + return _c +} + +// SignCSR provides a mock function for the type Agent +func (_mock *Agent) SignCSR(csr []byte, ttl string) (certs.Certificate, error) { + ret := _mock.Called(csr, ttl) + + if len(ret) == 0 { + panic("no return value specified for SignCSR") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func([]byte, string) (certs.Certificate, error)); ok { + return returnFunc(csr, ttl) + } + if returnFunc, ok := ret.Get(0).(func([]byte, string) certs.Certificate); ok { + r0 = returnFunc(csr, ttl) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func([]byte, string) error); ok { + r1 = returnFunc(csr, ttl) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' +type Agent_SignCSR_Call struct { + *mock.Call +} + +// SignCSR is a helper method to define mock.On call +// - csr []byte +// - ttl string +func (_e *Agent_Expecter) SignCSR(csr interface{}, ttl interface{}) *Agent_SignCSR_Call { + return &Agent_SignCSR_Call{Call: _e.mock.On("SignCSR", csr, ttl)} +} + +func (_c *Agent_SignCSR_Call) Run(run func(csr []byte, ttl string)) *Agent_SignCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []byte + if args[0] != nil { + arg0 = args[0].([]byte) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Agent_SignCSR_Call) Return(certificate certs.Certificate, err error) *Agent_SignCSR_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Agent_SignCSR_Call) RunAndReturn(run func(csr []byte, ttl string) (certs.Certificate, error)) *Agent_SignCSR_Call { + _c.Call.Return(run) + return _c +} + +// StartSecretRenewal provides a mock function for the type Agent +func (_mock *Agent) StartSecretRenewal(ctx context.Context) error { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for StartSecretRenewal") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Agent_StartSecretRenewal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StartSecretRenewal' +type Agent_StartSecretRenewal_Call struct { + *mock.Call +} + +// StartSecretRenewal is a helper method to define mock.On call +// - ctx context.Context +func (_e *Agent_Expecter) StartSecretRenewal(ctx interface{}) *Agent_StartSecretRenewal_Call { + return &Agent_StartSecretRenewal_Call{Call: _e.mock.On("StartSecretRenewal", ctx)} +} + +func (_c *Agent_StartSecretRenewal_Call) Run(run func(ctx context.Context)) *Agent_StartSecretRenewal_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Agent_StartSecretRenewal_Call) Return(err error) *Agent_StartSecretRenewal_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Agent_StartSecretRenewal_Call) RunAndReturn(run func(ctx context.Context) error) *Agent_StartSecretRenewal_Call { + _c.Call.Return(run) + return _c +} + +// View provides a mock function for the type Agent +func (_mock *Agent) View(serialNumber string) (certs.Certificate, error) { + ret := _mock.Called(serialNumber) + + if len(ret) == 0 { + panic("no return value specified for View") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(string) (certs.Certificate, error)); ok { + return returnFunc(serialNumber) + } + if returnFunc, ok := ret.Get(0).(func(string) certs.Certificate); ok { + r0 = returnFunc(serialNumber) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(string) error); ok { + r1 = returnFunc(serialNumber) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Agent_View_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'View' +type Agent_View_Call struct { + *mock.Call +} + +// View is a helper method to define mock.On call +// - serialNumber string +func (_e *Agent_Expecter) View(serialNumber interface{}) *Agent_View_Call { + return &Agent_View_Call{Call: _e.mock.On("View", serialNumber)} +} + +func (_c *Agent_View_Call) Run(run func(serialNumber string)) *Agent_View_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Agent_View_Call) Return(certificate certs.Certificate, err error) *Agent_View_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Agent_View_Call) RunAndReturn(run func(serialNumber string) (certs.Certificate, error)) *Agent_View_Call { + _c.Call.Return(run) + return _c +} diff --git a/certs/mocks/certs_client.go b/certs/mocks/certs_client.go new file mode 100644 index 000000000..af6675014 --- /dev/null +++ b/certs/mocks/certs_client.go @@ -0,0 +1,211 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/api/grpc/certs/v1" + mock "github.com/stretchr/testify/mock" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/emptypb" +) + +// NewCertsServiceClient creates a new instance of CertsServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCertsServiceClient(t interface { + mock.TestingT + Cleanup(func()) +}) *CertsServiceClient { + mock := &CertsServiceClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// CertsServiceClient is an autogenerated mock type for the CertsServiceClient type +type CertsServiceClient struct { + mock.Mock +} + +type CertsServiceClient_Expecter struct { + mock *mock.Mock +} + +func (_m *CertsServiceClient) EXPECT() *CertsServiceClient_Expecter { + return &CertsServiceClient_Expecter{mock: &_m.Mock} +} + +// GetEntityID provides a mock function for the type CertsServiceClient +func (_mock *CertsServiceClient) GetEntityID(ctx context.Context, in *v1.EntityReq, opts ...grpc.CallOption) (*v1.EntityRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for GetEntityID") + } + + var r0 *v1.EntityRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.EntityReq, ...grpc.CallOption) (*v1.EntityRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.EntityReq, ...grpc.CallOption) *v1.EntityRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.EntityRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.EntityReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// CertsServiceClient_GetEntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityID' +type CertsServiceClient_GetEntityID_Call struct { + *mock.Call +} + +// GetEntityID is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.EntityReq +// - opts ...grpc.CallOption +func (_e *CertsServiceClient_Expecter) GetEntityID(ctx interface{}, in interface{}, opts ...interface{}) *CertsServiceClient_GetEntityID_Call { + return &CertsServiceClient_GetEntityID_Call{Call: _e.mock.On("GetEntityID", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *CertsServiceClient_GetEntityID_Call) Run(run func(ctx context.Context, in *v1.EntityReq, opts ...grpc.CallOption)) *CertsServiceClient_GetEntityID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.EntityReq + if args[1] != nil { + arg1 = args[1].(*v1.EntityReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *CertsServiceClient_GetEntityID_Call) Return(entityRes *v1.EntityRes, err error) *CertsServiceClient_GetEntityID_Call { + _c.Call.Return(entityRes, err) + return _c +} + +func (_c *CertsServiceClient_GetEntityID_Call) RunAndReturn(run func(ctx context.Context, in *v1.EntityReq, opts ...grpc.CallOption) (*v1.EntityRes, error)) *CertsServiceClient_GetEntityID_Call { + _c.Call.Return(run) + return _c +} + +// RevokeCerts provides a mock function for the type CertsServiceClient +func (_mock *CertsServiceClient) RevokeCerts(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*emptypb.Empty, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for RevokeCerts") + } + + var r0 *emptypb.Empty + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*emptypb.Empty, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *emptypb.Empty); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*emptypb.Empty) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// CertsServiceClient_RevokeCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeCerts' +type CertsServiceClient_RevokeCerts_Call struct { + *mock.Call +} + +// RevokeCerts is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.RevokeReq +// - opts ...grpc.CallOption +func (_e *CertsServiceClient_Expecter) RevokeCerts(ctx interface{}, in interface{}, opts ...interface{}) *CertsServiceClient_RevokeCerts_Call { + return &CertsServiceClient_RevokeCerts_Call{Call: _e.mock.On("RevokeCerts", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *CertsServiceClient_RevokeCerts_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *CertsServiceClient_RevokeCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.RevokeReq + if args[1] != nil { + arg1 = args[1].(*v1.RevokeReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *CertsServiceClient_RevokeCerts_Call) Return(empty *emptypb.Empty, err error) *CertsServiceClient_RevokeCerts_Call { + _c.Call.Return(empty, err) + return _c +} + +func (_c *CertsServiceClient_RevokeCerts_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*emptypb.Empty, error)) *CertsServiceClient_RevokeCerts_Call { + _c.Call.Return(run) + return _c +} diff --git a/certs/mocks/repository.go b/certs/mocks/repository.go new file mode 100644 index 000000000..c54082afd --- /dev/null +++ b/certs/mocks/repository.go @@ -0,0 +1,296 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// GetEntityIDBySerial provides a mock function for the type Repository +func (_mock *Repository) GetEntityIDBySerial(ctx context.Context, serialNumber string) (string, error) { + ret := _mock.Called(ctx, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for GetEntityIDBySerial") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return returnFunc(ctx, serialNumber) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = returnFunc(ctx, serialNumber) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, serialNumber) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_GetEntityIDBySerial_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityIDBySerial' +type Repository_GetEntityIDBySerial_Call struct { + *mock.Call +} + +// GetEntityIDBySerial is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +func (_e *Repository_Expecter) GetEntityIDBySerial(ctx interface{}, serialNumber interface{}) *Repository_GetEntityIDBySerial_Call { + return &Repository_GetEntityIDBySerial_Call{Call: _e.mock.On("GetEntityIDBySerial", ctx, serialNumber)} +} + +func (_c *Repository_GetEntityIDBySerial_Call) Run(run func(ctx context.Context, serialNumber string)) *Repository_GetEntityIDBySerial_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_GetEntityIDBySerial_Call) Return(s string, err error) *Repository_GetEntityIDBySerial_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *Repository_GetEntityIDBySerial_Call) RunAndReturn(run func(ctx context.Context, serialNumber string) (string, error)) *Repository_GetEntityIDBySerial_Call { + _c.Call.Return(run) + return _c +} + +// ListCertsByEntityID provides a mock function for the type Repository +func (_mock *Repository) ListCertsByEntityID(ctx context.Context, entityID string) ([]string, error) { + ret := _mock.Called(ctx, entityID) + + if len(ret) == 0 { + panic("no return value specified for ListCertsByEntityID") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { + return returnFunc(ctx, entityID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []string); ok { + r0 = returnFunc(ctx, entityID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, entityID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListCertsByEntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCertsByEntityID' +type Repository_ListCertsByEntityID_Call struct { + *mock.Call +} + +// ListCertsByEntityID is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +func (_e *Repository_Expecter) ListCertsByEntityID(ctx interface{}, entityID interface{}) *Repository_ListCertsByEntityID_Call { + return &Repository_ListCertsByEntityID_Call{Call: _e.mock.On("ListCertsByEntityID", ctx, entityID)} +} + +func (_c *Repository_ListCertsByEntityID_Call) Run(run func(ctx context.Context, entityID string)) *Repository_ListCertsByEntityID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListCertsByEntityID_Call) Return(strings []string, err error) *Repository_ListCertsByEntityID_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Repository_ListCertsByEntityID_Call) RunAndReturn(run func(ctx context.Context, entityID string) ([]string, error)) *Repository_ListCertsByEntityID_Call { + _c.Call.Return(run) + return _c +} + +// RemoveCertEntityMapping provides a mock function for the type Repository +func (_mock *Repository) RemoveCertEntityMapping(ctx context.Context, serialNumber string) error { + ret := _mock.Called(ctx, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for RemoveCertEntityMapping") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, serialNumber) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveCertEntityMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCertEntityMapping' +type Repository_RemoveCertEntityMapping_Call struct { + *mock.Call +} + +// RemoveCertEntityMapping is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +func (_e *Repository_Expecter) RemoveCertEntityMapping(ctx interface{}, serialNumber interface{}) *Repository_RemoveCertEntityMapping_Call { + return &Repository_RemoveCertEntityMapping_Call{Call: _e.mock.On("RemoveCertEntityMapping", ctx, serialNumber)} +} + +func (_c *Repository_RemoveCertEntityMapping_Call) Run(run func(ctx context.Context, serialNumber string)) *Repository_RemoveCertEntityMapping_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveCertEntityMapping_Call) Return(err error) *Repository_RemoveCertEntityMapping_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveCertEntityMapping_Call) RunAndReturn(run func(ctx context.Context, serialNumber string) error) *Repository_RemoveCertEntityMapping_Call { + _c.Call.Return(run) + return _c +} + +// SaveCertEntityMapping provides a mock function for the type Repository +func (_mock *Repository) SaveCertEntityMapping(ctx context.Context, serialNumber string, entityID string) error { + ret := _mock.Called(ctx, serialNumber, entityID) + + if len(ret) == 0 { + panic("no return value specified for SaveCertEntityMapping") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, serialNumber, entityID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_SaveCertEntityMapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCertEntityMapping' +type Repository_SaveCertEntityMapping_Call struct { + *mock.Call +} + +// SaveCertEntityMapping is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - entityID string +func (_e *Repository_Expecter) SaveCertEntityMapping(ctx interface{}, serialNumber interface{}, entityID interface{}) *Repository_SaveCertEntityMapping_Call { + return &Repository_SaveCertEntityMapping_Call{Call: _e.mock.On("SaveCertEntityMapping", ctx, serialNumber, entityID)} +} + +func (_c *Repository_SaveCertEntityMapping_Call) Run(run func(ctx context.Context, serialNumber string, entityID string)) *Repository_SaveCertEntityMapping_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_SaveCertEntityMapping_Call) Return(err error) *Repository_SaveCertEntityMapping_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_SaveCertEntityMapping_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, entityID string) error) *Repository_SaveCertEntityMapping_Call { + _c.Call.Return(run) + return _c +} diff --git a/certs/mocks/service.go b/certs/mocks/service.go new file mode 100644 index 000000000..937f09939 --- /dev/null +++ b/certs/mocks/service.go @@ -0,0 +1,900 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// GenerateCRL provides a mock function for the type Service +func (_mock *Service) GenerateCRL(ctx context.Context) ([]byte, error) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GenerateCRL") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context) ([]byte, error)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) []byte); ok { + r0 = returnFunc(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = returnFunc(ctx) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GenerateCRL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenerateCRL' +type Service_GenerateCRL_Call struct { + *mock.Call +} + +// GenerateCRL is a helper method to define mock.On call +// - ctx context.Context +func (_e *Service_Expecter) GenerateCRL(ctx interface{}) *Service_GenerateCRL_Call { + return &Service_GenerateCRL_Call{Call: _e.mock.On("GenerateCRL", ctx)} +} + +func (_c *Service_GenerateCRL_Call) Run(run func(ctx context.Context)) *Service_GenerateCRL_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Service_GenerateCRL_Call) Return(bytes []byte, err error) *Service_GenerateCRL_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Service_GenerateCRL_Call) RunAndReturn(run func(ctx context.Context) ([]byte, error)) *Service_GenerateCRL_Call { + _c.Call.Return(run) + return _c +} + +// GetEntityID provides a mock function for the type Service +func (_mock *Service) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + ret := _mock.Called(ctx, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for GetEntityID") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return returnFunc(ctx, serialNumber) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = returnFunc(ctx, serialNumber) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, serialNumber) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GetEntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityID' +type Service_GetEntityID_Call struct { + *mock.Call +} + +// GetEntityID is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +func (_e *Service_Expecter) GetEntityID(ctx interface{}, serialNumber interface{}) *Service_GetEntityID_Call { + return &Service_GetEntityID_Call{Call: _e.mock.On("GetEntityID", ctx, serialNumber)} +} + +func (_c *Service_GetEntityID_Call) Run(run func(ctx context.Context, serialNumber string)) *Service_GetEntityID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_GetEntityID_Call) Return(s string, err error) *Service_GetEntityID_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *Service_GetEntityID_Call) RunAndReturn(run func(ctx context.Context, serialNumber string) (string, error)) *Service_GetEntityID_Call { + _c.Call.Return(run) + return _c +} + +// IssueCert provides a mock function for the type Service +func (_mock *Service) IssueCert(ctx context.Context, session authn.Session, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { + ret := _mock.Called(ctx, session, entityID, ttl, ipAddrs, option) + + if len(ret) == 0 { + panic("no return value specified for IssueCert") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { + return returnFunc(ctx, session, entityID, ttl, ipAddrs, option) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { + r0 = returnFunc(ctx, session, entityID, ttl, ipAddrs, option) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string, certs.SubjectOptions) error); ok { + r1 = returnFunc(ctx, session, entityID, ttl, ipAddrs, option) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_IssueCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueCert' +type Service_IssueCert_Call struct { + *mock.Call +} + +// IssueCert is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - ttl string +// - ipAddrs []string +// - option certs.SubjectOptions +func (_e *Service_Expecter) IssueCert(ctx interface{}, session interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *Service_IssueCert_Call { + return &Service_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, session, entityID, ttl, ipAddrs, option)} +} + +func (_c *Service_IssueCert_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *Service_IssueCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + var arg5 certs.SubjectOptions + if args[5] != nil { + arg5 = args[5].(certs.SubjectOptions) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_IssueCert_Call) Return(certificate certs.Certificate, err error) *Service_IssueCert_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_IssueCert_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error)) *Service_IssueCert_Call { + _c.Call.Return(run) + return _c +} + +// IssueFromCSR provides a mock function for the type Service +func (_mock *Service) IssueFromCSR(ctx context.Context, session authn.Session, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error) { + ret := _mock.Called(ctx, session, entityID, ttl, csr) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSR") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, certs.CSR) (certs.Certificate, error)); ok { + return returnFunc(ctx, session, entityID, ttl, csr) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, certs.CSR) certs.Certificate); ok { + r0 = returnFunc(ctx, session, entityID, ttl, csr) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, certs.CSR) error); ok { + r1 = returnFunc(ctx, session, entityID, ttl, csr) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_IssueFromCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSR' +type Service_IssueFromCSR_Call struct { + *mock.Call +} + +// IssueFromCSR is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - ttl string +// - csr certs.CSR +func (_e *Service_Expecter) IssueFromCSR(ctx interface{}, session interface{}, entityID interface{}, ttl interface{}, csr interface{}) *Service_IssueFromCSR_Call { + return &Service_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", ctx, session, entityID, ttl, csr)} +} + +func (_c *Service_IssueFromCSR_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, ttl string, csr certs.CSR)) *Service_IssueFromCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 certs.CSR + if args[4] != nil { + arg4 = args[4].(certs.CSR) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_IssueFromCSR_Call) Return(certificate certs.Certificate, err error) *Service_IssueFromCSR_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_IssueFromCSR_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error)) *Service_IssueFromCSR_Call { + _c.Call.Return(run) + return _c +} + +// IssueFromCSRInternal provides a mock function for the type Service +func (_mock *Service) IssueFromCSRInternal(ctx context.Context, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error) { + ret := _mock.Called(ctx, entityID, ttl, csr) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSRInternal") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) (certs.Certificate, error)); ok { + return returnFunc(ctx, entityID, ttl, csr) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) certs.Certificate); ok { + r0 = returnFunc(ctx, entityID, ttl, csr) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, certs.CSR) error); ok { + r1 = returnFunc(ctx, entityID, ttl, csr) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_IssueFromCSRInternal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSRInternal' +type Service_IssueFromCSRInternal_Call struct { + *mock.Call +} + +// IssueFromCSRInternal is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - ttl string +// - csr certs.CSR +func (_e *Service_Expecter) IssueFromCSRInternal(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}) *Service_IssueFromCSRInternal_Call { + return &Service_IssueFromCSRInternal_Call{Call: _e.mock.On("IssueFromCSRInternal", ctx, entityID, ttl, csr)} +} + +func (_c *Service_IssueFromCSRInternal_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr certs.CSR)) *Service_IssueFromCSRInternal_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 certs.CSR + if args[3] != nil { + arg3 = args[3].(certs.CSR) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_IssueFromCSRInternal_Call) Return(certificate certs.Certificate, err error) *Service_IssueFromCSRInternal_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_IssueFromCSRInternal_Call) RunAndReturn(run func(ctx context.Context, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error)) *Service_IssueFromCSRInternal_Call { + _c.Call.Return(run) + return _c +} + +// ListCerts provides a mock function for the type Service +func (_mock *Service) ListCerts(ctx context.Context, session authn.Session, pm certs.PageMetadata) (certs.CertificatePage, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListCerts") + } + + var r0 certs.CertificatePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, certs.PageMetadata) (certs.CertificatePage, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, certs.PageMetadata) certs.CertificatePage); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(certs.CertificatePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, certs.PageMetadata) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCerts' +type Service_ListCerts_Call struct { + *mock.Call +} + +// ListCerts is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - pm certs.PageMetadata +func (_e *Service_Expecter) ListCerts(ctx interface{}, session interface{}, pm interface{}) *Service_ListCerts_Call { + return &Service_ListCerts_Call{Call: _e.mock.On("ListCerts", ctx, session, pm)} +} + +func (_c *Service_ListCerts_Call) Run(run func(ctx context.Context, session authn.Session, pm certs.PageMetadata)) *Service_ListCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 certs.PageMetadata + if args[2] != nil { + arg2 = args[2].(certs.PageMetadata) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ListCerts_Call) Return(certificatePage certs.CertificatePage, err error) *Service_ListCerts_Call { + _c.Call.Return(certificatePage, err) + return _c +} + +func (_c *Service_ListCerts_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm certs.PageMetadata) (certs.CertificatePage, error)) *Service_ListCerts_Call { + _c.Call.Return(run) + return _c +} + +// OCSP provides a mock function for the type Service +func (_mock *Service) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) { + ret := _mock.Called(ctx, serialNumber, ocspRequestDER) + + if len(ret) == 0 { + panic("no return value specified for OCSP") + } + + var r0 []byte + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) ([]byte, error)); ok { + return returnFunc(ctx, serialNumber, ocspRequestDER) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []byte) []byte); ok { + r0 = returnFunc(ctx, serialNumber, ocspRequestDER) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []byte) error); ok { + r1 = returnFunc(ctx, serialNumber, ocspRequestDER) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_OCSP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OCSP' +type Service_OCSP_Call struct { + *mock.Call +} + +// OCSP is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - ocspRequestDER []byte +func (_e *Service_Expecter) OCSP(ctx interface{}, serialNumber interface{}, ocspRequestDER interface{}) *Service_OCSP_Call { + return &Service_OCSP_Call{Call: _e.mock.On("OCSP", ctx, serialNumber, ocspRequestDER)} +} + +func (_c *Service_OCSP_Call) Run(run func(ctx context.Context, serialNumber string, ocspRequestDER []byte)) *Service_OCSP_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []byte + if args[2] != nil { + arg2 = args[2].([]byte) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_OCSP_Call) Return(bytes []byte, err error) *Service_OCSP_Call { + _c.Call.Return(bytes, err) + return _c +} + +func (_c *Service_OCSP_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error)) *Service_OCSP_Call { + _c.Call.Return(run) + return _c +} + +// RenewCert provides a mock function for the type Service +func (_mock *Service) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + ret := _mock.Called(ctx, session, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for RenewCert") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (certs.Certificate, error)); ok { + return returnFunc(ctx, session, serialNumber) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) certs.Certificate); ok { + r0 = returnFunc(ctx, session, serialNumber) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, serialNumber) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RenewCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenewCert' +type Service_RenewCert_Call struct { + *mock.Call +} + +// RenewCert is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - serialNumber string +func (_e *Service_Expecter) RenewCert(ctx interface{}, session interface{}, serialNumber interface{}) *Service_RenewCert_Call { + return &Service_RenewCert_Call{Call: _e.mock.On("RenewCert", ctx, session, serialNumber)} +} + +func (_c *Service_RenewCert_Call) Run(run func(ctx context.Context, session authn.Session, serialNumber string)) *Service_RenewCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RenewCert_Call) Return(certificate certs.Certificate, err error) *Service_RenewCert_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_RenewCert_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error)) *Service_RenewCert_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveCAChain provides a mock function for the type Service +func (_mock *Service) RetrieveCAChain(ctx context.Context) (certs.Certificate, error) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCAChain") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context) (certs.Certificate, error)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) certs.Certificate); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = returnFunc(ctx) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RetrieveCAChain_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCAChain' +type Service_RetrieveCAChain_Call struct { + *mock.Call +} + +// RetrieveCAChain is a helper method to define mock.On call +// - ctx context.Context +func (_e *Service_Expecter) RetrieveCAChain(ctx interface{}) *Service_RetrieveCAChain_Call { + return &Service_RetrieveCAChain_Call{Call: _e.mock.On("RetrieveCAChain", ctx)} +} + +func (_c *Service_RetrieveCAChain_Call) Run(run func(ctx context.Context)) *Service_RetrieveCAChain_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Service_RetrieveCAChain_Call) Return(certificate certs.Certificate, err error) *Service_RetrieveCAChain_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_RetrieveCAChain_Call) RunAndReturn(run func(ctx context.Context) (certs.Certificate, error)) *Service_RetrieveCAChain_Call { + _c.Call.Return(run) + return _c +} + +// RevokeAll provides a mock function for the type Service +func (_mock *Service) RevokeAll(ctx context.Context, session authn.Session, entityID string) error { + ret := _mock.Called(ctx, session, entityID) + + if len(ret) == 0 { + panic("no return value specified for RevokeAll") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, entityID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeAll' +type Service_RevokeAll_Call struct { + *mock.Call +} + +// RevokeAll is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +func (_e *Service_Expecter) RevokeAll(ctx interface{}, session interface{}, entityID interface{}) *Service_RevokeAll_Call { + return &Service_RevokeAll_Call{Call: _e.mock.On("RevokeAll", ctx, session, entityID)} +} + +func (_c *Service_RevokeAll_Call) Run(run func(ctx context.Context, session authn.Session, entityID string)) *Service_RevokeAll_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RevokeAll_Call) Return(err error) *Service_RevokeAll_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeAll_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string) error) *Service_RevokeAll_Call { + _c.Call.Return(run) + return _c +} + +// RevokeBySerial provides a mock function for the type Service +func (_mock *Service) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error { + ret := _mock.Called(ctx, session, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for RevokeBySerial") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, serialNumber) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeBySerial_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeBySerial' +type Service_RevokeBySerial_Call struct { + *mock.Call +} + +// RevokeBySerial is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - serialNumber string +func (_e *Service_Expecter) RevokeBySerial(ctx interface{}, session interface{}, serialNumber interface{}) *Service_RevokeBySerial_Call { + return &Service_RevokeBySerial_Call{Call: _e.mock.On("RevokeBySerial", ctx, session, serialNumber)} +} + +func (_c *Service_RevokeBySerial_Call) Run(run func(ctx context.Context, session authn.Session, serialNumber string)) *Service_RevokeBySerial_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RevokeBySerial_Call) Return(err error) *Service_RevokeBySerial_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeBySerial_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, serialNumber string) error) *Service_RevokeBySerial_Call { + _c.Call.Return(run) + return _c +} + +// ViewCert provides a mock function for the type Service +func (_mock *Service) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error) { + ret := _mock.Called(ctx, session, serialNumber) + + if len(ret) == 0 { + panic("no return value specified for ViewCert") + } + + var r0 certs.Certificate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (certs.Certificate, error)); ok { + return returnFunc(ctx, session, serialNumber) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) certs.Certificate); ok { + r0 = returnFunc(ctx, session, serialNumber) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, serialNumber) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewCert' +type Service_ViewCert_Call struct { + *mock.Call +} + +// ViewCert is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - serialNumber string +func (_e *Service_Expecter) ViewCert(ctx interface{}, session interface{}, serialNumber interface{}) *Service_ViewCert_Call { + return &Service_ViewCert_Call{Call: _e.mock.On("ViewCert", ctx, session, serialNumber)} +} + +func (_c *Service_ViewCert_Call) Run(run func(ctx context.Context, session authn.Session, serialNumber string)) *Service_ViewCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ViewCert_Call) Return(certificate certs.Certificate, err error) *Service_ViewCert_Call { + _c.Call.Return(certificate, err) + return _c +} + +func (_c *Service_ViewCert_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, serialNumber string) (certs.Certificate, error)) *Service_ViewCert_Call { + _c.Call.Return(run) + return _c +} diff --git a/certs/pki.go b/certs/pki.go new file mode 100644 index 000000000..1af0d1b17 --- /dev/null +++ b/certs/pki.go @@ -0,0 +1,21 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package certs + +import "context" + +// Agent represents the PKI interface that all PKI implementations must satisfy. +type Agent interface { + Issue(ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) + View(serialNumber string) (Certificate, error) + Revoke(serialNumber string) error + ListCerts(pm PageMetadata) (CertificatePage, error) + GetCA() ([]byte, error) + GetCAChain() ([]byte, error) + GetCRL() ([]byte, error) + SignCSR(csr []byte, ttl string) (Certificate, error) + Renew(cert Certificate, increment string) (Certificate, error) + OCSP(serialNumber string, ocspRequestDER []byte) ([]byte, error) + StartSecretRenewal(ctx context.Context) error +} diff --git a/readers/mocks/doc.go b/certs/pki/doc.go similarity index 52% rename from readers/mocks/doc.go rename to certs/pki/doc.go index 16ed198af..213c2f817 100644 --- a/readers/mocks/doc.go +++ b/certs/pki/doc.go @@ -1,5 +1,4 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -// Package mocks contains mocks for testing purposes. -package mocks +package pki diff --git a/certs/pki/openbao.go b/certs/pki/openbao.go new file mode 100644 index 000000000..8b99a5831 --- /dev/null +++ b/certs/pki/openbao.go @@ -0,0 +1,983 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package pki wraps OpenBao client for PKI operations +package pki + +import ( + "context" + "crypto" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "time" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/errors" + "github.com/mitchellh/mapstructure" + "github.com/openbao/openbao/api/v2" + "golang.org/x/crypto/ocsp" +) + +const ( + issue = "issue" + sign = "sign" + cert = "cert" + revoke = "revoke" + ca = "ca" + caChain = "ca_chain" + crl = "crl" + ocspPath = "ocsp" + certsList = "certs" + signVerbatim = "sign-verbatim" + + defaultRenewThreshold = 24 * time.Hour + defaultSecretIDTTL = 72 * time.Hour + defaultSecretCheckInterval = 30 * time.Second +) + +var ( + errFailedToLogin = errors.New("failed to login to OpenBao") + errNoAuthInfo = errors.New("no auth information from OpenBao") + errRenewWatcher = errors.New("unable to initialize new lifetime watcher for renewing auth token") +) + +type openbaoPKIAgent struct { + appRole string + appSecret string + namespace string + path string + intermediatePath string + role string + host string + issueURL string + signURL string + signVerbatimURL string + readURL string + revokeURL string + caURL string + caChainURL string + rootCAURL string + rootCAChainURL string + crlURL string + ocspURL string + certsURL string + secretIDPath string + client *api.Client + secret *api.Secret + logger *slog.Logger + serviceToken string + renewThreshold time.Duration + secretCheckInterval time.Duration + mu sync.RWMutex + secretAccessor string + secretCreatedAt time.Time + secretTTL time.Duration +} + +// NewAgent instantiates an OpenBao PKI client that implements certs.Agent. +func NewAgent(appRole, appSecret, host, namespace, path, role, serviceToken, renewThreshold, secretIDTTL, secretCheckInterval string, logger *slog.Logger) (certs.Agent, error) { + conf := api.DefaultConfig() + conf.Address = host + + client, err := api.NewClient(conf) + if err != nil { + return nil, err + } + if namespace != "" { + client.SetNamespace(namespace) + } + + intermediatePath := path + "_int" + + renewDuration, err := time.ParseDuration(renewThreshold) + if err != nil { + logger.Warn("Invalid renew threshold duration, using default 24h", "error", err, "provided", renewThreshold) + renewDuration = defaultRenewThreshold + } + + ttlDuration, err := time.ParseDuration(secretIDTTL) + if err != nil { + logger.Warn("Invalid secret ID TTL duration, using default 72h", "error", err, "provided", secretIDTTL) + ttlDuration = defaultSecretIDTTL + } + + checkInterval, err := time.ParseDuration(secretCheckInterval) + if err != nil { + logger.Warn("Invalid secret check interval duration, using default 30s", "error", err, "provided", secretCheckInterval) + checkInterval = defaultSecretCheckInterval + } + + p := openbaoPKIAgent{ + appRole: appRole, + appSecret: appSecret, + host: host, + namespace: namespace, + role: role, + path: path, + intermediatePath: intermediatePath, + client: client, + logger: logger, + serviceToken: serviceToken, + renewThreshold: renewDuration, + secretCheckInterval: checkInterval, + secretTTL: ttlDuration, + issueURL: fmt.Sprintf("%s/%s/%s", intermediatePath, issue, role), + signURL: fmt.Sprintf("%s/%s/%s", intermediatePath, sign, role), + signVerbatimURL: fmt.Sprintf("%s/%s/%s", intermediatePath, signVerbatim, role), + readURL: fmt.Sprintf("%s/%s/", intermediatePath, cert), + revokeURL: fmt.Sprintf("%s/%s", intermediatePath, revoke), + caURL: fmt.Sprintf("%s/%s", intermediatePath, ca), + caChainURL: fmt.Sprintf("%s/%s", intermediatePath, caChain), + rootCAURL: fmt.Sprintf("%s/%s", path, ca), + rootCAChainURL: fmt.Sprintf("%s/%s", path, caChain), + crlURL: fmt.Sprintf("%s/%s", intermediatePath, crl), + ocspURL: fmt.Sprintf("%s/%s", intermediatePath, ocspPath), + certsURL: fmt.Sprintf("%s/%s", intermediatePath, certsList), + secretIDPath: fmt.Sprintf("auth/approle/role/%s/secret-id", role), + } + return &p, nil +} + +func (agent *openbaoPKIAgent) getIntermediateCADefaultSANs() ([]string, []string, error) { + err := agent.LoginAndRenew() + if err != nil { + return nil, nil, err + } + + certData, err := agent.GetCA() + if err != nil { + return nil, nil, fmt.Errorf("failed to get intermediate CA certificate: %w", err) + } + + block, _ := pem.Decode(certData) + if block == nil { + return nil, nil, fmt.Errorf("failed to decode intermediate CA certificate PEM") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse intermediate CA certificate: %w", err) + } + + var ipSANs []string + for _, ip := range cert.IPAddresses { + ipSANs = append(ipSANs, ip.String()) + } + + return cert.DNSNames, ipSANs, nil +} + +func (agent *openbaoPKIAgent) Issue(ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { + err := agent.LoginAndRenew() + if err != nil { + return certs.Certificate{}, err + } + + secretValues := map[string]any{ + "common_name": options.CommonName, + "ttl": ttl, + "exclude_cn_from_sans": true, + } + + if len(options.Organization) > 0 { + secretValues["organization"] = options.Organization + } + if len(options.OrganizationalUnit) > 0 { + secretValues["ou"] = options.OrganizationalUnit + } + if len(options.Country) > 0 { + secretValues["country"] = options.Country + } + if len(options.Province) > 0 { + secretValues["province"] = options.Province + } + if len(options.Locality) > 0 { + secretValues["locality"] = options.Locality + } + if len(options.StreetAddress) > 0 { + secretValues["street_address"] = options.StreetAddress + } + if len(options.PostalCode) > 0 { + secretValues["postal_code"] = options.PostalCode + } + + allDNSNames := make([]string, 0) + allDNSNames = append(allDNSNames, options.DnsNames...) + + defaultDNSNames, defaultIPSANs, err := agent.getIntermediateCADefaultSANs() + if err != nil { + agent.logger.Warn("failed to get default SANs from intermediate CA", "error", err) + } else { + for _, defaultDNS := range defaultDNSNames { + found := false + for _, existing := range allDNSNames { + if existing == defaultDNS { + found = true + break + } + } + if !found { + allDNSNames = append(allDNSNames, defaultDNS) + } + } + } + + if len(allDNSNames) > 0 { + altNamesValue := strings.Join(allDNSNames, ",") + secretValues["alt_names"] = altNamesValue + } + + allIPs := make([]string, 0) + allIPs = append(allIPs, ipAddrs...) + for _, ip := range options.IpAddresses { + allIPs = append(allIPs, ip.String()) + } + + for _, defaultIP := range defaultIPSANs { + found := false + for _, existing := range allIPs { + if existing == defaultIP { + found = true + break + } + } + if !found { + allIPs = append(allIPs, defaultIP) + } + } + + if len(allIPs) > 0 { + ipSansValue := strings.Join(allIPs, ",") + secretValues["ip_sans"] = ipSansValue + } + + secret, err := agent.client.Logical().Write(agent.issueURL, secretValues) + if err != nil { + return certs.Certificate{}, err + } + + if secret == nil || secret.Data == nil { + return certs.Certificate{}, fmt.Errorf("no certificate data returned from OpenBao") + } + + cert := certs.Certificate{} + + if certData, ok := secret.Data["certificate"].(string); ok { + cert.Certificate = []byte(certData) + } + + if keyData, ok := secret.Data["private_key"].(string); ok { + cert.Key = []byte(keyData) + } + + if serialNumber, ok := secret.Data["serial_number"].(string); ok { + cert.SerialNumber = serialNumber + } + + if expirationInterface, ok := secret.Data["expiration"]; ok { + switch exp := expirationInterface.(type) { + case int64: + cert.ExpiryTime = time.Unix(exp, 0) + case float64: + cert.ExpiryTime = time.Unix(int64(exp), 0) + case json.Number: + if expInt, err := exp.Int64(); err == nil { + cert.ExpiryTime = time.Unix(expInt, 0) + } + } + } + + return cert, nil +} + +func (agent *openbaoPKIAgent) View(serialNumber string) (certs.Certificate, error) { + err := agent.LoginAndRenew() + if err != nil { + return certs.Certificate{}, err + } + + secret, err := agent.client.Logical().Read(fmt.Sprintf("%s%s", agent.readURL, serialNumber)) + if err != nil { + return certs.Certificate{}, err + } + + if secret == nil || secret.Data == nil { + return certs.Certificate{}, fmt.Errorf("certificate not found") + } + + cert := certs.Certificate{ + SerialNumber: serialNumber, + } + + if certData, ok := secret.Data["certificate"].(string); ok { + cert.Certificate = []byte(certData) + } + + cert.Revoked = false + if revokedTimeStr, ok := secret.Data["revocation_time_rfc3339"].(string); ok && revokedTimeStr != "" { + cert.Revoked = true + } + + if len(cert.Certificate) > 0 { + if expiry, err := agent.parseCertificateExpiry(string(cert.Certificate)); err == nil { + cert.ExpiryTime = expiry + } + } + return cert, nil +} + +func (agent *openbaoPKIAgent) parseCertificateExpiry(certPEM string) (time.Time, error) { + block, _ := pem.Decode([]byte(certPEM)) + if block == nil { + return time.Time{}, fmt.Errorf("failed to decode PEM certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse X509 certificate: %w", err) + } + + return cert.NotAfter, nil +} + +func (agent *openbaoPKIAgent) Renew(existingCert certs.Certificate, increment string) (certs.Certificate, error) { + err := agent.LoginAndRenew() + if err != nil { + return certs.Certificate{}, err + } + + block, _ := pem.Decode(existingCert.Certificate) + if block == nil { + return certs.Certificate{}, fmt.Errorf("failed to decode existing certificate PEM") + } + + x509Cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return certs.Certificate{}, fmt.Errorf("failed to parse existing certificate: %w", err) + } + + options := certs.SubjectOptions{ + CommonName: x509Cert.Subject.CommonName, + DnsNames: x509Cert.DNSNames, + } + + options.IpAddresses = append(options.IpAddresses, x509Cert.IPAddresses...) + + if len(x509Cert.Subject.Organization) > 0 { + options.Organization = x509Cert.Subject.Organization + } + if len(x509Cert.Subject.OrganizationalUnit) > 0 { + options.OrganizationalUnit = x509Cert.Subject.OrganizationalUnit + } + if len(x509Cert.Subject.Country) > 0 { + options.Country = x509Cert.Subject.Country + } + if len(x509Cert.Subject.Province) > 0 { + options.Province = x509Cert.Subject.Province + } + if len(x509Cert.Subject.Locality) > 0 { + options.Locality = x509Cert.Subject.Locality + } + if len(x509Cert.Subject.StreetAddress) > 0 { + options.StreetAddress = x509Cert.Subject.StreetAddress + } + if len(x509Cert.Subject.PostalCode) > 0 { + options.PostalCode = x509Cert.Subject.PostalCode + } + + var ipAddrs []string + for _, ip := range x509Cert.IPAddresses { + ipAddrs = append(ipAddrs, ip.String()) + } + + newCert, err := agent.Issue(increment, ipAddrs, options) + if err != nil { + return certs.Certificate{}, fmt.Errorf("failed to issue renewed certificate: %w", err) + } + + return newCert, nil +} + +func (agent *openbaoPKIAgent) Revoke(serialNumber string) error { + err := agent.LoginAndRenew() + if err != nil { + return err + } + + secretValues := map[string]any{ + "serial_number": serialNumber, + } + + _, err = agent.client.Logical().Write(agent.revokeURL, secretValues) + if err != nil { + return err + } + return nil +} + +func (agent *openbaoPKIAgent) ListCerts(pm certs.PageMetadata) (certs.CertificatePage, error) { + err := agent.LoginAndRenew() + if err != nil { + return certs.CertificatePage{}, err + } + + secret, err := agent.client.Logical().List(agent.certsURL) + if err != nil { + return certs.CertificatePage{}, err + } + + certPage := certs.CertificatePage{ + Certificates: []certs.Certificate{}, + PageMetadata: pm, + } + + if secret == nil || secret.Data == nil { + return certPage, nil + } + + keysInterface, ok := secret.Data["keys"] + if !ok { + return certPage, nil + } + + var serialNumbers []string + if err := mapstructure.Decode(keysInterface, &serialNumbers); err != nil { + return certPage, fmt.Errorf("failed to decode certificate serial numbers: %w", err) + } + + var allCerts []certs.Certificate + for _, serialNumber := range serialNumbers { + cert, err := agent.View(serialNumber) + if err != nil { + agent.logger.Warn("failed to retrieve certificate details", "serial", serialNumber, "error", err) + continue + } + + allCerts = append(allCerts, cert) + } + + certPage.Total = uint64(len(allCerts)) + + start := pm.Offset + end := pm.Offset + pm.Limit + if pm.Limit == 0 { + end = uint64(len(allCerts)) + } + if start >= uint64(len(allCerts)) { + return certPage, nil + } + if end > uint64(len(allCerts)) { + end = uint64(len(allCerts)) + } + + for i := start; i < end; i++ { + certPage.Certificates = append(certPage.Certificates, allCerts[i]) + } + + return certPage, nil +} + +func (agent *openbaoPKIAgent) LoginAndRenew() error { + agent.mu.RLock() + hasValidToken := agent.secret != nil && agent.secret.Auth != nil && agent.secret.Auth.ClientToken != "" + agent.mu.RUnlock() + + if hasValidToken { + _, err := agent.client.Auth().Token().LookupSelf() + if err == nil { + return nil + } + agent.logger.Warn("Token validation failed, re-authenticating", "error", err) + } + + agent.mu.RLock() + roleID := agent.appRole + secretID := agent.appSecret + agent.mu.RUnlock() + + authData := map[string]any{ + "role_id": roleID, + "secret_id": secretID, + } + + authResp, err := agent.client.Logical().Write("auth/approle/login", authData) + if err != nil { + return fmt.Errorf("%s: %w", errFailedToLogin, err) + } + + if authResp == nil || authResp.Auth == nil { + return errNoAuthInfo + } + + agent.mu.Lock() + agent.secret = authResp + agent.client.SetToken(authResp.Auth.ClientToken) + agent.mu.Unlock() + + if authResp.Auth.Renewable { + watcher, err := agent.client.NewLifetimeWatcher(&api.LifetimeWatcherInput{ + Secret: authResp, + }) + if err != nil { + return fmt.Errorf("%s: %w", errRenewWatcher, err) + } + + go agent.renewToken(watcher) + } + + return nil +} + +func (agent *openbaoPKIAgent) renewToken(watcher *api.LifetimeWatcher) { + defer watcher.Stop() + + watcher.Start() + for { + select { + case err := <-watcher.DoneCh(): + if err != nil { + agent.logger.Error("token renewal failed", "error", err) + } + return + case renewal := <-watcher.RenewCh(): + agent.logger.Info("token renewed successfully", "lease_duration", renewal.Secret.LeaseDuration) + } + } +} + +func (agent *openbaoPKIAgent) GetCA() ([]byte, error) { + err := agent.LoginAndRenew() + if err != nil { + return nil, err + } + + secret, err := agent.client.Logical().ReadRaw(agent.caURL) + if err != nil { + return nil, fmt.Errorf("failed to get CA certificate: %w", err) + } + defer secret.Body.Close() + + if secret.StatusCode != http.StatusOK { + body, _ := io.ReadAll(secret.Body) + return nil, fmt.Errorf("failed to get CA certificate: HTTP %d - %s", secret.StatusCode, string(body)) + } + + certData, err := io.ReadAll(secret.Body) + if err != nil { + return nil, fmt.Errorf("failed to read CA response: %w", err) + } + + if len(certData) == 0 { + return nil, fmt.Errorf("CA certificate response is empty - PKI may not be initialized") + } + + _, err = x509.ParseCertificate(certData) + if err != nil { + return nil, fmt.Errorf("failed to parse DER certificate: %w", err) + } + + pemBlock := &pem.Block{ + Type: "CERTIFICATE", + Bytes: certData, + } + + pemData := pem.EncodeToMemory(pemBlock) + return pemData, nil +} + +func (agent *openbaoPKIAgent) GetCAChain() ([]byte, error) { + err := agent.LoginAndRenew() + if err != nil { + return nil, err + } + + secret, err := agent.client.Logical().ReadRaw(agent.caChainURL) + if err != nil { + return nil, fmt.Errorf("failed to get CA chain: %w", err) + } + defer secret.Body.Close() + + if secret.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get CA chain: HTTP %d", secret.StatusCode) + } + + chainData, err := io.ReadAll(secret.Body) + if err != nil { + return nil, fmt.Errorf("failed to read CA chain response: %w", err) + } + + return chainData, nil +} + +func (agent *openbaoPKIAgent) GetCRL() ([]byte, error) { + err := agent.LoginAndRenew() + if err != nil { + return nil, err + } + + secret, err := agent.client.Logical().ReadRaw(agent.crlURL) + if err != nil { + return nil, fmt.Errorf("failed to get CRL: %w", err) + } + defer secret.Body.Close() + + if secret.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get CRL: HTTP %d", secret.StatusCode) + } + + crlData, err := io.ReadAll(secret.Body) + if err != nil { + return nil, fmt.Errorf("failed to read CRL response: %w", err) + } + + return crlData, nil +} + +func (agent *openbaoPKIAgent) SignCSR(csr []byte, ttl string) (certs.Certificate, error) { + err := agent.LoginAndRenew() + if err != nil { + return certs.Certificate{}, err + } + + block, _ := pem.Decode(csr) + if block == nil { + return certs.Certificate{}, fmt.Errorf("failed to decode CSR PEM") + } + + csrData, err := x509.ParseCertificateRequest(block.Bytes) + if err != nil { + return certs.Certificate{}, fmt.Errorf("failed to parse CSR: %w", err) + } + + secretValues := map[string]any{ + "csr": string(csr), + "ttl": ttl, + "use_csr_values": true, + } + + existingDNSNames := csrData.DNSNames + var existingIPs []string + for _, ip := range csrData.IPAddresses { + existingIPs = append(existingIPs, ip.String()) + } + + defaultDNSNames, defaultIPSANs, err := agent.getIntermediateCADefaultSANs() + if err != nil { + defaultDNSNames = []string{} + defaultIPSANs = []string{} + } + + allDNSNames := make([]string, 0) + allDNSNames = append(allDNSNames, existingDNSNames...) + + for _, defaultDNS := range defaultDNSNames { + found := false + for _, existing := range allDNSNames { + if existing == defaultDNS { + found = true + break + } + } + if !found { + allDNSNames = append(allDNSNames, defaultDNS) + } + } + + allIPs := make([]string, 0) + allIPs = append(allIPs, existingIPs...) + + for _, defaultIP := range defaultIPSANs { + found := false + for _, existing := range allIPs { + if existing == defaultIP { + found = true + break + } + } + if !found { + allIPs = append(allIPs, defaultIP) + } + } + + if len(allDNSNames) > 0 { + altNamesValue := strings.Join(allDNSNames, ",") + secretValues["alt_names"] = altNamesValue + } + + if len(allIPs) > 0 { + ipSansValue := strings.Join(allIPs, ",") + secretValues["ip_sans"] = ipSansValue + } + + path := agent.signVerbatimURL + + secret, err := agent.client.Logical().Write(path, secretValues) + if err != nil { + return certs.Certificate{}, err + } + + if secret == nil || secret.Data == nil { + return certs.Certificate{}, fmt.Errorf("no certificate data returned from OpenBao") + } + + cert := certs.Certificate{} + + if certData, ok := secret.Data["certificate"].(string); ok { + cert.Certificate = []byte(certData) + } + + if serialNumber, ok := secret.Data["serial_number"].(string); ok { + cert.SerialNumber = serialNumber + } + + if expirationInterface, ok := secret.Data["expiration"]; ok { + switch exp := expirationInterface.(type) { + case int64: + cert.ExpiryTime = time.Unix(exp, 0) + case float64: + cert.ExpiryTime = time.Unix(int64(exp), 0) + case json.Number: + if expInt, err := exp.Int64(); err == nil { + cert.ExpiryTime = time.Unix(expInt, 0) + } + } + } + + return cert, nil +} + +func (agent *openbaoPKIAgent) OCSP(serialNumber string, ocspRequestDER []byte) ([]byte, error) { + err := agent.LoginAndRenew() + if err != nil { + return nil, err + } + + var requestDER []byte + + if len(ocspRequestDER) > 0 { + requestDER = ocspRequestDER + } else { + issuerCert, err := agent.getIssuerCertificate() + if err != nil { + return nil, fmt.Errorf("failed to get issuer certificate for OCSP: %w", err) + } + + cert, err := agent.View(serialNumber) + if err != nil { + return nil, fmt.Errorf("failed to get certificate for OCSP: %w", err) + } + + block, _ := pem.Decode(cert.Certificate) + if block == nil { + return nil, fmt.Errorf("failed to decode certificate PEM") + } + + subject, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + requestDER, err = ocsp.CreateRequest(subject, issuerCert, &ocsp.RequestOptions{ + Hash: crypto.SHA1, + }) + if err != nil { + return nil, fmt.Errorf("failed to create OCSP request: %w", err) + } + } + + url := fmt.Sprintf("%s/v1/%s/ocsp", agent.host, agent.intermediatePath) + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(string(requestDER))) + if err != nil { + return nil, fmt.Errorf("failed to create OCSP POST request: %w", err) + } + + req.Header.Set("Content-Type", "application/ocsp-request") + if agent.secret != nil && agent.secret.Auth != nil && agent.secret.Auth.ClientToken != "" { + req.Header.Set("X-Vault-Token", agent.secret.Auth.ClientToken) + } + if agent.namespace != "" { + req.Header.Set("X-Vault-Namespace", agent.namespace) + } + + httpClient := agent.client.CloneConfig().HttpClient + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to query OpenBao OCSP: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("OCSP request failed: HTTP %d - %s", resp.StatusCode, string(body)) + } + + der, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read OCSP response: %w", err) + } + + _, err = ocsp.ParseResponse(der, nil) + if err != nil { + return nil, fmt.Errorf("invalid OCSP response from OpenBao: %w", err) + } + + return der, nil +} + +func (agent *openbaoPKIAgent) getIssuerCertificate() (*x509.Certificate, error) { + certData, err := agent.GetCA() + if err != nil { + return nil, fmt.Errorf("failed to get CA certificate: %w", err) + } + + block, _ := pem.Decode(certData) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM certificate") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse X509 certificate: %w", err) + } + + return cert, nil +} + +func (agent *openbaoPKIAgent) StartSecretRenewal(ctx context.Context) error { + if agent.serviceToken == "" { + agent.logger.Info("No service token provided, secret renewal is disabled") + return nil + } + + if err := agent.LoginAndRenew(); err != nil { + return fmt.Errorf("initial login failed: %w", err) + } + + if err := agent.lookupSecretMetadata(); err != nil { + agent.logger.Warn("Failed to lookup secret metadata, secret renewal may not work properly", "error", err) + } + + go agent.monitorSecretExpiration(ctx) + + return nil +} + +func (agent *openbaoPKIAgent) lookupSecretMetadata() error { + tempClient, err := api.NewClient(agent.client.CloneConfig()) + if err != nil { + return fmt.Errorf("failed to create temp client: %w", err) + } + if agent.namespace != "" { + tempClient.SetNamespace(agent.namespace) + } + tempClient.SetToken(agent.serviceToken) + + agent.mu.Lock() + agent.secretCreatedAt = time.Now().UTC() + agent.mu.Unlock() + + agent.logger.Info("Secret metadata initialized", "created_at", agent.secretCreatedAt, "ttl", agent.secretTTL.String()) + return nil +} + +func (agent *openbaoPKIAgent) monitorSecretExpiration(ctx context.Context) { + ticker := time.NewTicker(agent.secretCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + agent.logger.Info("Secret renewal monitoring stopped") + return + case <-ticker.C: + agent.mu.RLock() + createdAt := agent.secretCreatedAt + ttl := agent.secretTTL + agent.mu.RUnlock() + + if createdAt.IsZero() { + continue + } + + expiryTime := createdAt.Add(ttl) + timeUntilExpiry := time.Until(expiryTime) + + if timeUntilExpiry <= agent.renewThreshold { + agent.logger.Warn("Secret ID approaching expiration and will be renewed", "time_until_expiry", timeUntilExpiry, "renew_threshold", agent.renewThreshold, "expiry_time", expiryTime) + + if err := agent.renewSecretID(); err != nil { + agent.logger.Error("Failed to renew secret ID", "error", err) + continue + } + + agent.logger.Info("Successfully renewed secret ID") + } else { + agent.logger.Debug("Secret ID still valid", "time_until_expiry", timeUntilExpiry, "expiry_time", expiryTime) + } + } + } +} + +func (agent *openbaoPKIAgent) renewSecretID() error { + tempClient, err := api.NewClient(agent.client.CloneConfig()) + if err != nil { + return fmt.Errorf("failed to create temp client: %w", err) + } + if agent.namespace != "" { + tempClient.SetNamespace(agent.namespace) + } + tempClient.SetToken(agent.serviceToken) + + if err := agent.renewServiceToken(tempClient); err != nil { + agent.logger.Warn("Failed to renew service token", "error", err) + } + + secret, err := tempClient.Logical().Write(agent.secretIDPath, map[string]interface{}{}) + if err != nil { + return fmt.Errorf("failed to generate new secret ID: %w", err) + } + + if secret == nil || secret.Data == nil { + return fmt.Errorf("no data returned when generating secret ID") + } + + newSecretID, ok := secret.Data["secret_id"].(string) + if !ok || newSecretID == "" { + return fmt.Errorf("secret_id not found in response") + } + + secretAccessor, _ := secret.Data["secret_id_accessor"].(string) + + agent.mu.Lock() + agent.appSecret = newSecretID + agent.secretAccessor = secretAccessor + agent.secretCreatedAt = time.Now().UTC() + agent.secret = nil + agent.mu.Unlock() + + if err := agent.LoginAndRenew(); err != nil { + return fmt.Errorf("authentication failed with new secret: %w", err) + } + + return nil +} + +func (agent *openbaoPKIAgent) renewServiceToken(client *api.Client) error { + renewable, err := client.Auth().Token().RenewSelf(0) + if err != nil { + return fmt.Errorf("failed to renew service token: %w", err) + } + + if renewable == nil || renewable.Auth == nil { + return fmt.Errorf("no auth information returned from token renewal") + } + + return nil +} diff --git a/certs/postgres/certs.go b/certs/postgres/certs.go new file mode 100644 index 000000000..44453f209 --- /dev/null +++ b/certs/postgres/certs.go @@ -0,0 +1,121 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/postgres" + "github.com/jackc/pgx/v5/pgconn" +) + +// Postgres error codes: +// https://www.postgresql.org/docs/current/errcodes-appendix.html +const ( + errDuplicate = "23505" // unique_violation + errTruncation = "22001" // string_data_right_truncation + errFK = "23503" // foreign_key_violation + errInvalid = "22P02" // invalid_text_representation + errUntranslatable = "22P05" // untranslatable_character + errInvalidChar = "22021" // character_not_in_repertoire +) + +var ( + ErrConflict = errors.New("entity already exists") + ErrMalformedEntity = errors.New("malformed entity") + ErrCreateEntity = errors.New("failed to create entity") + ErrNotFound = errors.New("entity not found") +) + +type certsRepo struct { + db postgres.Database +} + +func NewRepository(db postgres.Database) certs.Repository { + return certsRepo{ + db: db, + } +} + +// SaveCertEntityMapping saves the mapping between certificate serial number and entity ID. +func (repo certsRepo) SaveCertEntityMapping(ctx context.Context, serialNumber, entityID string) error { + q := `INSERT INTO cert_entity_mappings (serial_number, entity_id) VALUES ($1, $2)` + _, err := repo.db.ExecContext(ctx, q, serialNumber, entityID) + if err != nil { + return handleError(ErrCreateEntity, err) + } + return nil +} + +// GetEntityIDBySerial retrieves the entity ID for a given certificate serial number. +func (repo certsRepo) GetEntityIDBySerial(ctx context.Context, serialNumber string) (string, error) { + q := `SELECT entity_id FROM cert_entity_mappings WHERE serial_number = $1` + var entityID string + if err := repo.db.QueryRowxContext(ctx, q, serialNumber).Scan(&entityID); err != nil { + if err == sql.ErrNoRows { + return "", errors.Wrap(ErrNotFound, err) + } + return "", handleError(ErrNotFound, err) + } + return entityID, nil +} + +// ListCertsByEntityID lists all certificate serial numbers for a given entity ID. +func (repo certsRepo) ListCertsByEntityID(ctx context.Context, entityID string) ([]string, error) { + q := `SELECT serial_number FROM cert_entity_mappings WHERE entity_id = $1 ORDER BY created_at DESC` + rows, err := repo.db.QueryContext(ctx, q, entityID) + if err != nil { + return nil, handleError(ErrNotFound, err) + } + defer rows.Close() + + var serialNumbers []string + for rows.Next() { + var serialNumber string + if err := rows.Scan(&serialNumber); err != nil { + return nil, handleError(ErrNotFound, err) + } + serialNumbers = append(serialNumbers, serialNumber) + } + + if err := rows.Err(); err != nil { + return nil, handleError(ErrNotFound, err) + } + + return serialNumbers, nil +} + +// RemoveCertEntityMapping removes the mapping between certificate and entity ID. +func (repo certsRepo) RemoveCertEntityMapping(ctx context.Context, serialNumber string) error { + q := `DELETE FROM cert_entity_mappings WHERE serial_number = $1` + result, err := repo.db.ExecContext(ctx, q, serialNumber) + if err != nil { + return handleError(ErrNotFound, err) + } + + if rows, _ := result.RowsAffected(); rows == 0 { + return ErrNotFound + } + + return nil +} + +func handleError(wrapper, err error) error { + pqErr, ok := err.(*pgconn.PgError) + if ok { + switch pqErr.Code { + case errDuplicate: + return errors.Wrap(ErrConflict, err) + case errInvalid, errInvalidChar, errTruncation, errUntranslatable: + return errors.Wrap(ErrMalformedEntity, err) + case errFK: + return errors.Wrap(ErrCreateEntity, err) + } + } + + return errors.Wrap(wrapper, err) +} diff --git a/certs/postgres/certs_test.go b/certs/postgres/certs_test.go new file mode 100644 index 000000000..ac6fd6577 --- /dev/null +++ b/certs/postgres/certs_test.go @@ -0,0 +1,244 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "testing" + + "github.com/absmach/supermq/certs/postgres" + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + entityID = "bfead30d-5a1d-40f3-be21-fd8ffad49db0" + serialNumber = "20:f4:bd:43:2c:c7:06:82:c7:f2:00:47:51:b6:81:6f:fa:c4:46:0c" +) + +func TestSaveCertEntityMapping(t *testing.T) { + repo := postgres.NewRepository(database) + + testCases := []struct { + desc string + serialNumber string + entityID string + err error + }{ + { + desc: "successful save", + serialNumber: serialNumber, + entityID: entityID, + err: nil, + }, + { + desc: "save duplicate mapping", + serialNumber: serialNumber, + entityID: entityID, + err: postgres.ErrConflict, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.SaveCertEntityMapping(context.Background(), tc.serialNumber, tc.entityID) + if tc.err != nil { + require.Error(t, err) + assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestGetEntityIDBySerial(t *testing.T) { + repo := postgres.NewRepository(database) + + // Setup: save a mapping first + testSerial := "test-serial-456" + testEntityID := "test-entity-789" + err := repo.SaveCertEntityMapping(context.Background(), testSerial, testEntityID) + require.NoError(t, err) + + testCases := []struct { + desc string + serialNumber string + expectedID string + err error + }{ + { + desc: "successful retrieval", + serialNumber: testSerial, + expectedID: testEntityID, + err: nil, + }, + { + desc: "serial number not found", + serialNumber: "non-existent-serial", + expectedID: "", + err: postgres.ErrNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + entityID, err := repo.GetEntityIDBySerial(context.Background(), tc.serialNumber) + if tc.err != nil { + require.Error(t, err) + assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + assert.Empty(t, entityID) + } else { + require.NoError(t, err) + assert.Equal(t, tc.expectedID, entityID) + } + }) + } +} + +func TestListCertsByEntityID(t *testing.T) { + repo := postgres.NewRepository(database) + + // Setup: save multiple mappings for the same entity + testEntityID := "test-entity-list" + testSerials := []string{"serial-1", "serial-2", "serial-3"} + + for _, serial := range testSerials { + err := repo.SaveCertEntityMapping(context.Background(), serial, testEntityID) + require.NoError(t, err) + } + + testCases := []struct { + desc string + entityID string + expectedCount int + expectedContains []string + err error + }{ + { + desc: "successful list with multiple certs", + entityID: testEntityID, + expectedCount: 3, + expectedContains: testSerials, + err: nil, + }, + { + desc: "entity with no certificates", + entityID: "non-existent-entity", + expectedCount: 0, + expectedContains: []string{}, + err: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + serials, err := repo.ListCertsByEntityID(context.Background(), tc.entityID) + if tc.err != nil { + require.Error(t, err) + assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + assert.Len(t, serials, tc.expectedCount) + + if tc.expectedCount > 0 { + for _, expectedSerial := range tc.expectedContains { + assert.Contains(t, serials, expectedSerial) + } + } + } + }) + } +} + +func TestRemoveCertEntityMapping(t *testing.T) { + repo := postgres.NewRepository(database) + + testSerial := "test-serial-remove" + testEntityID := "test-entity-remove" + err := repo.SaveCertEntityMapping(context.Background(), testSerial, testEntityID) + require.NoError(t, err) + + testCases := []struct { + desc string + serialNumber string + err error + }{ + { + desc: "successful removal", + serialNumber: testSerial, + err: nil, + }, + { + desc: "remove non-existent mapping", + serialNumber: "non-existent-serial", + err: postgres.ErrNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.RemoveCertEntityMapping(context.Background(), tc.serialNumber) + if tc.err != nil { + require.Error(t, err) + assert.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) + } else { + require.NoError(t, err) + + // Verify the mapping was actually removed + _, err := repo.GetEntityIDBySerial(context.Background(), tc.serialNumber) + assert.True(t, errors.Contains(err, postgres.ErrNotFound)) + } + }) + } +} + +func TestCertEntityMappingWorkflow(t *testing.T) { + repo := postgres.NewRepository(database) + + // Test complete workflow: save -> get -> list -> remove + entityID := "workflow-entity" + serials := []string{"workflow-serial-1", "workflow-serial-2"} + + // Save mappings + for _, serial := range serials { + err := repo.SaveCertEntityMapping(context.Background(), serial, entityID) + require.NoError(t, err) + } + + // Verify we can get entity IDs by serial + for _, serial := range serials { + retrievedID, err := repo.GetEntityIDBySerial(context.Background(), serial) + require.NoError(t, err) + assert.Equal(t, entityID, retrievedID) + } + + // Verify we can list all serials for the entity + listedSerials, err := repo.ListCertsByEntityID(context.Background(), entityID) + require.NoError(t, err) + assert.Len(t, listedSerials, 2) + for _, serial := range serials { + assert.Contains(t, listedSerials, serial) + } + + // Remove one mapping + err = repo.RemoveCertEntityMapping(context.Background(), serials[0]) + require.NoError(t, err) + + // Verify it's removed + _, err = repo.GetEntityIDBySerial(context.Background(), serials[0]) + assert.True(t, errors.Contains(err, postgres.ErrNotFound)) + + // Verify the other mapping still exists + retrievedID, err := repo.GetEntityIDBySerial(context.Background(), serials[1]) + require.NoError(t, err) + assert.Equal(t, entityID, retrievedID) + + // Verify list now shows only one serial + listedSerials, err = repo.ListCertsByEntityID(context.Background(), entityID) + require.NoError(t, err) + assert.Len(t, listedSerials, 1) + assert.Contains(t, listedSerials, serials[1]) +} diff --git a/certs/postgres/init.go b/certs/postgres/init.go new file mode 100644 index 000000000..9b1e6e73f --- /dev/null +++ b/certs/postgres/init.go @@ -0,0 +1,32 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + _ "github.com/jackc/pgx/v5/stdlib" + migrate "github.com/rubenv/sql-migrate" +) + +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "certs_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS cert_entity_mappings ( + serial_number VARCHAR(255) UNIQUE NOT NULL, + entity_id VARCHAR(255) NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (serial_number) + )`, + `CREATE INDEX IF NOT EXISTS idx_cert_entity_mappings_entity_id ON cert_entity_mappings(entity_id)`, + }, + Down: []string{ + "DROP INDEX IF EXISTS idx_cert_entity_mappings_entity_id", + "DROP TABLE cert_entity_mappings", + }, + }, + }, + } +} diff --git a/certs/postgres/setup_test.go b/certs/postgres/setup_test.go new file mode 100644 index 000000000..947551e5d --- /dev/null +++ b/certs/postgres/setup_test.go @@ -0,0 +1,98 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/absmach/supermq/certs/postgres" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database pgclient.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + mig := postgres.Migration() + if db, err = pgclient.Setup(dbConfig, *mig); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + if db, err = pgclient.Connect(dbConfig); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = pgclient.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/certs/service.go b/certs/service.go new file mode 100644 index 000000000..c1b7c49d6 --- /dev/null +++ b/certs/service.go @@ -0,0 +1,310 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package certs + +import ( + "context" + "crypto/x509" + "encoding/pem" + "fmt" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" +) + +const ( + PrivateKeyBytes = 2048 + RootCAValidityPeriod = time.Hour * 24 * 365 + IntermediateCAValidityPeriod = time.Hour * 24 * 90 + certValidityPeriod = time.Hour * 24 * 30 + PrivateKey = "PRIVATE KEY" + RSAPrivateKey = "RSA PRIVATE KEY" + ECPrivateKey = "EC PRIVATE KEY" + PKCS8PrivateKey = "PKCS8 PRIVATE KEY" + EDPrivateKey = "ED25519 PRIVATE KEY" +) + +var ( + ErrNotFound = errors.New("entity not found") + ErrConflict = errors.New("entity already exists") + ErrCreateEntity = errors.New("failed to create entity") + ErrViewEntity = errors.New("view entity failed") + ErrUpdateEntity = errors.New("update entity failed") + ErrDeleteEntity = errors.New("delete entity failed") + ErrMalformedEntity = errors.New("malformed entity specification") + ErrRootCANotFound = errors.New("root CA not found") + ErrIntermediateCANotFound = errors.New("intermediate CA not found") + ErrCertExpired = errors.New("certificate expired before renewal") + ErrCertRevoked = errors.New("certificate has been revoked and cannot be renewed") + ErrCertInvalidType = errors.New("invalid cert type") + ErrInvalidLength = errors.New("invalid length of serial numbers") + ErrPrivKeyType = errors.New("unsupported private key type") + ErrPubKeyType = errors.New("unsupported public key type") + ErrFailedParse = errors.New("failed to parse key PEM") + ErrFailedCertCreation = errors.New("failed to create certificate") + ErrInvalidIP = errors.New("invalid IP address") +) + +type service struct { + pki Agent + repo Repository +} + +var _ Service = (*service)(nil) + +func NewService(ctx context.Context, pki Agent, repo Repository) (Service, error) { + var svc service + + svc.pki = pki + svc.repo = repo + + return &svc, nil +} + +// IssueCert generates and issues a certificate for a given entityID. +// It uses the PKI agent to generate and issue a certificate. +// The certificate is managed by OpenBao PKI internally. +// EntityType is used to customize certificate properties based on the entity type. +func (s *service) IssueCert(ctx context.Context, session authn.Session, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { + cert, err := s.pki.Issue(ttl, ipAddrs, options) + if err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + if err := s.repo.SaveCertEntityMapping(ctx, cert.SerialNumber, entityID); err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + cert.EntityID = entityID + + return cert, nil +} + +func (s *service) ListCerts(ctx context.Context, session authn.Session, pm PageMetadata) (CertificatePage, error) { + if pm.EntityID != "" { + serialNumbers, err := s.repo.ListCertsByEntityID(ctx, pm.EntityID) + if err != nil { + return CertificatePage{}, errors.Wrap(ErrViewEntity, err) + } + + certPg := CertificatePage{ + PageMetadata: pm, + Certificates: make([]Certificate, 0), + } + + start := pm.Offset + end := pm.Offset + pm.Limit + if pm.Limit == 0 { + end = uint64(len(serialNumbers)) + } + if start >= uint64(len(serialNumbers)) { + return certPg, nil + } + if end > uint64(len(serialNumbers)) { + end = uint64(len(serialNumbers)) + } + + for i := start; i < end; i++ { + cert, err := s.pki.View(serialNumbers[i]) + if err != nil { + continue + } + cert.EntityID = pm.EntityID + certPg.Certificates = append(certPg.Certificates, cert) + } + + certPg.Total = uint64(len(serialNumbers)) + return certPg, nil + } + + certPg, err := s.pki.ListCerts(pm) + if err != nil { + return CertificatePage{}, errors.Wrap(ErrViewEntity, err) + } + + for i, cert := range certPg.Certificates { + if entityID, err := s.repo.GetEntityIDBySerial(ctx, cert.SerialNumber); err == nil { + certPg.Certificates[i].EntityID = entityID + } + } + + return certPg, nil +} + +func (s *service) RevokeBySerial(ctx context.Context, session authn.Session, serialNumber string) error { + err := s.pki.Revoke(serialNumber) + if err != nil { + return errors.Wrap(ErrUpdateEntity, err) + } + return nil +} + +// RevokeAll revokes all certificates for a given entity ID. +// It uses the repository to find all certificates for the entity ID, then revokes each one. +func (s *service) RevokeAll(ctx context.Context, session authn.Session, entityID string) error { + serialNumbers, err := s.repo.ListCertsByEntityID(ctx, entityID) + if err != nil { + return errors.Wrap(ErrViewEntity, err) + } + + if len(serialNumbers) == 0 { + return errors.Wrap(ErrNotFound, fmt.Errorf("no certificates found for entity ID: %s", entityID)) + } + + for _, serialNumber := range serialNumbers { + if err := s.pki.Revoke(serialNumber); err != nil { + return errors.Wrap(ErrUpdateEntity, err) + } + if err := s.repo.RemoveCertEntityMapping(ctx, serialNumber); err != nil { + return errors.Wrap(ErrDeleteEntity, err) + } + } + + return nil +} + +func (s *service) ViewCert(ctx context.Context, session authn.Session, serialNumber string) (Certificate, error) { + cert, err := s.pki.View(serialNumber) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + if entityID, err := s.repo.GetEntityIDBySerial(ctx, serialNumber); err == nil { + cert.EntityID = entityID + } + + return cert, nil +} + +func (s *service) ViewCA(ctx context.Context) (Certificate, error) { + caPEM, err := s.pki.GetCA() + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + if len(caPEM) == 0 { + return Certificate{}, errors.New("CA certificate PEM is empty") + } + + block, _ := pem.Decode(caPEM) + if block == nil { + caPreview := string(caPEM) + if len(caPreview) > 100 { + caPreview = caPreview[:100] + "..." + } + return Certificate{}, errors.New("failed to decode CA certificate PEM - received: " + caPreview) + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + return Certificate{ + SerialNumber: cert.SerialNumber.String(), + Certificate: caPEM, + Key: nil, + Revoked: false, + ExpiryTime: cert.NotAfter, + EntityID: cert.Subject.CommonName, + Type: IntermediateCA, + }, nil +} + +// RenewCert renews a certificate by issuing a new certificate with the same parameters. +// Returns the new certificate with extended TTL and a new serial number. +func (s *service) RenewCert(ctx context.Context, session authn.Session, serialNumber string) (Certificate, error) { + cert, err := s.pki.View(serialNumber) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + if cert.Revoked { + return Certificate{}, ErrCertRevoked + } + newCert, err := s.pki.Renew(cert, certValidityPeriod.String()) + if err != nil { + return Certificate{}, errors.Wrap(ErrUpdateEntity, err) + } + + return newCert, nil +} + +// OCSP forwards OCSP requests to OpenBao's OCSP endpoint. +// If ocspRequestDER is provided, it will be used directly; otherwise, a request will be built from the serialNumber. +func (s *service) OCSP(ctx context.Context, serialNumber string, ocspRequestDER []byte) ([]byte, error) { + return s.pki.OCSP(serialNumber, ocspRequestDER) +} + +func (s *service) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + entityID, err := s.repo.GetEntityIDBySerial(ctx, serialNumber) + if err != nil { + return "", errors.Wrap(ErrViewEntity, err) + } + return entityID, nil +} + +func (s *service) GenerateCRL(ctx context.Context) ([]byte, error) { + crl, err := s.pki.GetCRL() + if err != nil { + return nil, errors.Wrap(ErrFailedCertCreation, err) + } + return crl, nil +} + +func (s *service) RetrieveCAChain(ctx context.Context) (Certificate, error) { + return s.getConcatCAs(ctx) +} + +func (s *service) IssueFromCSR(ctx context.Context, session authn.Session, entityID, ttl string, csr CSR) (Certificate, error) { + cert, err := s.pki.SignCSR(csr.CSR, ttl) + if err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + if err := s.repo.SaveCertEntityMapping(ctx, cert.SerialNumber, entityID); err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + cert.EntityID = entityID + + return cert, nil +} + +func (s *service) IssueFromCSRInternal(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) { + cert, err := s.pki.SignCSR(csr.CSR, ttl) + if err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + if err := s.repo.SaveCertEntityMapping(ctx, cert.SerialNumber, entityID); err != nil { + return Certificate{}, errors.Wrap(ErrFailedCertCreation, err) + } + + cert.EntityID = entityID + + return cert, nil +} + +func (s *service) getConcatCAs(_ context.Context) (Certificate, error) { + caChain, err := s.pki.GetCAChain() + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + block, _ := pem.Decode(caChain) + if block == nil { + return Certificate{}, errors.New("failed to decode CA chain PEM") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + return Certificate{ + Certificate: caChain, + ExpiryTime: cert.NotAfter, + }, nil +} diff --git a/channels/README.md b/channels/README.md index a61c7e15a..85cbff7ed 100644 --- a/channels/README.md +++ b/channels/README.md @@ -8,22 +8,22 @@ The service is configured using the following environment variables (unset varia | Variable | Description | Default | |-----------------------------|--------------------------------------------------------------|-------------| -| `SMQ_CHANNELS_LOG_LEVEL` | Log level (debug, info, warn, error) | info | -| `SMQ_CHANNELS_HTTP_HOST` | HTTP host for Channels service | localhost | -| `SMQ_CHANNELS_HTTP_PORT` | HTTP port for Channels service | 9005 | -| `SMQ_CHANNELS_SERVER_CERT` | Path to PEM encoded server certificate | "" | -| `SMQ_CHANNELS_SERVER_KEY` | Path to PEM encoded server key file | "" | -| `SMQ_CHANNELS_GRPC_HOST` | gRPC host for Channels service | localhost | -| `SMQ_CHANNELS_GRPC_PORT` | gRPC port for Channels service | 7005 | -| `SMQ_CHANNELS_DB_HOST` | Database host address | localhost | -| `SMQ_CHANNELS_DB_PORT` | Database port | 5432 | -| `SMQ_CHANNELS_DB_USER` | Database user | supermq | -| `SMQ_CHANNELS_DB_PASS` | Database password | supermq | -| `SMQ_CHANNELS_DB_NAME` | Name of the database used by the service | channels | -| `SMQ_CHANNELS_DB_SSL_MODE` | Database connection SSL mode | disable | -| `SMQ_CHANNELS_CACHE_URL` | Cache database URL | | -| `SMQ_JAEGER_URL` | Jaeger tracing server URL | | -| `SMQ_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true | +| `MG_CHANNELS_LOG_LEVEL` | Log level (debug, info, warn, error) | info | +| `MG_CHANNELS_HTTP_HOST` | HTTP host for Channels service | localhost | +| `MG_CHANNELS_HTTP_PORT` | HTTP port for Channels service | 9005 | +| `MG_CHANNELS_SERVER_CERT` | Path to PEM encoded server certificate | "" | +| `MG_CHANNELS_SERVER_KEY` | Path to PEM encoded server key file | "" | +| `MG_CHANNELS_GRPC_HOST` | gRPC host for Channels service | localhost | +| `MG_CHANNELS_GRPC_PORT` | gRPC port for Channels service | 7005 | +| `MG_CHANNELS_DB_HOST` | Database host address | localhost | +| `MG_CHANNELS_DB_PORT` | Database port | 5432 | +| `MG_CHANNELS_DB_USER` | Database user | supermq | +| `MG_CHANNELS_DB_PASS` | Database password | supermq | +| `MG_CHANNELS_DB_NAME` | Name of the database used by the service | channels | +| `MG_CHANNELS_DB_SSL_MODE` | Database connection SSL mode | disable | +| `MG_CHANNELS_CACHE_URL` | Cache database URL | | +| `MG_JAEGER_URL` | Jaeger tracing server URL | | +| `MG_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true | ## Features @@ -86,13 +86,13 @@ make channels make install -SMQ_CHANNELS_HTTP_HOST=localhost \ -SMQ_CHANNELS_HTTP_PORT=9005 \ -SMQ_CHANNELS_DB_HOST=localhost \ -SMQ_CHANNELS_DB_PORT=5432 \ -SMQ_CHANNELS_DB_USER=supermq \ -SMQ_CHANNELS_DB_PASS=supermq \ -SMQ_CHANNELS_DB_NAME=channels \ +MG_CHANNELS_HTTP_HOST=localhost \ +MG_CHANNELS_HTTP_PORT=9005 \ +MG_CHANNELS_DB_HOST=localhost \ +MG_CHANNELS_DB_PORT=5432 \ +MG_CHANNELS_DB_USER=supermq \ +MG_CHANNELS_DB_PASS=supermq \ +MG_CHANNELS_DB_NAME=channels \ $GOBIN/supermq-channels ``` diff --git a/channels/events/streams.go b/channels/events/streams.go index 6160f0c5c..0ab1f1839 100644 --- a/channels/events/streams.go +++ b/channels/events/streams.go @@ -44,7 +44,7 @@ type eventStore struct { // NewEventStoreMiddleware returns wrapper around clients service that sends // events to event store. func NewEventStoreMiddleware(ctx context.Context, svc channels.Service, url string) (channels.Service, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "channels-es-pub") if err != nil { return nil, err } diff --git a/channels/middleware/authorization.go b/channels/middleware/authorization.go index c68b1f41a..a33c7d672 100644 --- a/channels/middleware/authorization.go +++ b/channels/middleware/authorization.go @@ -362,7 +362,7 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }, nil); err != nil { return err } diff --git a/channels/postgres/channels.go b/channels/postgres/channels.go index 1e3bd0c7d..43b44a0d8 100644 --- a/channels/postgres/channels.go +++ b/channels/postgres/channels.go @@ -500,10 +500,10 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u FROM final_channels c ` + connCountJoinQuery := connJoinQuery if pm.Client != "" { - connJoinQuery = ` - ,conn.connection_types + connCountJoinQuery = ` FROM final_channels c LEFT JOIN ( @@ -517,6 +517,9 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u conn.client_id, conn.channel_id ) conn ON c.id = conn.channel_id ` + connJoinQuery = ` + ,conn.connection_types + ` + connCountJoinQuery } dbPage, err := toDBChannelsPage(pm) @@ -529,9 +532,9 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u if pm.OnlyTotal { cq := fmt.Sprintf(`%s SELECT COUNT(*) AS total_count - FROM final_channels c + %s %s; - `, bq, pageQuery) + `, bq, connCountJoinQuery, pageQuery) total, err := postgres.Total(ctx, repo.db, cq, dbPage) if err != nil { @@ -607,9 +610,9 @@ func (repo *channelRepository) retrieveChannels(ctx context.Context, domainID, u if len(items) == 0 { cq := fmt.Sprintf(`%s SELECT COUNT(*) AS total_count - FROM final_channels c + %s %s; - `, bq, pageQuery) + `, bq, connCountJoinQuery, pageQuery) total, err = postgres.Total(ctx, repo.db, cq, dbPage) if err != nil { diff --git a/cli/bootstrap.go b/cli/bootstrap.go new file mode 100644 index 000000000..1fbd1584e --- /dev/null +++ b/cli/bootstrap.go @@ -0,0 +1,216 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "encoding/json" + + mgsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/spf13/cobra" +) + +var cmdBootstrap = []cobra.Command{ + { + Use: "create ", + Short: "Create config", + Long: `Create new Client Bootstrap Config to the user identified by the provided key`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + var cfg mgsdk.BootstrapConfig + if err := json.Unmarshal([]byte(args[0]), &cfg); err != nil { + logErrorCmd(*cmd, err) + return + } + + id, err := sdk.AddBootstrap(cmd.Context(), cfg, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logCreatedCmd(*cmd, id) + }, + }, + { + Use: "get [all | ] ", + Short: "Get config", + Long: `Get Client Config with given ID belonging to the user identified by the given key. + all - lists all config + - view config of `, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + pageMetadata := mgsdk.PageMetadata{ + Offset: Offset, + Limit: Limit, + State: State, + Name: Name, + } + if args[0] == all { + l, err := sdk.Bootstraps(cmd.Context(), pageMetadata, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, l) + return + } + + c, err := sdk.ViewBootstrap(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, c) + }, + }, + { + Use: "update [config | connection | certs ] ", + Short: "Update config", + Long: `Updates editable fields of the provided Config. + config - Updates editable fields of the provided Config. + connection - Updates connections performs update of the channel list corresponding Client is connected to. + channel_ids - '["channel_id1", ...]' + certs - Update bootstrap config certificates.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 4 { + logUsageCmd(*cmd, cmd.Use) + return + } + if args[0] == "config" { + var cfg mgsdk.BootstrapConfig + if err := json.Unmarshal([]byte(args[1]), &cfg); err != nil { + logErrorCmd(*cmd, err) + return + } + + if err := sdk.UpdateBootstrap(cmd.Context(), cfg, args[1], args[2]); err != nil { + logErrorCmd(*cmd, err) + return + } + + logOKCmd(*cmd) + return + } + if args[0] == "connection" { + var ids []string + if err := json.Unmarshal([]byte(args[2]), &ids); err != nil { + logErrorCmd(*cmd, err) + return + } + if err := sdk.UpdateBootstrapConnection(cmd.Context(), args[1], ids, args[3], args[4]); err != nil { + logErrorCmd(*cmd, err) + return + } + + logOKCmd(*cmd) + return + } + if args[0] == "certs" { + cfg, err := sdk.UpdateBootstrapCerts(cmd.Context(), args[0], args[1], args[2], args[3], args[4], args[5]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, cfg) + return + } + logUsageCmd(*cmd, cmd.Use) + }, + }, + { + Use: "remove ", + Short: "Remove config", + Long: `Removes Config with specified key that belongs to the user identified by the given key`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + if err := sdk.RemoveBootstrap(cmd.Context(), args[0], args[1], args[2]); err != nil { + logErrorCmd(*cmd, err) + return + } + + logOKCmd(*cmd) + }, + }, + { + Use: "bootstrap [ | secure ]", + Short: "Bootstrap config", + Long: `Returns Config to the Client with provided external ID using external key. + secure - Retrieves a configuration with given external ID and encrypted external key.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + if args[0] == "secure" { + c, err := sdk.BootstrapSecure(cmd.Context(), args[1], args[2], args[3]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, c) + return + } + c, err := sdk.Bootstrap(cmd.Context(), args[0], args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, c) + }, + }, + { + Use: "whitelist ", + Short: "Whitelist config", + Long: `Whitelist updates client state config with given id from the authenticated user`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + var cfg mgsdk.BootstrapConfig + if err := json.Unmarshal([]byte(args[0]), &cfg); err != nil { + logErrorCmd(*cmd, err) + return + } + + if err := sdk.Whitelist(cmd.Context(), cfg.ClientID, cfg.State, args[1], args[2]); err != nil { + logErrorCmd(*cmd, err) + return + } + + logOKCmd(*cmd) + }, + }, +} + +// NewBootstrapCmd returns bootstrap command. +func NewBootstrapCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "bootstrap [create | get | update | remove | bootstrap | whitelist]", + Short: "Bootstrap management", + Long: `Bootstrap management: create, get, update, delete or whitelist Bootstrap config`, + } + + for i := range cmdBootstrap { + cmd.AddCommand(&cmdBootstrap[i]) + } + + return &cmd +} diff --git a/cli/bootstrap_test.go b/cli/bootstrap_test.go new file mode 100644 index 000000000..5174d368c --- /dev/null +++ b/cli/bootstrap_test.go @@ -0,0 +1,633 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli_test + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/absmach/supermq/cli" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + mgsdk "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + clientID = testsutil.GenerateUUID(&testing.T{}) + channelID = testsutil.GenerateUUID(&testing.T{}) + domainID = testsutil.GenerateUUID(&testing.T{}) + bootConfig = mgsdk.BootstrapConfig{ + ClientID: clientID, + Channels: []string{channelID}, + Name: "Test Bootstrap", + ExternalID: "09:6:0:sb:sa", + ExternalKey: "key", + } + validToken = "validToken" + invalidToken = "invalidToken" + extraArg = "extra-arg" + invalidID = "invalidID" + all = "all" +) + +func TestCreateBootstrapConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + jsonConfig := fmt.Sprintf("{\"external_id\":\"09:6:0:sb:sa\", \"client_id\": \"%s\", \"external_key\":\"key\", \"name\": \"%s\", \"channels\":[\"%s\"]}", clientID, "Test Bootstrap", channelID) + invalidJson := fmt.Sprintf("{\"external_id\":\"09:6:0:sb:sa\", \"client_id\": \"%s\", \"external_key\":\"key\", \"name\": \"%s\", \"channels\":[\"%s\"]", clientID, "Test Bootstrap", channelID) + cases := []struct { + desc string + args []string + logType outputLog + response string + sdkErr errors.SDKError + errLogMessage string + id string + }{ + { + desc: "create bootstrap config successfully", + args: []string{ + jsonConfig, + domainID, + validToken, + }, + logType: createLog, + id: clientID, + response: fmt.Sprintf("\ncreated: %s\n\n", clientID), + }, + { + desc: "create bootstrap config with invald args", + args: []string{ + jsonConfig, + domainID, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "create bootstrap config with invald json", + args: []string{ + invalidJson, + domainID, + validToken, + }, + sdkErr: errors.NewSDKError(errors.New("unexpected end of JSON input")), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")), + logType: errLog, + }, + { + desc: "create bootstrap config with invald token", + args: []string{ + jsonConfig, + domainID, + invalidToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("AddBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.id, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) + + switch tc.logType { + case createLog: + assert.Equal(t, tc.response, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.response, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestGetBootstrapConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + var boot mgsdk.BootstrapConfig + var page mgsdk.BootstrapPage + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + page mgsdk.BootstrapPage + boot mgsdk.BootstrapConfig + logType outputLog + errLogMessage string + }{ + { + desc: "get all bootstrap config successfully", + args: []string{ + all, + domainID, + validToken, + }, + page: mgsdk.BootstrapPage{ + PageRes: mgsdk.PageRes{ + Total: 1, + Offset: 0, + Limit: 10, + }, + Configs: []mgsdk.BootstrapConfig{bootConfig}, + }, + logType: entityLog, + }, + { + desc: "get bootstrap config with id", + args: []string{ + channelID, + domainID, + validToken, + }, + logType: entityLog, + boot: bootConfig, + }, + { + desc: "get bootstrap config with invalid args", + args: []string{ + all, + domainID, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "get all bootstrap config with invalid token", + args: []string{ + all, + domainID, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + { + desc: "get bootstrap config with invalid id", + args: []string{ + invalidID, + domainID, + validToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("ViewBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstraps", mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.page, tc.sdkErr) + + out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) + + switch tc.logType { + case entityLog: + if tc.args[0] == all { + err := json.Unmarshal([]byte(out), &page) + assert.Nil(t, err) + assert.Equal(t, tc.page, page, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.page, page)) + } else { + err := json.Unmarshal([]byte(out), &boot) + assert.Nil(t, err) + assert.Equal(t, tc.boot, boot, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.boot, boot)) + } + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + sdkCall1.Unset() + }) + } +} + +func TestRemoveBootstrapConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + logType outputLog + errLogMessage string + }{ + { + desc: "remove bootstrap config successfully", + args: []string{ + clientID, + domainID, + validToken, + }, + logType: okLog, + }, + { + desc: "remove bootstrap config with invalid args", + args: []string{ + clientID, + domainID, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "remove bootstrap config with invalid client id", + args: []string{ + invalidID, + domainID, + validToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + { + desc: "remove bootstrap config with invalid token", + args: []string{ + clientID, + domainID, + invalidToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("RemoveBootstrap", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) + + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestUpdateBootstrapConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + config := "config" + connection := "connection" + + newConfigJson := "{\"name\" : \"New Bootstrap\"}" + chanIDsJson := fmt.Sprintf("[\"%s\"]", channelID) + cases := []struct { + desc string + args []string + boot mgsdk.BootstrapConfig + sdkErr errors.SDKError + errLogMessage string + logType outputLog + }{ + { + desc: "update bootstrap config successfully", + args: []string{ + config, + newConfigJson, + domainID, + validToken, + }, + logType: okLog, + }, + { + desc: "update bootstrap config with invalid token", + args: []string{ + config, + newConfigJson, + domainID, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + { + desc: "update bootstrap connections successfully", + args: []string{ + connection, + clientID, + chanIDsJson, + domainID, + validToken, + }, + logType: okLog, + }, + { + desc: "update bootstrap connections with invalid json", + args: []string{ + connection, + clientID, + fmt.Sprintf("[\"%s\"", clientID), + domainID, + validToken, + }, + sdkErr: errors.NewSDKError(errors.New("unexpected end of JSON input")), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")), + logType: errLog, + }, + { + desc: "update bootstrap connections with invalid token", + args: []string{ + connection, + clientID, + chanIDsJson, + domainID, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + { + desc: "update bootstrap certs successfully", + args: []string{ + "certs", + clientID, + "client cert", + "client key", + "ca", + domainID, + validToken, + }, + boot: bootConfig, + logType: entityLog, + }, + { + desc: "update bootstrap certs with invalid token", + args: []string{ + "certs", + clientID, + "client cert", + "client key", + "ca", + domainID, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + { + desc: "update bootstrap config with invalid args", + args: []string{ + newConfigJson, + domainID, + validToken, + }, + logType: usageLog, + }, + { + desc: "update bootstrap config with invalid json", + args: []string{ + config, + "{\"name\" : \"New Bootstrap\"", + domainID, + validToken, + }, + sdkErr: errors.NewSDKError(errors.New("unexpected end of JSON input")), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")), + logType: errLog, + }, + { + desc: "update bootstrap with invalid args", + args: []string{ + extraArg, + extraArg, + extraArg, + extraArg, + extraArg, + }, + logType: usageLog, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var boot mgsdk.BootstrapConfig + sdkCall := sdkMock.On("UpdateBootstrap", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall1 := sdkMock.On("UpdateBootstrapConnection", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + sdkCall2 := sdkMock.On("UpdateBootstrapCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{updCmd}, tc.args...)...) + + switch tc.logType { + case entityLog: + err := json.Unmarshal([]byte(out), &boot) + assert.Nil(t, err) + assert.Equal(t, tc.boot, boot, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.boot, boot)) + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + sdkCall1.Unset() + sdkCall2.Unset() + }) + } +} + +func TestWhitelistConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + jsonConfig := fmt.Sprintf("{\"client_id\": \"%s\", \"state\":%d}", clientID, 1) + + cases := []struct { + desc string + args []string + logType outputLog + errLogMessage string + sdkErr errors.SDKError + }{ + { + desc: "whitelist config successfully", + args: []string{ + jsonConfig, + domainID, + validToken, + }, + logType: okLog, + }, + { + desc: "whitelist config with invalid args", + args: []string{ + jsonConfig, + domainID, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "whitelist config with invalid json", + args: []string{ + fmt.Sprintf("{\"client_id\": \"%s\", \"state\":%d", clientID, 1), + domainID, + validToken, + }, + sdkErr: errors.NewSDKError(errors.New("unexpected end of JSON input")), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.New("unexpected end of JSON input")), + logType: errLog, + }, + { + desc: "whitelist config with invalid token", + args: []string{ + jsonConfig, + domainID, + invalidToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("Whitelist", mock.Anything, mock.Anything, mock.Anything, tc.args[1], tc.args[2]).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{whitelistCmd}, tc.args...)...) + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + +func TestBootstrapConfigCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + bootCmd := cli.NewBootstrapCmd() + rootCmd := setFlags(bootCmd) + + var boot mgsdk.BootstrapConfig + cryptoKey := "v7aT0HGxJxt2gULzr3RHwf4WIf6DusPp" + invalidKey := "invalid key" + cases := []struct { + desc string + args []string + logType outputLog + errLogMessage string + sdkErr errors.SDKError + boot mgsdk.BootstrapConfig + }{ + { + desc: "bootstrap secure config successfully", + args: []string{ + "secure", + bootConfig.ExternalID, + bootConfig.ExternalKey, + cryptoKey, + }, + boot: bootConfig, + logType: entityLog, + }, + { + desc: "bootstrap config successfully", + args: []string{ + bootConfig.ExternalID, + bootConfig.ExternalKey, + }, + boot: bootConfig, + logType: entityLog, + }, + { + desc: "bootstrap secure config with invalid args", + args: []string{ + cryptoKey, + }, + + logType: usageLog, + }, + { + desc: "bootstrap secure config with invalid key", + args: []string{ + "secure", + bootConfig.ExternalID, + invalidKey, + cryptoKey, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized)), + logType: errLog, + }, + { + desc: "bootstrap config with invalid key", + args: []string{ + bootConfig.ExternalID, + invalidKey, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("BootstrapSecure", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + sdkCall1 := sdkMock.On("Bootstrap", mock.Anything, mock.Anything, mock.Anything).Return(tc.boot, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{bootStrapCmd}, tc.args...)...) + switch tc.logType { + case entityLog: + err := json.Unmarshal([]byte(out), &boot) + assert.Nil(t, err) + assert.Equal(t, tc.boot, boot, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.boot, boot)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + sdkCall1.Unset() + }) + } +} diff --git a/cli/certs.go b/cli/certs.go new file mode 100644 index 000000000..0f675b5eb --- /dev/null +++ b/cli/certs.go @@ -0,0 +1,342 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "encoding/json" + "os" + + "github.com/absmach/supermq/certs" + smqsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/spf13/cobra" +) + +var cmdCerts = []cobra.Command{ + { + Use: "get [all | ] ", + Short: "Get certificate", + Long: `Gets a certificate for a given entity ID or all certificates.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + if args[0] == all { + pm := smqsdk.PageMetadata{ + Limit: Limit, + Offset: Offset, + } + page, err := sdk.ListCerts(cmd.Context(), pm, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + return + } + pm := smqsdk.PageMetadata{ + EntityID: args[0], + Limit: Limit, + Offset: Offset, + } + page, err := sdk.ListCerts(cmd.Context(), pm, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, page) + }, + }, + { + Use: "revoke ", + Short: "Revoke certificate", + Long: `Revokes a certificate for a given serial number.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + err := sdk.RevokeCert(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, + { + Use: "delete ", + Short: "Delete certificate", + Long: `Deletes certificates for a given entity id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + err := sdk.DeleteCert(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, + { + Use: "renew ", + Short: "Renew certificate", + Long: `Renews a certificate for a given serial number.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + _, err := sdk.RenewCert(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, + { + Use: "ocsp ", + Short: "OCSP", + Long: `OCSP for a given serial number or certificate.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + var serialNumber, certContent string + if _, statErr := os.Stat(args[0]); statErr == nil { + certBytes, err := os.ReadFile(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + certContent = string(certBytes) + } else { + serialNumber = args[0] + } + response, err := sdk.OCSP(cmd.Context(), serialNumber, certContent) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, response) + }, + }, + { + Use: "view ", + Short: "View certificate", + Long: `Views a certificate for a given serial number.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + cert, err := sdk.ViewCert(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + }, + }, + { + Use: "view-ca", + Short: "View-ca certificate", + Long: `Views ca certificate.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + cert, err := sdk.ViewCA(cmd.Context()) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + }, + }, + { + Use: "download-ca", + Short: "Download signing CA", + Long: `Download intermediate cert and ca.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + bundle, err := sdk.DownloadCA(cmd.Context()) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logSaveCAFiles(*cmd, bundle) + }, + }, + { + Use: "csr ", + Short: "Create CSR", + Long: `Creates a CSR.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + var pm certs.CSRMetadata + if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { + logErrorCmd(*cmd, err) + return + } + data, err := os.ReadFile(args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + csr, err := sdk.CreateCSR(cmd.Context(), pm, data) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logSaveCSRFiles(*cmd, csr) + }, + }, + { + Use: "issue-csr ", + Short: "Issue from CSR", + Long: `issues a certificate for a given csr.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 5 { + logUsageCmd(*cmd, cmd.Use) + return + } + csrData, err := os.ReadFile(args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + cert, err := sdk.IssueFromCSR(cmd.Context(), args[0], args[1], string(csrData), args[3], args[4]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + logSaveCertFiles(*cmd, cert) + }, + }, + { + Use: "issue-csr-internal ", + Short: "Issue from CSR Internal (Agent)", + Long: `Issues a certificate for a given CSR using agent authentication.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 4 { + logUsageCmd(*cmd, cmd.Use) + return + } + csrData, err := os.ReadFile(args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + cert, err := sdk.IssueFromCSRInternal(cmd.Context(), args[0], args[1], string(csrData), args[3]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + logSaveCertFiles(*cmd, cert) + }, + }, + { + Use: "crl", + Short: "Generate CRL", + Long: `Generates a Certificate Revocation List (CRL).`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + crlBytes, err := sdk.GenerateCRL(cmd.Context()) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logSaveCRLFile(*cmd, crlBytes) + }, + }, + { + Use: "entity-id ", + Short: "Get entity ID by serial number", + Long: `Gets the entity ID for a certificate by its serial number.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + entityID, err := sdk.EntityID(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, map[string]string{"entity_id": entityID}) + }, + }, +} + +// NewCertsCmd returns certificate command. +func NewCertsCmd() *cobra.Command { + var ttl string + issueCmd := cobra.Command{ + Use: "issue [] [--ttl=8760h]", + Short: "Issue certificate", + Long: `Issues a certificate for a given entity ID.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 5 || len(args) > 6 { + logUsageCmd(*cmd, cmd.Use) + return + } + var ipAddrs []string + if err := json.Unmarshal([]byte(args[2]), &ipAddrs); err != nil { + logErrorCmd(*cmd, err) + return + } + var option smqsdk.Options + option.CommonName = args[1] + var domainID, token string + if len(args) == 5 { + domainID = args[3] + token = args[4] + } else { + if err := json.Unmarshal([]byte(args[3]), &option); err != nil { + logErrorCmd(*cmd, err) + return + } + domainID = args[4] + token = args[5] + } + cert, err := sdk.IssueCert(cmd.Context(), args[0], ttl, ipAddrs, option, domainID, token) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + logSaveCertFiles(*cmd, cert) + }, + } + issueCmd.Flags().StringVar(&ttl, "ttl", "8760h", "certificate time to live in duration") + + cmd := cobra.Command{ + Use: "certs [issue | get | revoke | renew | ocsp | view | download-ca | view-ca | csr | issue-csr | issue-csr-internal | crl | entity-id]", + Short: "Certificates management", + Long: `Certificates management: issue, get all, get by entity ID, revoke, renew, OCSP, view, CRL generation, entity ID lookup, agent CSR issuing, and CA operations.`, + } + cmd.AddCommand(&issueCmd) + for i := range cmdCerts { + cmd.AddCommand(&cmdCerts[i]) + } + return &cmd +} diff --git a/cli/certs_test.go b/cli/certs_test.go new file mode 100644 index 000000000..4523f33c0 --- /dev/null +++ b/cli/certs_test.go @@ -0,0 +1,905 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli_test + +import ( + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "testing" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/cli" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + revokeCmd = "revoke" + deleteCmd = "delete" + issueCmd = "issue" + renewCmd = "renew" + certsListCmd = "get" + downloadCACmd = "download-ca" + CATokenCmd = "certsToken-ca" + viewCACmd = "view-ca" + filePermission = 0o644 +) + +var ( + serialNumber = "39054620502613157373429341617471746606" + id = "5b4c9ee3-e719-4a0a-9ee5-354932c5e6a4" + commonName = "test-name" + certsToken = "certsToken" + certsDomainID = "domain-id" +) + +func TestIssueCertCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + ipAddrs := "[\"192.168.100.22\"]" + + var cert sdk.Certificate + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + cert sdk.Certificate + }{ + { + desc: "issue cert successfully", + args: []string{ + id, + commonName, + ipAddrs, + certsDomainID, + certsToken, + }, + logType: entityLog, + cert: sdk.Certificate{SerialNumber: serialNumber}, + }, + { + desc: "issue cert with invalid args", + args: []string{ + id, + ipAddrs, + }, + logType: usageLog, + }, + { + desc: "issue cert failed", + args: []string{ + id, + commonName, + ipAddrs, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + { + desc: "issue cert with 6 args", + args: []string{ + id, + commonName, + ipAddrs, + "{\"organization\":[\"organization_name\"]}", + certsDomainID, + certsToken, + }, + logType: entityLog, + cert: sdk.Certificate{SerialNumber: serialNumber}, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"cert.pem", "key.pem"}) + }() + sdkCall := sdkMock.On("IssueCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.cert, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{issueCmd}, tc.args...)...) + switch tc.logType { + case entityLog: + lines := strings.Split(out, "\n") + var jsonLines []string + var inJSON bool + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "{") { + inJSON = true + jsonLines = append(jsonLines, line) + } else if inJSON && strings.HasSuffix(line, "}") { + jsonLines = append(jsonLines, line) + break + } else if inJSON { + jsonLines = append(jsonLines, line) + } + } + + if len(jsonLines) == 0 { + t.Fatalf("No JSON found in output: %s", out) + } + + jsonPart := strings.Join(jsonLines, "") + + err := json.Unmarshal([]byte(jsonPart), &cert) + assert.Nil(t, err) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) + assert.True(t, strings.Contains(out, "All certificate files have been saved successfully"), fmt.Sprintf("%s should save files", tc.desc)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestRevokeCertCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + }{ + { + desc: "revoke cert successfully", + args: []string{ + serialNumber, + certsDomainID, + certsToken, + }, + logType: okLog, + }, + { + desc: "revoke cert with invalid args", + args: []string{ + serialNumber, + extraArg, + }, + logType: usageLog, + }, + { + desc: "revoke cert failed", + args: []string{ + serialNumber, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("RevokeCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{revokeCmd}, tc.args...)...) + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestDeleteCertCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + }{ + { + desc: "delete certs successfully", + args: []string{ + id, + certsDomainID, + certsToken, + }, + logType: okLog, + }, + { + desc: "delete certs with invalid args", + args: []string{ + id, + extraArg, + }, + logType: usageLog, + }, + { + desc: "delete certs failed", + args: []string{ + id, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("DeleteCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{deleteCmd}, tc.args...)...) + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestRenewCertCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + }{ + { + desc: "renew cert successfully", + args: []string{ + serialNumber, + certsDomainID, + certsToken, + }, + logType: okLog, + }, + { + desc: "renew cert with invalid args", + args: []string{ + serialNumber, + extraArg, + }, + logType: usageLog, + }, + { + desc: "renew cert failed", + args: []string{ + serialNumber, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("RenewCert", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(sdk.Certificate{}, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{renewCmd}, tc.args...)...) + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestListCertsCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + var page sdk.CertificatePage + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + page sdk.CertificatePage + }{ + { + desc: "list certs successfully", + args: []string{ + all, + certsDomainID, + certsToken, + }, + logType: entityLog, + page: sdk.CertificatePage{ + Total: 1, + Offset: 0, + Limit: 10, + Certificates: []sdk.Certificate{ + {SerialNumber: serialNumber}, + }, + }, + }, + { + desc: "list certs successfully with entity ID", + args: []string{ + id, + certsDomainID, + certsToken, + }, + logType: entityLog, + page: sdk.CertificatePage{ + Total: 1, + Offset: 0, + Limit: 10, + Certificates: []sdk.Certificate{ + {SerialNumber: serialNumber}, + }, + }, + }, + { + desc: "list certs with invalid args", + args: []string{ + all, + extraArg, + }, + logType: usageLog, + }, + { + desc: "failed list certs with all", + args: []string{ + all, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + { + desc: "failed list certs with entity ID", + args: []string{ + id, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("ListCerts", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.page, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{certsListCmd}, tc.args...)...) + + switch tc.logType { + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case entityLog: + err := json.Unmarshal([]byte(out), &page) + if err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + assert.Equal(t, tc.page, page, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.page, page)) + } + + sdkCall.Unset() + }) + } +} + +func TestDownloadCACmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logMessage string + logType outputLog + certBundle sdk.CertificateBundle + }{ + { + desc: "download CA successfully", + args: []string{}, + logType: entityLog, + certBundle: sdk.CertificateBundle{ + Certificate: []byte("certificate"), + }, + logMessage: "Saved ca.crt\n\nAll certificate files have been saved successfully.\n", + }, + { + desc: "download CA with invalid args", + args: []string{ + extraArg, + }, + logType: usageLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"ca.crt"}) + }() + sdkCall := sdkMock.On("DownloadCA", mock.Anything).Return(tc.certBundle, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{downloadCACmd}, tc.args...)...) + switch tc.logType { + case entityLog: + assert.True(t, strings.Contains(out, "Saved ca.crt"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + +func TestViewCACmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + var cert sdk.Certificate + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + cert sdk.Certificate + }{ + { + desc: "view cert successfully", + args: []string{}, + logType: entityLog, + cert: sdk.Certificate{ + Certificate: "certificate", + Key: "privatekey", + }, + }, + { + desc: "view cert failed", + args: []string{}, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + cert: sdk.Certificate{}, + }, + { + desc: "view cert with invalid args", + args: []string{extraArg}, + logType: usageLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("ViewCA", mock.Anything).Return(tc.cert, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{viewCACmd}, tc.args...)...) + switch tc.logType { + case entityLog: + err := json.Unmarshal([]byte(out), &cert) + assert.Nil(t, err) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + +func TestGenerateCRLCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + crlBytes []byte + }{ + { + desc: "generate CRL successfully", + args: []string{}, + logType: entityLog, + crlBytes: []byte("crl-data"), + }, + { + desc: "generate CRL failed", + args: []string{}, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity)), + logType: errLog, + }, + { + desc: "generate CRL with invalid args", + args: []string{"invalid"}, + logType: usageLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"ca.crl"}) + }() + + sdkCall := sdkMock.On("GenerateCRL", mock.Anything).Return(tc.crlBytes, tc.sdkErr) + defer sdkCall.Unset() + + out := executeCommand(t, rootCmd, append([]string{"crl"}, tc.args...)...) + + switch tc.logType { + case entityLog: + assert.True(t, strings.Contains(out, "CRL file has been saved successfully"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + }) + } +} + +func TestGetEntityIDCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + entityID := "test-entity-id" + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + entityID string + }{ + { + desc: "get entity ID successfully", + args: []string{serialNumber, certsDomainID, certsToken}, + logType: entityLog, + entityID: entityID, + }, + { + desc: "get entity ID with invalid args", + args: []string{serialNumber, extraArg}, + logType: usageLog, + }, + { + desc: "get entity ID failed", + args: []string{serialNumber, certsDomainID, certsToken}, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("EntityID", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.entityID, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{"entity-id"}, tc.args...)...) + + switch tc.logType { + case entityLog: + assert.True(t, strings.Contains(out, tc.entityID), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + +func cleanupFiles(t *testing.T, filenames []string) { + for _, filename := range filenames { + err := os.Remove(filename) + if err != nil && !os.IsNotExist(err) { + t.Logf("Failed to remove file %s: %v", filename, err) + } + } +} + +func TestIssueFromCSRInternalCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + agentToken := "agent-certsToken-123" + csrPath := "test.csr" + bytes := []byte("-----BEGIN CERTIFICATE REQUEST-----\n-csr-content\n-----END CERTIFICATE REQUEST-----") + + err := os.WriteFile(csrPath, bytes, filePermission) + if err != nil { + t.Fatalf("Failed to create test CSR file: %v", err) + } + defer os.Remove(csrPath) + + var cert sdk.Certificate + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + cert sdk.Certificate + }{ + { + desc: "issue cert from CSR internal successfully", + args: []string{ + id, + "10h", + csrPath, + agentToken, + }, + logType: entityLog, + cert: sdk.Certificate{SerialNumber: serialNumber}, + }, + { + desc: "issue cert from CSR internal with invalid args", + args: []string{ + id, + extraArg, + }, + logType: usageLog, + }, + { + desc: "issue cert from CSR internal failed", + args: []string{ + id, + "10h", + csrPath, + agentToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity)), + logType: errLog, + }, + { + desc: "issue cert from CSR internal with non-existent file", + args: []string{ + id, + "10h", + "non-existent.csr", + agentToken, + }, + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"cert.pem", "key.pem"}) + }() + sdkCall := sdkMock.On("IssueFromCSRInternal", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.cert, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{"issue-csr-internal"}, tc.args...)...) + switch tc.logType { + case entityLog: + lines := strings.Split(out, "\n") + var jsonLines []string + var inJSON bool + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "{") { + inJSON = true + jsonLines = append(jsonLines, line) + } else if inJSON && strings.HasSuffix(line, "}") { + jsonLines = append(jsonLines, line) + break + } else if inJSON { + jsonLines = append(jsonLines, line) + } + } + + if len(jsonLines) == 0 { + t.Fatalf("No JSON found in output: %s", out) + } + + jsonPart := strings.Join(jsonLines, "") + + err := json.Unmarshal([]byte(jsonPart), &cert) + assert.Nil(t, err) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) + assert.True(t, strings.Contains(out, "All certificate files have been saved successfully"), fmt.Sprintf("%s should save files", tc.desc)) + case errLog: + if tc.errLogMessage != "" { + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } else { + assert.True(t, strings.Contains(out, "error"), fmt.Sprintf("%s should contain error message: %s", tc.desc, out)) + } + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + +func TestIssueFromCSRCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + csrPath := "test.csr" + bytes := []byte("-----BEGIN CERTIFICATE REQUEST-----\n-csr-content\n-----END CERTIFICATE REQUEST-----") + + err := os.WriteFile(csrPath, bytes, filePermission) + if err != nil { + t.Fatalf("Failed to create test CSR file: %v", err) + } + defer os.Remove(csrPath) + + var cert sdk.Certificate + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + cert sdk.Certificate + }{ + { + desc: "issue cert from CSR successfully", + args: []string{ + id, + "10h", + csrPath, + certsDomainID, + certsToken, + }, + logType: entityLog, + cert: sdk.Certificate{SerialNumber: serialNumber}, + }, + { + desc: "issue cert from CSR with invalid args", + args: []string{ + id, + extraArg, + }, + logType: usageLog, + }, + { + desc: "issue cert from CSR failed", + args: []string{ + id, + "10h", + csrPath, + certsDomainID, + certsToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity)), + logType: errLog, + }, + { + desc: "issue cert from CSR with non-existent file", + args: []string{ + id, + "10h", + "non-existent.csr", + certsDomainID, + certsToken, + }, + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"cert.pem", "key.pem"}) + }() + sdkCall := sdkMock.On("IssueFromCSR", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.cert, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{"issue-csr"}, tc.args...)...) + switch tc.logType { + case entityLog: + lines := strings.Split(out, "\n") + var jsonLines []string + var inJSON bool + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "{") { + inJSON = true + jsonLines = append(jsonLines, line) + } else if inJSON && strings.HasSuffix(line, "}") { + jsonLines = append(jsonLines, line) + break + } else if inJSON { + jsonLines = append(jsonLines, line) + } + } + + if len(jsonLines) == 0 { + t.Fatalf("No JSON found in output: %s", out) + } + + jsonPart := strings.Join(jsonLines, "") + + err := json.Unmarshal([]byte(jsonPart), &cert) + assert.Nil(t, err) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) + assert.True(t, strings.Contains(out, "All certificate files have been saved successfully"), fmt.Sprintf("%s should save files", tc.desc)) + case errLog: + if tc.errLogMessage != "" { + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } else { + assert.True(t, strings.Contains(out, "error"), fmt.Sprintf("%s should contain error message: %s", tc.desc, out)) + } + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} diff --git a/cli/clients_test.go b/cli/clients_test.go index df4c153ac..c6f003f67 100644 --- a/cli/clients_test.go +++ b/cli/clients_test.go @@ -24,9 +24,7 @@ import ( var ( token = "valid" + "domaintoken" - domainID = "domain-id" relation = "administrator" - all = "all" conntype = `["publish","subscribe"]` errEndJSONInput = errors.New("unexpected end of JSON input") diff --git a/cli/commands_test.go b/cli/commands_test.go index 6169ea891..16ecfb259 100644 --- a/cli/commands_test.go +++ b/cli/commands_test.go @@ -51,3 +51,11 @@ const ( listCmd = "list" membersCmd = "members" ) + +// Bootstrap commands +const ( + updCmd = "update" + rmCmd = "remove" + whitelistCmd = "whitelist" + bootStrapCmd = "bootstrap" +) diff --git a/cli/consumers.go b/cli/consumers.go new file mode 100644 index 000000000..2fd461469 --- /dev/null +++ b/cli/consumers.go @@ -0,0 +1,100 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + mgsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/spf13/cobra" +) + +var cmdSubscription = []cobra.Command{ + { + Use: "create ", + Short: "Create subscription", + Long: `Create new subscription`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + id, err := sdk.CreateSubscription(cmd.Context(), args[0], args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logCreatedCmd(*cmd, id) + }, + }, + { + Use: "get [all | ] ", + Short: "Get subscription", + Long: `Get subscription. + all - lists all subscriptions + - view subscription of `, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + pageMetadata := mgsdk.PageMetadata{ + Offset: Offset, + Limit: Limit, + Topic: Topic, + Contact: Contact, + } + if args[0] == all { + sub, err := sdk.ListSubscriptions(cmd.Context(), pageMetadata, args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, sub) + return + } + + c, err := sdk.ViewSubscription(cmd.Context(), args[0], args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, c) + }, + }, + { + Use: "remove ", + Short: "Remove subscription", + Long: `Removes removes a subscription with the provided id`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 2 { + logUsageCmd(*cmd, cmd.Use) + return + } + + if err := sdk.DeleteSubscription(cmd.Context(), args[0], args[1]); err != nil { + logErrorCmd(*cmd, err) + return + } + + logOKCmd(*cmd) + }, + }, +} + +// NewSubscriptionCmd returns subscription command. +func NewSubscriptionCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "subscription [create | get | remove ]", + Short: "Subscription management", + Long: `Subscription management: create, get, or delete subscription`, + } + + for i := range cmdSubscription { + cmd.AddCommand(&cmdSubscription[i]) + } + + return &cmd +} diff --git a/cli/consumers_test.go b/cli/consumers_test.go new file mode 100644 index 000000000..1b5f94c83 --- /dev/null +++ b/cli/consumers_test.go @@ -0,0 +1,266 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli_test + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "testing" + + "github.com/absmach/supermq/cli" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + mgsdk "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + userID = testsutil.GenerateUUID(&testing.T{}) + subscription = mgsdk.Subscription{ + ID: testsutil.GenerateUUID(&testing.T{}), + OwnerID: userID, + Topic: "topic", + Contact: "identity@example.com", + } +) + +func TestCreateSubscriptionCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + subCmd := cli.NewSubscriptionCmd() + rootCmd := setFlags(subCmd) + + cases := []struct { + desc string + args []string + logType outputLog + errLogMessage string + sdkErr errors.SDKError + response string + id string + }{ + { + desc: "create subscription successfully", + args: []string{ + subscription.Topic, + subscription.Contact, + validToken, + }, + id: userID, + response: fmt.Sprintf("\ncreated: %s\n\n", userID), + logType: createLog, + }, + { + desc: "create subscription with invalid args", + args: []string{ + subscription.Topic, + subscription.Contact, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "create subscription with invalid token", + args: []string{ + subscription.Topic, + subscription.Contact, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("CreateSubscription", mock.Anything, tc.args[0], tc.args[1], tc.args[2]).Return(tc.id, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{createCmd}, tc.args...)...) + + switch tc.logType { + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case createLog: + assert.Equal(t, tc.response, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.response, out)) + } + sdkCall.Unset() + }) + } +} + +func TestGetSubscriptionsCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + subCmd := cli.NewSubscriptionCmd() + rootCmd := setFlags(subCmd) + + var sub mgsdk.Subscription + var page mgsdk.SubscriptionPage + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + page mgsdk.SubscriptionPage + subscription mgsdk.Subscription + logType outputLog + errLogMessage string + }{ + { + desc: "get all subscriptions successfully", + args: []string{ + all, + validToken, + }, + page: mgsdk.SubscriptionPage{ + Subscriptions: []mgsdk.Subscription{subscription}, + }, + logType: entityLog, + }, + { + desc: "get subscription with id", + args: []string{ + subscription.ID, + validToken, + }, + logType: entityLog, + subscription: subscription, + }, + { + desc: "get subscriptions with invalid args", + args: []string{ + all, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "get all subscriptions with invalid token", + args: []string{ + all, + invalidToken, + }, + logType: errLog, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + }, + { + desc: "get subscription with invalid id", + args: []string{ + invalidID, + validToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("ViewSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.subscription, tc.sdkErr) + sdkCall1 := sdkMock.On("ListSubscriptions", mock.Anything, mock.Anything, tc.args[1]).Return(tc.page, tc.sdkErr) + + out := executeCommand(t, rootCmd, append([]string{getCmd}, tc.args...)...) + + switch tc.logType { + case entityLog: + if tc.args[1] == all { + err := json.Unmarshal([]byte(out), &page) + assert.Nil(t, err) + assert.Equal(t, tc.page, page, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.page, page)) + } else { + err := json.Unmarshal([]byte(out), &sub) + assert.Nil(t, err) + assert.Equal(t, tc.subscription, sub, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.subscription, sub)) + } + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + sdkCall1.Unset() + }) + } +} + +func TestRemoveSubscriptionCmd(t *testing.T) { + sdkMock := new(sdkmocks.SDK) + cli.SetSDK(sdkMock) + subCmd := cli.NewSubscriptionCmd() + rootCmd := setFlags(subCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + logType outputLog + errLogMessage string + }{ + { + desc: "remove subscription successfully", + args: []string{ + subscription.ID, + validToken, + }, + logType: okLog, + }, + { + desc: "remove subscription with invalid args", + args: []string{ + subscription.ID, + validToken, + extraArg, + }, + logType: usageLog, + }, + { + desc: "remove subscription with invalid subscription id", + args: []string{ + invalidID, + validToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + { + desc: "remove subscription with invalid token", + args: []string{ + subscription.ID, + invalidToken, + }, + sdkErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("DeleteSubscription", mock.Anything, tc.args[0], tc.args[1]).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{rmCmd}, tc.args...)...) + + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} diff --git a/cli/message.go b/cli/message.go index c97a39e60..356a72dc2 100644 --- a/cli/message.go +++ b/cli/message.go @@ -31,7 +31,7 @@ func NewMessagesCmd() *cobra.Command { cmd := cobra.Command{ Use: "messages [send]", Short: "Send messages", - Long: `Send messages using the http-adapter`, + Long: `Send messages using the HTTP API`, } for i := range cmdMessages { diff --git a/cli/provision.go b/cli/provision.go new file mode 100644 index 000000000..282474303 --- /dev/null +++ b/cli/provision.go @@ -0,0 +1,410 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cli + +import ( + "encoding/csv" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "os" + "path/filepath" + "time" + + "github.com/0x6flab/namegenerator" + smqsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/spf13/cobra" +) + +const ( + jsonExt = ".json" + csvExt = ".csv" + PublishType = "publish" + SubscribeType = "subscribe" +) + +var ( + msgFormat = `[{"bn":"provision:", "bu":"V", "t": %d, "bver":5, "n":"voltage", "u":"V", "v":%d}]` + namesgenerator = namegenerator.NewGenerator() +) + +var cmdProvision = []cobra.Command{ + { + Use: "clients ", + Short: "Provision clients", + Long: `Bulk create clients`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + if _, err := os.Stat(args[0]); os.IsNotExist(err) { + logErrorCmd(*cmd, err) + return + } + + clients, err := clientsFromFile(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + clients, err = sdk.CreateClients(cmd.Context(), clients, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, clients) + }, + }, + { + Use: "channels ", + Short: "Provision channels", + Long: `Bulk create channels`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + channels, err := channelsFromFile(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + var chs []smqsdk.Channel + for _, c := range channels { + c, err = sdk.CreateChannel(cmd.Context(), c, args[1], args[2]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + chs = append(chs, c) + } + channels = chs + + logJSONCmd(*cmd, channels) + }, + }, + { + Use: "connect ", + Short: "Provision connections", + Long: `Bulk connect clients to channels`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 3 { + logUsageCmd(*cmd, cmd.Use) + return + } + + connIDs, err := connectionsFromFile(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + for _, conn := range connIDs { + if err := sdk.Connect(cmd.Context(), conn, args[1], args[2]); err != nil { + logErrorCmd(*cmd, err) + return + } + } + + logOKCmd(*cmd) + }, + }, + { + Use: "test", + Short: "test", + Long: `Provisions test setup: one test user, two clients and two channels. \ + Connect both clients to one of the channels, \ + and only on client to other channel.`, + Run: func(cmd *cobra.Command, args []string) { + numClients := 2 + numChan := 2 + clients := []smqsdk.Client{} + channels := []smqsdk.Channel{} + + if len(args) != 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + + // Create test user + name := namesgenerator.Generate() + user := smqsdk.User{ + FirstName: name, + Email: fmt.Sprintf("%s@email.com", name), + Credentials: smqsdk.Credentials{ + Username: name, + Secret: "12345678", + }, + Status: smqsdk.EnabledStatus, + } + user, err := sdk.CreateUser(cmd.Context(), user, "") + if err != nil { + logErrorCmd(*cmd, err) + return + } + + ut, err := sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + // create domain + domain := smqsdk.Domain{ + Name: fmt.Sprintf("%s-domain", name), + Status: smqsdk.EnabledStatus, + } + domain, err = sdk.CreateDomain(cmd.Context(), domain, ut.AccessToken) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + ut, err = sdk.CreateToken(cmd.Context(), smqsdk.Login{Username: user.Email, Password: user.Credentials.Secret}) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + // Create clients + for i := 0; i < numClients; i++ { + t := smqsdk.Client{ + Name: fmt.Sprintf("%s-client-%d", name, i), + Status: smqsdk.EnabledStatus, + } + + clients = append(clients, t) + } + clients, err = sdk.CreateClients(cmd.Context(), clients, domain.ID, ut.AccessToken) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + // Create channels + for i := 0; i < numChan; i++ { + c := smqsdk.Channel{ + Name: fmt.Sprintf("%s-channel-%d", name, i), + Status: smqsdk.EnabledStatus, + } + c, err = sdk.CreateChannel(cmd.Context(), c, domain.ID, ut.AccessToken) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + channels = append(channels, c) + } + + // Connect clients to channels - first client to both channels, second only to first + conIDs := smqsdk.Connection{ + ChannelIDs: []string{channels[0].ID}, + ClientIDs: []string{clients[0].ID}, + Types: []string{PublishType, SubscribeType}, + } + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { + logErrorCmd(*cmd, err) + return + } + + conIDs = smqsdk.Connection{ + ChannelIDs: []string{channels[1].ID}, + ClientIDs: []string{clients[0].ID}, + Types: []string{PublishType, SubscribeType}, + } + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { + logErrorCmd(*cmd, err) + return + } + + conIDs = smqsdk.Connection{ + ChannelIDs: []string{channels[0].ID}, + ClientIDs: []string{clients[1].ID}, + Types: []string{PublishType, SubscribeType}, + } + if err := sdk.Connect(cmd.Context(), conIDs, domain.ID, ut.AccessToken); err != nil { + logErrorCmd(*cmd, err) + return + } + + // send message to test connectivity + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + logErrorCmd(*cmd, err) + return + } + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[0].ID, clients[1].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + logErrorCmd(*cmd, err) + return + } + if err := sdk.SendMessage(cmd.Context(), domain.ID, channels[1].ID, clients[0].Credentials.Secret, fmt.Sprintf(msgFormat, time.Now().Unix(), rand.Int())); err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, user, ut, clients, channels) + }, + }, +} + +// NewProvisionCmd returns provision command. +func NewProvisionCmd() *cobra.Command { + cmd := cobra.Command{ + Use: "provision [clients | channels | connect | test]", + Short: "Provision clients and channels from a config file", + Long: `Provision clients and channels: use json or csv file to bulk provision clients and channels`, + } + + for i := range cmdProvision { + cmd.AddCommand(&cmdProvision[i]) + } + + return &cmd +} + +func clientsFromFile(path string) ([]smqsdk.Client, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return []smqsdk.Client{}, err + } + + file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm) + if err != nil { + return []smqsdk.Client{}, err + } + defer file.Close() + + clients := []smqsdk.Client{} + switch filepath.Ext(path) { + case csvExt: + reader := csv.NewReader(file) + + for { + l, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return []smqsdk.Client{}, err + } + + if len(l) < 1 { + return []smqsdk.Client{}, errors.New("empty line found in file") + } + + client := smqsdk.Client{ + Name: l[0], + } + + clients = append(clients, client) + } + case jsonExt: + err := json.NewDecoder(file).Decode(&clients) + if err != nil { + return []smqsdk.Client{}, err + } + default: + return []smqsdk.Client{}, err + } + + return clients, nil +} + +func channelsFromFile(path string) ([]smqsdk.Channel, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return []smqsdk.Channel{}, err + } + + file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm) + if err != nil { + return []smqsdk.Channel{}, err + } + defer file.Close() + + channels := []smqsdk.Channel{} + switch filepath.Ext(path) { + case csvExt: + reader := csv.NewReader(file) + + for { + l, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return []smqsdk.Channel{}, err + } + + if len(l) < 1 { + return []smqsdk.Channel{}, errors.New("empty line found in file") + } + + channel := smqsdk.Channel{ + Name: l[0], + } + + channels = append(channels, channel) + } + case jsonExt: + err := json.NewDecoder(file).Decode(&channels) + if err != nil { + return []smqsdk.Channel{}, err + } + default: + return []smqsdk.Channel{}, err + } + + return channels, nil +} + +func connectionsFromFile(path string) ([]smqsdk.Connection, error) { + if _, err := os.Stat(path); os.IsNotExist(err) { + return []smqsdk.Connection{}, err + } + + file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm) + if err != nil { + return []smqsdk.Connection{}, err + } + defer file.Close() + + connections := []smqsdk.Connection{} + switch filepath.Ext(path) { + case csvExt: + reader := csv.NewReader(file) + + for { + l, err := reader.Read() + if err == io.EOF { + break + } + if err != nil { + return []smqsdk.Connection{}, err + } + + if len(l) < 1 { + return []smqsdk.Connection{}, errors.New("empty line found in file") + } + connections = append(connections, smqsdk.Connection{ + ClientIDs: []string{l[0]}, + ChannelIDs: []string{l[1]}, + Types: []string{PublishType, SubscribeType}, + }) + } + case jsonExt: + err := json.NewDecoder(file).Decode(&connections) + if err != nil { + return []smqsdk.Connection{}, err + } + default: + return []smqsdk.Connection{}, err + } + + return connections, nil +} diff --git a/cli/users_test.go b/cli/users_test.go index 583b198f2..ad64039c3 100644 --- a/cli/users_test.go +++ b/cli/users_test.go @@ -33,13 +33,6 @@ var user = mgsdk.User{ Status: users.EnabledStatus.String(), } -var ( - validToken = "valid" - invalidToken = "" - invalidID = "invalidID" - extraArg = "extra-arg" -) - func TestCreateUsersCmd(t *testing.T) { sdkMock := new(sdkmocks.SDK) cli.SetSDK(sdkMock) diff --git a/cli/utils.go b/cli/utils.go index c5542e5ed..ea424aa37 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -6,7 +6,11 @@ package cli import ( "encoding/json" "fmt" + "os" + "path/filepath" + "github.com/absmach/supermq/certs" + smqsdk "github.com/absmach/supermq/pkg/sdk" "github.com/fatih/color" "github.com/hokaccha/go-prettyjson" "github.com/spf13/cobra" @@ -76,6 +80,14 @@ func logOKCmd(cmd cobra.Command) { fmt.Fprintf(cmd.OutOrStdout(), "\n%s\n\n", color.BlueString("ok")) } +func logCreatedCmd(cmd cobra.Command, e string) { + if RawOutput { + fmt.Fprintln(cmd.OutOrStdout(), e) + } else { + fmt.Fprintf(cmd.OutOrStdout(), color.BlueString("\ncreated: %s\n\n"), e) + } +} + func convertMetadata(m string) (map[string]any, error) { var metadata map[string]any if m == "" { @@ -86,3 +98,72 @@ func convertMetadata(m string) (map[string]any, error) { } return nil, nil } + +const certFileMode = 0o644 + +func logSaveCertFiles(cmd cobra.Command, cert smqsdk.Certificate) { + files := map[string][]byte{ + "cert.pem": []byte(cert.Certificate), + } + if cert.Key != "" { + files["key.pem"] = []byte(cert.Key) + } + for filename, content := range files { + if err := saveToFile(filename, content); err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") +} + +func logSaveCAFiles(cmd cobra.Command, certBundle smqsdk.CertificateBundle) { + files := map[string][]byte{ + "ca.crt": certBundle.Certificate, + } + for filename, content := range files { + if err := saveToFile(filename, content); err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") +} + +func logSaveCSRFiles(cmd cobra.Command, csr certs.CSR) { + files := map[string][]byte{ + "file.csr": csr.CSR, + } + for filename, content := range files { + if err := saveToFile(filename, content); err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nCSR file have been saved successfully.\n") +} + +func logSaveCRLFile(cmd cobra.Command, crlBytes []byte) { + filename := "ca.crl" + if err := saveToFile(filename, crlBytes); err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + fmt.Fprintf(cmd.OutOrStdout(), "\nCRL file has been saved successfully.\n") +} + +func saveToFile(filename string, content []byte) error { + cwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current working directory: %w", err) + } + filePath := filepath.Join(cwd, filename) + if err := os.WriteFile(filePath, content, certFileMode); err != nil { + return fmt.Errorf("failed to write file %s: %w", filename, err) + } + return nil +} diff --git a/clients/README.md b/clients/README.md index 752cd632a..05d2dc530 100644 --- a/clients/README.md +++ b/clients/README.md @@ -18,37 +18,37 @@ default values. | Variable | Description | Default | | ------------------------------ | ----------------------------------------------------------------------- | ------------------------------ | -| SMQ_CLIENTS_LOG_LEVEL | Log level for Clients (debug, info, warn, error) | info | -| SMQ_CLIENTS_HTTP_HOST | Clients service HTTP host | localhost | -| SMQ_CLIENTS_HTTP_PORT | Clients service HTTP port | 9000 | -| SMQ_CLIENTS_SERVER_CERT | Path to the PEM encoded server certificate file | "" | -| SMQ_CLIENTS_SERVER_KEY | Path to the PEM encoded server key file | "" | -| SMQ_CLIENTS_GRPC_HOST | Clients service gRPC host | localhost | -| SMQ_CLIENTS_GRPC_PORT | Clients service gRPC port | 7000 | -| SMQ_CLIENTS_GRPC_SERVER_CERT | Path to the PEM encoded server certificate file | "" | -| SMQ_CLIENTS_GRPC_SERVER_KEY | Path to the PEM encoded server key file | "" | -| SMQ_CLIENTS_DB_HOST | Database host address | localhost | -| SMQ_CLIENTS_DB_PORT | Database host port | 5432 | -| SMQ_CLIENTS_DB_USER | Database user | supermq | -| SMQ_CLIENTS_DB_PASS | Database password | supermq | -| SMQ_CLIENTS_DB_NAME | Name of the database used by the service | clients | -| SMQ_CLIENTS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| SMQ_CLIENTS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | -| SMQ_CLIENTS_DB_SSL_KEY | Path to the PEM encoded key file | "" | -| SMQ_CLIENTS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | -| SMQ_CLIENTS_CACHE_URL | Cache database URL | | -| SMQ_CLIENTS_CACHE_KEY_DURATION | Cache key duration in seconds | 3600 | -| SMQ_CLIENTS_ES_URL | Event store URL | | -| SMQ_CLIENTS_ES_PASS | Event store password | "" | -| SMQ_CLIENTS_ES_DB | Event store instance name | 0 | -| SMQ_CLIENTS_STANDALONE_ID | User ID for standalone mode (no gRPC communication with Auth) | "" | -| SMQ_CLIENTS_STANDALONE_TOKEN | User token for standalone mode that should be passed in auth header | "" | -| SMQ_JAEGER_URL | Jaeger server URL | | -| SMQ_AUTH_GRPC_URL | Auth service gRPC URL | localhost:7001 | -| SMQ_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s | -| SMQ_AUTH_GRPC_CLIENT_TLS | Enable TLS for gRPC client | false | -| SMQ_AUTH_GRPC_CA_CERT | Path to the CA certificate file | "" | -| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true | +| MG_CLIENTS_LOG_LEVEL | Log level for Clients (debug, info, warn, error) | info | +| MG_CLIENTS_HTTP_HOST | Clients service HTTP host | localhost | +| MG_CLIENTS_HTTP_PORT | Clients service HTTP port | 9000 | +| MG_CLIENTS_SERVER_CERT | Path to the PEM encoded server certificate file | "" | +| MG_CLIENTS_SERVER_KEY | Path to the PEM encoded server key file | "" | +| MG_CLIENTS_GRPC_HOST | Clients service gRPC host | localhost | +| MG_CLIENTS_GRPC_PORT | Clients service gRPC port | 7000 | +| MG_CLIENTS_GRPC_SERVER_CERT | Path to the PEM encoded server certificate file | "" | +| MG_CLIENTS_GRPC_SERVER_KEY | Path to the PEM encoded server key file | "" | +| MG_CLIENTS_DB_HOST | Database host address | localhost | +| MG_CLIENTS_DB_PORT | Database host port | 5432 | +| MG_CLIENTS_DB_USER | Database user | supermq | +| MG_CLIENTS_DB_PASS | Database password | supermq | +| MG_CLIENTS_DB_NAME | Name of the database used by the service | clients | +| MG_CLIENTS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| MG_CLIENTS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| MG_CLIENTS_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| MG_CLIENTS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| MG_CLIENTS_CACHE_URL | Cache database URL | | +| MG_CLIENTS_CACHE_KEY_DURATION | Cache key duration in seconds | 3600 | +| MG_CLIENTS_ES_URL | Event store URL | | +| MG_CLIENTS_ES_PASS | Event store password | "" | +| MG_CLIENTS_ES_DB | Event store instance name | 0 | +| MG_CLIENTS_STANDALONE_ID | User ID for standalone mode (no gRPC communication with Auth) | "" | +| MG_CLIENTS_STANDALONE_TOKEN | User token for standalone mode that should be passed in auth header | "" | +| MG_JAEGER_URL | Jaeger server URL | | +| MG_AUTH_GRPC_URL | Auth service gRPC URL | localhost:7001 | +| MG_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s | +| MG_AUTH_GRPC_CLIENT_TLS | Enable TLS for gRPC client | false | +| MG_AUTH_GRPC_CA_CERT | Path to the CA certificate file | "" | +| MG_SEND_TELEMETRY | Send telemetry to supermq call home server. | true | | Clients_INSTANCE_ID | Clients instance ID | "" | **Note** that if you want `clients` service to have only one user locally, you should use `CLIENTS_STANDALONE` env vars. By specifying these, you don't need `auth` service in your deployment for users' authorization. @@ -98,12 +98,12 @@ Clients_CACHE_URL=[Cache database URL] \ Clients_ES_URL=[Event store URL] \ Clients_ES_PASS=[Event store password] \ Clients_ES_DB=[Event store instance name] \ -SMQ_AUTH_GRPC_URL=[Auth service gRPC URL] \ -SMQ_AUTH_GRPC_TIMEOUT=[Auth service gRPC request timeout in seconds] \ -SMQ_AUTH_GRPC_CLIENT_TLS=[Enable TLS for gRPC client] \ -SMQ_AUTH_GRPC_CA_CERT=[Path to trusted CA certificate file] \ -SMQ_JAEGER_URL=[Jaeger server URL] \ -SMQ_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ +MG_AUTH_GRPC_URL=[Auth service gRPC URL] \ +MG_AUTH_GRPC_TIMEOUT=[Auth service gRPC request timeout in seconds] \ +MG_AUTH_GRPC_CLIENT_TLS=[Enable TLS for gRPC client] \ +MG_AUTH_GRPC_CA_CERT=[Path to trusted CA certificate file] \ +MG_JAEGER_URL=[Jaeger server URL] \ +MG_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ Clients_INSTANCE_ID=[Clients instance ID] \ $GOBIN/supermq-clients ``` diff --git a/clients/events/streams.go b/clients/events/streams.go index 62afcc296..2fef6abc7 100644 --- a/clients/events/streams.go +++ b/clients/events/streams.go @@ -42,7 +42,7 @@ type eventStore struct { // NewEventStoreMiddleware returns wrapper around clients service that sends // events to event store. func NewEventStoreMiddleware(ctx context.Context, svc clients.Service, url string) (clients.Service, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "clients-es-pub") if err != nil { return nil, err } diff --git a/clients/middleware/authorization.go b/clients/middleware/authorization.go index a71aefecd..4fd3058ce 100644 --- a/clients/middleware/authorization.go +++ b/clients/middleware/authorization.go @@ -299,7 +299,7 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }, nil); err != nil { return err } diff --git a/cmd/alarms/main.go b/cmd/alarms/main.go new file mode 100644 index 000000000..3f1314dad --- /dev/null +++ b/cmd/alarms/main.go @@ -0,0 +1,263 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "net/url" + "os" + + "github.com/absmach/supermq/alarms" + httpAPI "github.com/absmach/supermq/alarms/api" + "github.com/absmach/supermq/alarms/brokers" + "github.com/absmach/supermq/alarms/consumer" + "github.com/absmach/supermq/alarms/middleware" + "github.com/absmach/supermq/alarms/operations" + alarmsRepo "github.com/absmach/supermq/alarms/postgres" + dpostgres "github.com/absmach/supermq/domains/postgres" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/authn/authsvc" + authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" + dconsumer "github.com/absmach/supermq/pkg/domains/events/consumer" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/grpcclient" + "github.com/absmach/supermq/pkg/jaeger" + "github.com/absmach/supermq/pkg/messaging" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + rconsumer "github.com/absmach/supermq/pkg/re/events/consumer" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + rpostgres "github.com/absmach/supermq/re/postgres" + "github.com/caarlos0/env/v11" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "alarms" + envPrefixDB = "MG_ALARMS_DB_" + envPrefixHTTP = "MG_ALARMS_HTTP_" + envPrefixAuth = "MG_AUTH_GRPC_" + defDB = "alarms" + defSvcHTTPPort = "8050" + envPrefixDomains = "MG_DOMAINS_GRPC_" + alarmEntity = "alarm" +) + +type config struct { + LogLevel string `env:"MG_ALARMS_LOG_LEVEL" envDefault:"info"` + BrokerURL string `env:"MG_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + InstanceID string `env:"MG_ALARMS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` + ESConsumerName string `env:"MG_ALARMS_EVENT_CONSUMER" envDefault:"alarms"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err.Error()) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + dbConfig := postgres.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + } + + migrations, err := alarmsRepo.Migration() + if err != nil { + logger.Error(fmt.Sprintf("failed to load migrations: %s", err)) + exitCode = 1 + return + } + + db, err := postgres.Setup(dbConfig, *migrations) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + repo := alarmsRepo.NewAlarmsRepo(db) + + authConfig := grpcclient.Config{} + if err := env.ParseWithOptions(&authConfig, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) + exitCode = 1 + return + } + authn, authnClient, err := authsvc.NewAuthentication(ctx, authConfig) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + am := smqauthn.NewAuthNMiddleware(authn) + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authConfig, domAuthz) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer authzHandler.Close() + + logger.Info("AuthZ successfully connected to auth gRPC server " + authzHandler.Secure()) + + ddatabase := postgres.NewDatabase(db, dbConfig, tracer) + drepo := dpostgres.NewRepository(ddatabase) + + if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) + exitCode = 1 + return + } + + rdatabase := postgres.NewDatabase(db, dbConfig, tracer) + rrepo := rpostgres.NewRepository(rdatabase) + + if err := rconsumer.RulesEventsSubscribe(ctx, rrepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to rules events: %s", err)) + exitCode = 1 + return + } + + idp := uuid.New() + + svc := alarms.NewService(idp, repo) + + permConfig, err := permissions.ParsePermissionsFile(cfg.PermissionsFile) + if err != nil { + logger.Error(fmt.Sprintf("failed to parse permissions file: %s", err)) + exitCode = 1 + return + } + + alarmOps, _, err := permConfig.GetEntityPermissions(alarmEntity) + if err != nil { + logger.Error(fmt.Sprintf("failed to get alarm permissions: %s", err)) + exitCode = 1 + return + } + + entitiesOps, err := permissions.NewEntitiesOperations( + permissions.EntitiesPermission{ + operations.EntityType: alarmOps, + }, + permissions.EntitiesOperationDetails[permissions.Operation]{ + operations.EntityType: operations.OperationDetails(), + }, + ) + if err != nil { + logger.Error(fmt.Sprintf("failed to create entity operations: %s", err)) + exitCode = 1 + return + } + + svc, err = middleware.NewAuthorizationMiddleware(svc, authz, entitiesOps) + if err != nil { + logger.Error(fmt.Sprintf("failed to create authorization middleware: %s", err)) + exitCode = 1 + return + } + + svc = middleware.NewLoggingMiddleware(logger, svc) + counter, latency := prometheus.MakeMetrics("alarms", "api") + svc = middleware.NewMetricsMiddleware(counter, latency, svc) + svc = middleware.NewTracingMiddleware(tracer, svc) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpAPI.MakeHandler(svc, logger, idp, cfg.InstanceID, am), logger) + + pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + exitCode = 1 + return + } + defer pubSub.Close() + pubSub = brokerstracing.NewPubSub(httpServerConfig, tracer, pubSub) + + consumer := consumer.NewHandler(svc, logger) + + subCfg := messaging.SubscriberConfig{ + ID: svcName, + Topic: brokers.AllTopic, + DeliveryPolicy: messaging.DeliverAllPolicy, + Handler: consumer, + } + if err := pubSub.Subscribe(ctx, subCfg); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to message broker: %s", err)) + exitCode = 1 + + return + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("billing service terminated: %s", err)) + } +} diff --git a/cmd/auth/main.go b/cmd/auth/main.go index 983c66067..c1f13d296 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -51,36 +51,36 @@ import ( const ( svcName = "auth" - envPrefixHTTP = "SMQ_AUTH_HTTP_" - envPrefixGrpc = "SMQ_AUTH_GRPC_" - envPrefixDB = "SMQ_AUTH_DB_" + envPrefixHTTP = "MG_AUTH_HTTP_" + envPrefixGrpc = "MG_AUTH_GRPC_" + envPrefixDB = "MG_AUTH_DB_" defDB = "auth" defSvcHTTPPort = "8189" defSvcGRPCPort = "8181" ) type config struct { - LogLevel string `env:"SMQ_AUTH_LOG_LEVEL" envDefault:"info"` - SecretKey string `env:"SMQ_AUTH_SECRET_KEY" envDefault:"secret"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_AUTH_ADAPTER_INSTANCE_ID" envDefault:""` - AccessDuration time.Duration `env:"SMQ_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"` - RefreshDuration time.Duration `env:"SMQ_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"` - KeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"` - ActiveKeyPath string `env:"SMQ_AUTH_KEYS_ACTIVE_KEY_PATH" envDefault:"./keys/active.key"` - RetiringKeyPath string `env:"SMQ_AUTH_KEYS_RETIRING_KEY_PATH" envDefault:""` - InvitationDuration time.Duration `env:"SMQ_AUTH_INVITATION_DURATION" envDefault:"168h"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"` - CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"` - JWKSCacheMaxAge int `env:"SMQ_AUTH_JWKS_CACHE_MAX_AGE" envDefault:"900"` - JWKSCacheStaleWhileRevalidate int `env:"SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE" envDefault:"60"` + LogLevel string `env:"MG_AUTH_LOG_LEVEL" envDefault:"info"` + SecretKey string `env:"MG_AUTH_SECRET_KEY" envDefault:"secret"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_AUTH_ADAPTER_INSTANCE_ID" envDefault:""` + AccessDuration time.Duration `env:"MG_AUTH_ACCESS_TOKEN_DURATION" envDefault:"1h"` + RefreshDuration time.Duration `env:"MG_AUTH_REFRESH_TOKEN_DURATION" envDefault:"24h"` + KeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"EdDSA"` + ActiveKeyPath string `env:"MG_AUTH_KEYS_ACTIVE_KEY_PATH" envDefault:"./keys/active.key"` + RetiringKeyPath string `env:"MG_AUTH_KEYS_RETIRING_KEY_PATH" envDefault:""` + InvitationDuration time.Duration `env:"MG_AUTH_INVITATION_DURATION" envDefault:"168h"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + CacheURL string `env:"MG_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_AUTH_CACHE_KEY_DURATION" envDefault:"10m"` + JWKSCacheMaxAge int `env:"MG_AUTH_JWKS_CACHE_MAX_AGE" envDefault:"900"` + JWKSCacheStaleWhileRevalidate int `env:"MG_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE" envDefault:"60"` } func main() { @@ -267,7 +267,7 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error { if isSymmetric { if cfg.SecretKey == "secret" { - return fmt.Errorf("default secret key is insecure - please set SMQ_AUTH_SECRET_KEY environment variable") + return fmt.Errorf("default secret key is insecure - please set MG_AUTH_SECRET_KEY environment variable") } return nil } @@ -276,7 +276,7 @@ func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error { _, err := os.Stat(cfg.ActiveKeyPath) if err != nil { if os.IsNotExist(err) { - return fmt.Errorf("active key file not found: %s - please set SMQ_AUTH_KEYS_ACTIVE_KEY_PATH", cfg.ActiveKeyPath) + return fmt.Errorf("active key file not found: %s - please set MG_AUTH_KEYS_ACTIVE_KEY_PATH", cfg.ActiveKeyPath) } return fmt.Errorf("failed to access active key file: %w", err) } diff --git a/cmd/bootstrap/main.go b/cmd/bootstrap/main.go new file mode 100644 index 000000000..e64552d16 --- /dev/null +++ b/cmd/bootstrap/main.go @@ -0,0 +1,277 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains bootstrap main function to start the bootstrap service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/url" + "os" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + "github.com/absmach/supermq/bootstrap" + httpapi "github.com/absmach/supermq/bootstrap/api" + "github.com/absmach/supermq/bootstrap/events/consumer" + "github.com/absmach/supermq/bootstrap/events/producer" + "github.com/absmach/supermq/bootstrap/middleware" + bootstrappg "github.com/absmach/supermq/bootstrap/postgres" + "github.com/absmach/supermq/bootstrap/tracing" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" + smqauthz "github.com/absmach/supermq/pkg/authz" + authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/grpcclient" + "github.com/absmach/supermq/pkg/jaeger" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/policies/spicedb" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + mgsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/authzed/authzed-go/v1" + "github.com/authzed/grpcutil" + "github.com/caarlos0/env/v11" + "github.com/jmoiron/sqlx" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const ( + svcName = "bootstrap" + envPrefixDB = "MG_BOOTSTRAP_DB_" + envPrefixHTTP = "MG_BOOTSTRAP_HTTP_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + defDB = "bootstrap" + defSvcHTTPPort = "9013" + + stream = "events.supermq.clients" + streamID = "supermq.bootstrap" +) + +type config struct { + LogLevel string `env:"MG_BOOTSTRAP_LOG_LEVEL" envDefault:"info"` + EncKey string `env:"MG_BOOTSTRAP_ENCRYPT_KEY" envDefault:"12345678910111213141516171819202"` + ESConsumerName string `env:"MG_BOOTSTRAP_EVENT_CONSUMER" envDefault:"bootstrap"` + ClientsURL string `env:"MG_CLIENTS_URL" envDefault:"http://localhost:9006"` + ChannelsURL string `env:"MG_CHANNELS_URL" envDefault:"http://localhost:9005"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_BOOTSTRAP_INSTANCE_ID" envDefault:""` + ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + // Create new postgres client + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + } + db, err := pgclient.Setup(dbConfig, *bootstrappg.Migration()) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + policySvc, err := newPolicyService(cfg, logger) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + logger.Info("Policy client successfully connected to spicedb gRPC server") + + tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + grpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&grpcCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + return + } + authn, authnClient, err := authsvcAuthn.NewAuthentication(ctx, grpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + am := smqauthn.NewAuthNMiddleware(authn) + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + defer authnClient.Close() + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + domainsAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzClient, err := authsvcAuthz.NewAuthorization(ctx, grpcCfg, domainsAuthz) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authzClient.Close() + logger.Info("AuthZ successfully connected to auth gRPC server " + authzClient.Secure()) + + // Create new service + svc, err := newService(ctx, authz, policySvc, db, tracer, logger, cfg, dbConfig) + if err != nil { + logger.Error(fmt.Sprintf("failed to create %s service: %s", svcName, err)) + exitCode = 1 + return + } + + if err = subscribeToClientsES(ctx, svc, cfg, logger); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to clients event store: %s", err)) + exitCode = 1 + return + } + + logger.Info("Subscribed to Event Store") + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, bootstrap.NewConfigReader([]byte(cfg.EncKey)), logger, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + // Start servers + g.Go(func() error { + return hs.Start() + }) + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Bootstrap service terminated: %s", err)) + } +} + +func newService(ctx context.Context, authz smqauthz.Authorization, policySvc policies.Service, db *sqlx.DB, tracer trace.Tracer, logger *slog.Logger, cfg config, dbConfig pgclient.Config) (bootstrap.Service, error) { + database := pgclient.NewDatabase(db, dbConfig, tracer) + + repoConfig := bootstrappg.NewConfigRepository(database, logger) + + config := mgsdk.Config{ + ClientsURL: cfg.ClientsURL, + ChannelsURL: cfg.ChannelsURL, + } + + sdk := mgsdk.NewSDK(config) + idp := uuid.New() + + svc := bootstrap.New(policySvc, repoConfig, sdk, []byte(cfg.EncKey), idp) + + publisher, err := store.NewPublisher(ctx, cfg.ESURL, "bootstrap-es-pub") + if err != nil { + return nil, err + } + + svc = middleware.AuthorizationMiddleware(svc, authz) + svc = producer.NewEventStoreMiddleware(svc, publisher) + svc = middleware.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics(svcName, "api") + svc = middleware.MetricsMiddleware(svc, counter, latency) + svc = tracing.New(svc, tracer) + + return svc, nil +} + +func subscribeToClientsES(ctx context.Context, svc bootstrap.Service, cfg config, logger *slog.Logger) error { + subscriber, err := store.NewSubscriber(ctx, cfg.ESURL, "bootstrap-es-sub", logger) + if err != nil { + return err + } + + subConfig := events.SubscriberConfig{ + Stream: stream, + Consumer: cfg.ESConsumerName, + Handler: consumer.NewEventHandler(svc), + } + return subscriber.Subscribe(ctx, subConfig) +} + +func newPolicyService(cfg config, logger *slog.Logger) (policies.Service, error) { + client, err := authzed.NewClientWithExperimentalAPIs( + fmt.Sprintf("%s:%s", cfg.SpicedbHost, cfg.SpicedbPort), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpcutil.WithInsecureBearerToken(cfg.SpicedbPreSharedKey), + ) + if err != nil { + return nil, err + } + policySvc := spicedb.NewPolicyService(client, logger) + + return policySvc, nil +} diff --git a/cmd/certs/main.go b/cmd/certs/main.go new file mode 100644 index 000000000..6285eb0e9 --- /dev/null +++ b/cmd/certs/main.go @@ -0,0 +1,294 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/url" + "os" + "strings" + "time" + + grpcCertsV1 "github.com/absmach/supermq/api/grpc/certs/v1" + "github.com/absmach/supermq/certs" + certsgrpc "github.com/absmach/supermq/certs/api/grpc" + httpapi "github.com/absmach/supermq/certs/api/http" + "github.com/absmach/supermq/certs/middleware" + "github.com/absmach/supermq/certs/pki" + "github.com/absmach/supermq/certs/postgres" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" + smqauthz "github.com/absmach/supermq/pkg/authz" + authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/grpcclient" + "github.com/absmach/supermq/pkg/jaeger" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + smq "github.com/absmach/supermq/pkg/server" + grpcserver "github.com/absmach/supermq/pkg/server/grpc" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v10" + "github.com/jmoiron/sqlx" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +const ( + svcName = "certs" + envPrefixHTTP = "MG_CERTS_HTTP_" + envPrefixDB = "MG_CERTS_DB_" + envPrefixGRPC = "MG_CERTS_GRPC_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + defSvcHTTPPort = "9010" + defSvcGRPCPort = "7012" + defDB = "certs" + serviceTokenKey = "SERVICE_TOKEN=" +) + +type config struct { + LogLevel string `env:"MG_CERTS_LOG_LEVEL" envDefault:"info"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://jaeger:4318"` + InstanceID string `env:"MG_CERTS_INSTANCE_ID" envDefault:""` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + Secret string `env:"MG_CERTS_SECRET" envDefault:""` + + // OpenBao PKI settings + OpenBaoHost string `env:"MG_CERTS_OPENBAO_HOST" envDefault:"http://localhost:8200"` + OpenBaoAppRole string `env:"MG_CERTS_OPENBAO_APP_ROLE" envDefault:""` + OpenBaoAppSecret string `env:"MG_CERTS_OPENBAO_APP_SECRET" envDefault:""` + OpenBaoNamespace string `env:"MG_CERTS_OPENBAO_NAMESPACE" envDefault:""` + OpenBaoPKIPath string `env:"MG_CERTS_OPENBAO_PKI_PATH" envDefault:"pki"` + OpenBaoRole string `env:"MG_CERTS_OPENBAO_ROLE" envDefault:"certs"` + OpenBaoServiceToken string `env:"MG_CERTS_SERVICE_TOKEN" envDefault:""` + ServiceTokenPath string `env:"MG_CERTS_SERVICE_TOKEN_PATH" envDefault:""` + SecretIDPath string `env:"MG_CERTS_SECRET_ID_PATH" envDefault:""` + SecretRenewThreshold string `env:"MG_CERTS_SECRET_RENEW_THRESHOLD" envDefault:"24h"` + SecretIDTTL string `env:"MG_CERTS_OPENBAO_SECRET_ID_TTL" envDefault:"72h"` + SecretCheckInterval string `env:"MG_CERTS_SECRET_CHECK_INTERVAL" envDefault:"30s"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := initLogger(cfg.LogLevel) + if err != nil { + log.Fatalf("failed to initialize logger: %v", err) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + cfg.InstanceID, err = uuid.New().ID() + if err != nil { + logger.Error(fmt.Sprintf("failed to generate instance ID: %v", err)) + exitCode = 1 + return + } + } + + if cfg.OpenBaoHost == "" { + logger.Error("No host specified for OpenBao PKI engine") + exitCode = 1 + return + } + + if cfg.OpenBaoAppRole == "" { + logger.Error("OpenBao AppRole not specified") + exitCode = 1 + return + } + + secretID := cfg.OpenBaoAppSecret + if secretID == "" && cfg.SecretIDPath != "" { + secretData, err := os.ReadFile(cfg.SecretIDPath) + if err != nil { + logger.Error("Failed to read secret ID from file", "path", cfg.SecretIDPath, "error", err) + exitCode = 1 + return + } + secretID = strings.TrimSpace(string(secretData)) + } + + if secretID == "" { + logger.Error("OpenBao secret ID not specified (provide via MG_CERTS_OPENBAO_APP_SECRET or MG_CERTS_SECRET_ID_PATH)") + exitCode = 1 + return + } + + serviceToken := cfg.OpenBaoServiceToken + if serviceToken == "" && cfg.ServiceTokenPath != "" { + tokenData, err := os.ReadFile(cfg.ServiceTokenPath) + if err != nil { + logger.Warn("Failed to read service token from file, secret renewal will be disabled", "path", cfg.ServiceTokenPath, "error", err) + } else { + tokenLine := string(tokenData) + if strings.HasPrefix(tokenLine, serviceTokenKey) { + serviceToken = strings.TrimSpace(strings.TrimPrefix(tokenLine, serviceTokenKey)) + } + } + } + + pkiAgent, err := pki.NewAgent(cfg.OpenBaoAppRole, secretID, cfg.OpenBaoHost, cfg.OpenBaoNamespace, cfg.OpenBaoPKIPath, cfg.OpenBaoRole, serviceToken, cfg.SecretRenewThreshold, cfg.SecretIDTTL, cfg.SecretCheckInterval, logger) + if err != nil { + logger.Error("failed to configure client for OpenBao PKI engine") + exitCode = 1 + return + } + + if err := pkiAgent.StartSecretRenewal(ctx); err != nil { + logger.Warn("Failed to start secret renewal, service may lose access when secret expires", "error", err) + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + migrations := postgres.Migration() + db, err := pgclient.Setup(dbConfig, *migrations) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authClientConfig := grpcclient.Config{} + if err := env.ParseWithOptions(&authClientConfig, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) + exitCode = 1 + return + } + + authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientConfig) + if err != nil { + logger.Error("failed to create authn " + err.Error()) + exitCode = 1 + return + } + defer authnHandler.Close() + logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) + authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authClientConfig, domAuthz) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer authzHandler.Close() + logger.Info("Authz successfully connected to auth gRPC server " + authzHandler.Secure()) + httpServerConfig := smq.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err)) + exitCode = 1 + return + } + + svc := newService(ctx, db, dbConfig, tracer, logger, pkiAgent, authz) + + grpcServerConfig := smq.Config{Port: defSvcGRPCPort} + if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { + log.Printf("failed to load %s gRPC server configuration : %s", svcName, err.Error()) + exitCode = 1 + return + } + + registerCertsServiceServer := func(srv *grpc.Server) { + reflection.Register(srv) + grpcCertsV1.RegisterCertsServiceServer(srv, certsgrpc.NewServer(svc)) + } + gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerCertsServiceServer, logger) + + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, cfg.InstanceID, cfg.Secret), logger) + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return gs.Start() + }) + + g.Go(func() error { + return smq.StopSignalHandler(ctx, cancel, logger, svcName, hs, gs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err)) + } +} + +func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, logger *slog.Logger, pkiAgent certs.Agent, authz smqauthz.Authorization) certs.Service { + database := pgclient.NewDatabase(db, dbConfig, tracer) + repo := postgres.NewRepository(database) + svc, err := certs.NewService(ctx, pkiAgent, repo) + if err != nil { + logger.Error(fmt.Sprintf("failed to create service: %s", err)) + return nil + } + svc = middleware.AuthorizationMiddleware(authz, svc) + svc = middleware.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics(svcName, "api") + svc = middleware.MetricsMiddleware(svc, counter, latency) + svc = middleware.New(svc, tracer) + + return svc +} + +func initLogger(levelText string) (*slog.Logger, error) { + var level slog.Level + if err := level.UnmarshalText([]byte(levelText)); err != nil { + return &slog.Logger{}, fmt.Errorf(`{"level":"error","message":"%s: %s","ts":"%s"}`, err, levelText, time.RFC3339Nano) + } + + logHandler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: level, + }) + + return slog.New(logHandler), nil +} diff --git a/cmd/channels/main.go b/cmd/channels/main.go index 47cfef7e7..244af342c 100644 --- a/cmd/channels/main.go +++ b/cmd/channels/main.go @@ -74,36 +74,36 @@ import ( const ( svcName = "channels" - envPrefixDB = "SMQ_CHANNELS_DB_" - envPrefixHTTP = "SMQ_CHANNELS_HTTP_" - envPrefixGRPC = "SMQ_CHANNELS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixGroups = "SMQ_GROUPS_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixChannelCallout = "SMQ_CHANNELS_CALLOUT_" + envPrefixDB = "MG_CHANNELS_DB_" + envPrefixHTTP = "MG_CHANNELS_HTTP_" + envPrefixGRPC = "MG_CHANNELS_GRPC_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixClients = "MG_CLIENTS_GRPC_" + envPrefixGroups = "MG_GROUPS_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + envPrefixChannelCallout = "MG_CHANNELS_CALLOUT_" defDB = "channels" defSvcHTTPPort = "9005" defSvcGRPCPort = "7005" ) type config struct { - LogLevel string `env:"SMQ_CHANNELS_LOG_LEVEL" envDefault:"info"` - InstanceID string `env:"SMQ_CHANNELS_INSTANCE_ID" envDefault:""` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - CacheURL string `env:"SMQ_CHANNELS_CACHE_URL" envDefault:"redis://localhost:6379/0"` - CacheKeyDuration time.Duration `env:"SMQ_CHANNELS_CACHE_KEY_DURATION" envDefault:"10m"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - ESConsumerName string `env:"SMQ_CHANNELS_EVENT_CONSUMER" envDefault:"channels"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` - PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"` + LogLevel string `env:"MG_CHANNELS_LOG_LEVEL" envDefault:"info"` + InstanceID string `env:"MG_CHANNELS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + CacheURL string `env:"MG_CHANNELS_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_CHANNELS_CACHE_KEY_DURATION" envDefault:"10m"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + ESConsumerName string `env:"MG_CHANNELS_EVENT_CONSUMER" envDefault:"channels"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` } func main() { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index ab3ffee58..d81f108e2 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -44,6 +44,7 @@ func main() { configCmd := cli.NewConfigCmd() invitationsCmd := cli.NewInvitationsCmd() journalCmd := cli.NewJournalCmd() + certsCmd := cli.NewCertsCmd() // Root Commands rootCmd.AddCommand(healthCmd) @@ -56,6 +57,7 @@ func main() { rootCmd.AddCommand(configCmd) rootCmd.AddCommand(invitationsCmd) rootCmd.AddCommand(journalCmd) + rootCmd.AddCommand(certsCmd) // Root Flags rootCmd.PersistentFlags().StringVarP( @@ -98,6 +100,14 @@ func main() { "Journal Log URL", ) + rootCmd.PersistentFlags().StringVarP( + &sdkConf.CertsURL, + "certs-url", + "", + sdkConf.CertsURL, + "Certs service URL", + ) + rootCmd.PersistentFlags().StringVarP( &sdkConf.HostURL, "host-url", diff --git a/cmd/clients/main.go b/cmd/clients/main.go index 276422076..d9d8e7c46 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -73,38 +73,38 @@ import ( const ( svcName = "clients" - envPrefixDB = "SMQ_CLIENTS_DB_" - envPrefixHTTP = "SMQ_CLIENTS_HTTP_" - envPrefixGRPC = "SMQ_CLIENTS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixGroups = "SMQ_GROUPS_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixClientCallout = "SMQ_CLIENTS_CALLOUT_" + envPrefixDB = "MG_CLIENTS_DB_" + envPrefixHTTP = "MG_CLIENTS_HTTP_" + envPrefixGRPC = "MG_CLIENTS_GRPC_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixChannels = "MG_CHANNELS_GRPC_" + envPrefixGroups = "MG_GROUPS_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + envPrefixClientCallout = "MG_CLIENTS_CALLOUT_" defDB = "clients" defSvcHTTPPort = "9000" defSvcAuthGRPCPort = "7000" ) type config struct { - InstanceID string `env:"SMQ_CLIENTS_INSTANCE_ID" envDefault:""` - LogLevel string `env:"SMQ_CLIENTS_LOG_LEVEL" envDefault:"info"` - StandaloneID string `env:"SMQ_CLIENTS_STANDALONE_ID" envDefault:""` - StandaloneToken string `env:"SMQ_CLIENTS_STANDALONE_TOKEN" envDefault:""` - CacheURL string `env:"SMQ_CLIENTS_CACHE_URL" envDefault:"redis://localhost:6379/0"` - CacheKeyDuration time.Duration `env:"SMQ_CLIENTS_CACHE_KEY_DURATION" envDefault:"10m"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"clients"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` - PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"` + InstanceID string `env:"MG_CLIENTS_INSTANCE_ID" envDefault:""` + LogLevel string `env:"MG_CLIENTS_LOG_LEVEL" envDefault:"info"` + StandaloneID string `env:"MG_CLIENTS_STANDALONE_ID" envDefault:""` + StandaloneToken string `env:"MG_CLIENTS_STANDALONE_TOKEN" envDefault:""` + CacheURL string `env:"MG_CLIENTS_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_CLIENTS_CACHE_KEY_DURATION" envDefault:"10m"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + ESConsumerName string `env:"MG_CLIENTS_EVENT_CONSUMER" envDefault:"clients"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` } func main() { diff --git a/cmd/coap/main.go b/cmd/coap/main.go deleted file mode 100644 index 96e753270..000000000 --- a/cmd/coap/main.go +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package main contains coap-adapter main function to start the coap-adapter service. -package main - -import ( - "context" - "fmt" - "log" - "log/slog" - "net/url" - "os" - - chclient "github.com/absmach/callhome/pkg/client" - "github.com/absmach/mgate" - mgatecoap "github.com/absmach/mgate/pkg/coap" - "github.com/absmach/mgate/pkg/session" - mgtls "github.com/absmach/mgate/pkg/tls" - "github.com/absmach/supermq" - "github.com/absmach/supermq/coap" - httpapi "github.com/absmach/supermq/coap/api" - "github.com/absmach/supermq/coap/middleware" - smqlog "github.com/absmach/supermq/logger" - domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" - "github.com/absmach/supermq/pkg/grpcclient" - jaegerclient "github.com/absmach/supermq/pkg/jaeger" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/brokers" - brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" - msgevents "github.com/absmach/supermq/pkg/messaging/events" - "github.com/absmach/supermq/pkg/prometheus" - "github.com/absmach/supermq/pkg/server" - coapserver "github.com/absmach/supermq/pkg/server/coap" - httpserver "github.com/absmach/supermq/pkg/server/http" - "github.com/absmach/supermq/pkg/uuid" - "github.com/caarlos0/env/v11" - "github.com/pion/dtls/v3" - "golang.org/x/sync/errgroup" -) - -const ( - svcName = "coap_adapter" - envPrefix = "SMQ_COAP_ADAPTER_" - envPrefixHTTP = "SMQ_COAP_ADAPTER_HTTP_" - envPrefixDTLS = "SMQ_COAP_ADAPTER_SERVER_" - envPrefixCache = "SMQ_COAP_CACHE_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - defSvcHTTPPort = "5683" - defSvcCoAPPort = "5683" - targetProtocol = "coap" - targetCoapPort = "5682" -) - -type config struct { - LogLevel string `env:"SMQ_COAP_ADAPTER_LOG_LEVEL" envDefault:"info"` - BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_COAP_ADAPTER_INSTANCE_ID" envDefault:""` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` -} - -func main() { - ctx, cancel := context.WithCancel(context.Background()) - g, ctx := errgroup.WithContext(ctx) - - cfg := config{} - if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) - } - - logger, err := smqlog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("failed to init logger: %s", err.Error()) - } - - var exitCode int - defer smqlog.ExitWithError(&exitCode) - - if cfg.InstanceID == "" { - if cfg.InstanceID, err = uuid.New().ID(); err != nil { - logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) - exitCode = 1 - return - } - } - - httpServerConfig := server.Config{Port: defSvcHTTPPort} - if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) - exitCode = 1 - return - } - - coapServerConfig := server.Config{Port: defSvcCoAPPort} - if err := env.ParseWithOptions(&coapServerConfig, env.Options{Prefix: envPrefix}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s CoAP server configuration : %s", svcName, err)) - exitCode = 1 - return - } - - dtlsCfg, err := mgtls.NewConfig(env.Options{Prefix: envPrefixDTLS}) - if err != nil { - logger.Error(fmt.Sprintf("failed to load %s DTLS configuration : %s", svcName, err)) - exitCode = 1 - return - } - - cacheConfig := messaging.CacheConfig{} - if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil { - logger.Error(fmt.Sprintf("failed to load cache configuration : %s", err)) - exitCode = 1 - return - } - - domsGrpcCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { - logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) - exitCode = 1 - return - } - _, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer domainsHandler.Close() - - logger.Info("Domains service gRPC client successfully connected to domains gRPC server " + domainsHandler.Secure()) - - clientsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) - exitCode = 1 - return - } - - clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer clientsHandler.Close() - - logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) - - channelsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { - logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer channelsHandler.Close() - logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) - - tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) - if err != nil { - logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) - exitCode = 1 - return - } - defer func() { - if err := tp.Shutdown(ctx); err != nil { - logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) - } - }() - tracer := tp.Tracer(svcName) - - nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) - if err != nil { - logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) - exitCode = 1 - return - } - defer nps.Close() - nps = brokerstracing.NewPubSub(coapServerConfig, tracer, nps) - - nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - - svc := coap.New(clientsClient, channelsClient, nps) - - svc = middleware.NewTracing(tracer, svc) - - svc = middleware.NewLogging(svc, logger) - - counter, latency := prometheus.MakeMetrics(svcName, "api") - svc = middleware.NewMetrics(svc, counter, latency) - - hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(cfg.InstanceID), logger) - - parser, err := messaging.NewTopicParser(cacheConfig, channelsClient, domainsClient) - if err != nil { - logger.Error(fmt.Sprintf("failed to create topic parsers: %s", err)) - exitCode = 1 - return - } - cs := coapserver.NewServer(ctx, cancel, svcName, server.Config{Host: coapServerConfig.Host, Port: targetCoapPort}, httpapi.MakeCoAPHandler(svc, channelsClient, parser, logger), logger) - - if cfg.SendTelemetry { - chc := chclient.New(svcName, supermq.Version, logger, cancel) - go chc.CallHome(ctx) - } - - g.Go(func() error { - return hs.Start() - }) - g.Go(func() error { - g.Go(func() error { - return cs.Start() - }) - handler := coap.NewHandler(logger, clientsClient, channelsClient, parser) - return proxyCoAP(ctx, coapServerConfig, dtlsCfg, handler, logger) - }) - g.Go(func() error { - return server.StopSignalHandler(ctx, cancel, logger, svcName, hs, cs) - }) - - if err := g.Wait(); err != nil { - logger.Error(fmt.Sprintf("CoAP adapter service terminated: %s", err)) - } -} - -func proxyCoAP(ctx context.Context, cfg server.Config, dtlsCfg mgtls.Config, handler session.Handler, logger *slog.Logger) error { - var err error - config := mgate.Config{ - Host: "", - Port: cfg.Port, - TargetProtocol: targetProtocol, - TargetHost: cfg.Host, - TargetPort: targetCoapPort, - } - - mg := mgatecoap.NewProxy(config, handler, logger) - - errCh := make(chan error) - - config.DTLSConfig, err = mgtls.LoadTLSConfig(&dtlsCfg, &dtls.Config{}) - if err != nil { - return err - } - - switch { - case config.DTLSConfig != nil: - dltsCfg := config - mgDtls := mgatecoap.NewProxy(dltsCfg, handler, logger) - logger.Info(fmt.Sprintf("Starting COAP with DTLS proxy on port %s", cfg.Port)) - go func() { - errCh <- mgDtls.Listen(ctx) - }() - default: - logger.Info(fmt.Sprintf("Starting COAP without DTLS proxy on port %s", cfg.Port)) - go func() { - errCh <- mg.Listen(ctx) - }() - } - select { - case <-ctx.Done(): - logger.Info(fmt.Sprintf("proxy COAP shutdown at %s:%s", config.Host, config.Port)) - return nil - case err := <-errCh: - return err - } -} diff --git a/cmd/domains/main.go b/cmd/domains/main.go index 6ba36fdb6..0555e8ff0 100644 --- a/cmd/domains/main.go +++ b/cmd/domains/main.go @@ -63,32 +63,32 @@ import ( const ( svcName = "domains" - envPrefixHTTP = "SMQ_DOMAINS_HTTP_" - envPrefixGrpc = "SMQ_DOMAINS_GRPC_" - envPrefixDB = "SMQ_DOMAINS_DB_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomainCallout = "SMQ_DOMAINS_CALLOUT_" + envPrefixHTTP = "MG_DOMAINS_HTTP_" + envPrefixGrpc = "MG_DOMAINS_GRPC_" + envPrefixDB = "MG_DOMAINS_DB_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomainCallout = "MG_DOMAINS_CALLOUT_" defDB = "domains" defSvcHTTPPort = "9004" defSvcGRPCPort = "7004" ) type config struct { - LogLevel string `env:"SMQ_DOMAINS_LOG_LEVEL" envDefault:"info"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - CacheURL string `env:"SMQ_DOMAINS_CACHE_URL" envDefault:"redis://localhost:6379/0"` - CacheKeyDuration time.Duration `env:"SMQ_DOMAINS_CACHE_KEY_DURATION" envDefault:"10m"` - InstanceID string `env:"SMQ_DOMAINS_INSTANCE_ID" envDefault:""` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` - PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"` + LogLevel string `env:"MG_DOMAINS_LOG_LEVEL" envDefault:"info"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + CacheURL string `env:"MG_DOMAINS_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_DOMAINS_CACHE_KEY_DURATION" envDefault:"10m"` + InstanceID string `env:"MG_DOMAINS_INSTANCE_ID" envDefault:""` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` } func main() { diff --git a/cmd/fluxmq/main.go b/cmd/fluxmq/main.go new file mode 100644 index 000000000..bd81aaab7 --- /dev/null +++ b/cmd/fluxmq/main.go @@ -0,0 +1,232 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains the FluxMQ auth bridge service entry point. +// This service implements the FluxMQ auth callout server using ConnectRPC, +// bridging authentication requests to SuperMQ's Clients service and +// authorization requests to SuperMQ's Channels service. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "net/url" + "os" + "os/signal" + "syscall" + + "connectrpc.com/connect" + "connectrpc.com/otelconnect" + "github.com/absmach/fluxmq/pkg/proto/auth/v1/authv1connect" + fluxmqgrpc "github.com/absmach/supermq/fluxmq/api/grpc" + smqlog "github.com/absmach/supermq/logger" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/grpcclient" + jaegerclient "github.com/absmach/supermq/pkg/jaeger" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/server" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v11" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "fluxmq-auth" + defSvcGRPCPort = "7016" + envPrefixClients = "MG_CLIENTS_GRPC_" + envPrefixChannels = "MG_CHANNELS_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + envPrefixCache = "MG_FLUXMQ_CACHE_" + envPrefixGRPC = "MG_FLUXMQ_GRPC_" +) + +type config struct { + LogLevel string `env:"MG_FLUXMQ_LOG_LEVEL" envDefault:"info"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + InstanceID string `env:"MG_FLUXMQ_INSTANCE_ID" envDefault:""` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration: %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) + } + }() + + // Connect to Domains gRPC service (needed for topic route resolution). + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration: %s", err)) + exitCode = 1 + return + } + _, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + logger.Info("Domains gRPC client connected " + domainsHandler.Secure()) + + // Connect to Clients gRPC service (authentication). + clientsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { + logger.Error(fmt.Sprintf("failed to load clients gRPC client configuration: %s", err)) + exitCode = 1 + return + } + clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer clientsHandler.Close() + logger.Info("Clients gRPC client connected " + clientsHandler.Secure()) + + // Connect to Channels gRPC service (authorization + route resolution). + channelsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { + logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration: %s", err)) + exitCode = 1 + return + } + channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer channelsHandler.Close() + logger.Info("Channels gRPC client connected " + channelsHandler.Secure()) + + // Topic parser with cache for route resolution. + cacheConfig := messaging.CacheConfig{} + if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil { + logger.Error(fmt.Sprintf("failed to load cache configuration: %s", err)) + exitCode = 1 + return + } + parser, err := messaging.NewTopicParser(cacheConfig, channelsClient, domainsClient) + if err != nil { + logger.Error(fmt.Sprintf("failed to create topic parser: %s", err)) + exitCode = 1 + return + } + + // Start FluxMQ auth Connect/gRPC server over h2c. + grpcServerConfig := server.Config{Port: defSvcGRPCPort} + if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { + logger.Error(fmt.Sprintf("failed to load gRPC server configuration: %s", err)) + exitCode = 1 + return + } + + mux := http.NewServeMux() + otelInterceptor, err := otelconnect.NewInterceptor() + if err != nil { + logger.Error(fmt.Sprintf("failed to create OTel interceptor: %s", err)) + exitCode = 1 + return + } + path, handler := authv1connect.NewAuthServiceHandler( + fluxmqgrpc.NewServer(clientsClient, channelsClient, parser), + connect.WithInterceptors(otelInterceptor), + ) + mux.Handle(path, handler) + mux.HandleFunc("/health", func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) //nolint:errcheck // HTTP response write; client disconnect is non-fatal. + }) + + address := fmt.Sprintf("%s:%s", grpcServerConfig.Host, grpcServerConfig.Port) + httpServer := &http.Server{ + Addr: address, + Handler: h2c.NewHandler(mux, &http2.Server{}), + ReadTimeout: grpcServerConfig.ReadTimeout, + WriteTimeout: grpcServerConfig.WriteTimeout, + ReadHeaderTimeout: grpcServerConfig.ReadHeaderTimeout, + IdleTimeout: grpcServerConfig.IdleTimeout, + MaxHeaderBytes: grpcServerConfig.MaxHeaderBytes, + } + + g.Go(func() error { + logger.Info(fmt.Sprintf("%s service h2c server listening at %s", svcName, address)) + var err error + switch { + case grpcServerConfig.CertFile != "" || grpcServerConfig.KeyFile != "": + err = httpServer.ListenAndServeTLS(grpcServerConfig.CertFile, grpcServerConfig.KeyFile) + default: + err = httpServer.ListenAndServe() + } + if err != nil && err != http.ErrServerClosed { + cancel() + return err + } + return nil + }) + + g.Go(func() error { + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), server.StopWaitTime) //nolint:contextcheck + defer shutdownCancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { //nolint:contextcheck + return fmt.Errorf("failed to shutdown %s server: %w", svcName, err) + } + logger.Info(fmt.Sprintf("%s service shutdown at %s", svcName, address)) + return nil + }) + + g.Go(func() error { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + select { + case sig := <-c: + cancel() + logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig)) + return nil + case <-ctx.Done(): + return nil + } + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err)) + } +} diff --git a/cmd/groups/main.go b/cmd/groups/main.go index e7a33cd35..fc5ed7340 100644 --- a/cmd/groups/main.go +++ b/cmd/groups/main.go @@ -67,34 +67,34 @@ import ( const ( svcName = "groups" - envPrefixDB = "SMQ_GROUPS_DB_" - envPrefixHTTP = "SMQ_GROUPS_HTTP_" - envPrefixgRPC = "SMQ_GROUPS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixGroupCallout = "SMQ_GROUPS_CALLOUT_" + envPrefixDB = "MG_GROUPS_DB_" + envPrefixHTTP = "MG_GROUPS_HTTP_" + envPrefixgRPC = "MG_GROUPS_GRPC_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + envPrefixChannels = "MG_CHANNELS_GRPC_" + envPrefixClients = "MG_CLIENTS_GRPC_" + envPrefixGroupCallout = "MG_GROUPS_CALLOUT_" defDB = "groups" defSvcHTTPPort = "9004" defSvcgRPCPort = "7004" ) type config struct { - LogLevel string `env:"SMQ_GROUPS_LOG_LEVEL" envDefault:"info"` - InstanceID string `env:"SMQ_GROUPS_INSTANCE_ID" envDefault:""` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - ESConsumerName string `env:"SMQ_GROUPS_EVENT_CONSUMER" envDefault:"groups"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbSchemaFile string `env:"SMQ_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` - PermissionsFile string `env:"SMQ_PERMISSIONS_FILE" envDefault:"permission.yaml"` + LogLevel string `env:"MG_GROUPS_LOG_LEVEL" envDefault:"info"` + InstanceID string `env:"MG_GROUPS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + ESConsumerName string `env:"MG_GROUPS_EVENT_CONSUMER" envDefault:"groups"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` } func main() { diff --git a/cmd/http/main.go b/cmd/http/main.go deleted file mode 100644 index 886e83abd..000000000 --- a/cmd/http/main.go +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package main contains http-adapter main function to start the http-adapter service. -package main - -import ( - "context" - "crypto/tls" - "fmt" - "log" - "log/slog" - "net/http" - "net/url" - "os" - - chclient "github.com/absmach/callhome/pkg/client" - "github.com/absmach/mgate" - mgatehttp "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - "github.com/absmach/supermq" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" - "github.com/absmach/supermq/auth" - adapter "github.com/absmach/supermq/http" - httpapi "github.com/absmach/supermq/http/api" - "github.com/absmach/supermq/http/middleware" - smqlog "github.com/absmach/supermq/logger" - smqauthn "github.com/absmach/supermq/pkg/authn" - authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" - jwksAuthn "github.com/absmach/supermq/pkg/authn/jwks" - domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" - "github.com/absmach/supermq/pkg/grpcclient" - jaegerclient "github.com/absmach/supermq/pkg/jaeger" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/brokers" - brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" - msgevents "github.com/absmach/supermq/pkg/messaging/events" - "github.com/absmach/supermq/pkg/messaging/handler" - "github.com/absmach/supermq/pkg/prometheus" - "github.com/absmach/supermq/pkg/server" - httpserver "github.com/absmach/supermq/pkg/server/http" - "github.com/absmach/supermq/pkg/uuid" - "github.com/caarlos0/env/v11" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" -) - -const ( - svcName = "http_adapter" - envPrefix = "SMQ_HTTP_ADAPTER_" - envPrefixCache = "SMQ_HTTP_ADAPTER_CACHE_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - defSvcHTTPPort = "80" - targetHTTPProtocol = "http" - targetHTTPHost = "localhost" - targetHTTPPort = "81" - targetHTTPPath = "" -) - -type config struct { - LogLevel string `env:"SMQ_HTTP_ADAPTER_LOG_LEVEL" envDefault:"info"` - BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_HTTP_ADAPTER_INSTANCE_ID" envDefault:""` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` -} - -func main() { - ctx, cancel := context.WithCancel(context.Background()) - g, ctx := errgroup.WithContext(ctx) - - cfg := config{} - if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) - } - - logger, err := smqlog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("failed to init logger: %s", err.Error()) - } - - var exitCode int - defer smqlog.ExitWithError(&exitCode) - - if cfg.InstanceID == "" { - if cfg.InstanceID, err = uuid.New().ID(); err != nil { - logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) - exitCode = 1 - return - } - } - - httpServerConfig := server.Config{Port: defSvcHTTPPort} - if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefix}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) - exitCode = 1 - return - } - - cacheConfig := messaging.CacheConfig{} - if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil { - logger.Error(fmt.Sprintf("failed to load cache configuration : %s", err)) - exitCode = 1 - return - } - - domsGrpcCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { - logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) - exitCode = 1 - return - } - _, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer domainsHandler.Close() - - logger.Info("Domains service gRPC client successfully connected to domains gRPC server " + domainsHandler.Secure()) - - clientsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { - logger.Error(fmt.Sprintf("failed to load clients gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer clientsHandler.Close() - logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) - - channelsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { - logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer channelsHandler.Close() - logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) - - authnCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&authnCfg, env.Options{Prefix: envPrefixAuth}); err != nil { - logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - isSymmetric, err := auth.IsSymmetricAlgorithm(cfg.AuthKeyAlgorithm) - if err != nil { - logger.Error(fmt.Sprintf("failed to parse auth key algorithm : %s", err)) - exitCode = 1 - return - } - var authn smqauthn.Authentication - var authnClient grpcclient.Handler - switch { - case !isSymmetric: - authn, authnClient, err = jwksAuthn.NewAuthentication(ctx, cfg.JWKSURL, authnCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer authnClient.Close() - logger.Info("AuthN successfully set up jwks authentication on " + cfg.JWKSURL) - default: - authn, authnClient, err = authsvcAuthn.NewAuthentication(ctx, authnCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer authnClient.Close() - logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) - } - - tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) - if err != nil { - logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) - exitCode = 1 - return - } - defer func() { - if err := tp.Shutdown(ctx); err != nil { - logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) - } - }() - tracer := tp.Tracer(svcName) - - nps, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) - if err != nil { - logger.Error(fmt.Sprintf("Failed to connect to message broker: %s", err)) - exitCode = 1 - return - } - defer nps.Close() - nps = brokerstracing.NewPubSub(httpServerConfig, tracer, nps) - - nps, err = msgevents.NewPubSubMiddleware(ctx, nps, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - - resolver := messaging.NewTopicResolver(channelsClient, domainsClient) - handler, err := newHandler(nps, authn, cacheConfig, clientsClient, channelsClient, domainsClient, logger, tracer) - if err != nil { - logger.Error(fmt.Sprintf("failed to create service: %s", err)) - exitCode = 1 - return - } - svc := newService(clientsClient, channelsClient, authn, nps, logger, tracer) - - targetServerCfg := server.Config{Port: targetHTTPPort} - - hs := httpserver.NewServer(ctx, cancel, svcName, targetServerCfg, httpapi.MakeHandler(ctx, svc, resolver, logger, cfg.InstanceID), logger) - - if cfg.SendTelemetry { - chc := chclient.New(svcName, supermq.Version, logger, cancel) - go chc.CallHome(ctx) - } - - g.Go(func() error { - return hs.Start() - }) - - g.Go(func() error { - return proxyHTTP(ctx, httpServerConfig, logger, handler) - }) - - g.Go(func() error { - return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) - }) - - if err := g.Wait(); err != nil { - logger.Error(fmt.Sprintf("HTTP adapter service terminated: %s", err)) - } -} - -func newHandler(pubsub messaging.PubSub, authn smqauthn.Authentication, cacheCfg messaging.CacheConfig, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient, logger *slog.Logger, tracer trace.Tracer) (session.Handler, error) { - parser, err := messaging.NewTopicParser(cacheCfg, channels, domains) - if err != nil { - return nil, err - } - h := adapter.NewHandler(pubsub, logger, authn, clients, channels, parser) - h = handler.NewTracing(tracer, h) - h = handler.NewLogging(h, logger) - counter, latency := prometheus.MakeMetrics(svcName, "handler") - h = handler.NewMetrics(h, counter, latency) - - return h, nil -} - -func newService(clientsClient grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, nps messaging.PubSub, logger *slog.Logger, tracer trace.Tracer) adapter.Service { - svc := adapter.NewService(clientsClient, channels, authn, nps) - svc = middleware.NewTracing(tracer, svc) - svc = middleware.NewLogging(svc, logger) - counter, latency := prometheus.MakeMetrics(svcName, "api") - svc = middleware.NewMetrics(svc, counter, latency) - return svc -} - -func proxyHTTP(ctx context.Context, cfg server.Config, logger *slog.Logger, sessionHandler session.Handler) error { - config := mgate.Config{ - Port: cfg.Port, - TargetProtocol: targetHTTPProtocol, - TargetHost: targetHTTPHost, - TargetPort: targetHTTPPort, - TargetPath: targetHTTPPath, - } - if cfg.CertFile != "" || cfg.KeyFile != "" { - tlsCert, err := server.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) - if err != nil { - return err - } - config.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - } - } - mp, err := mgatehttp.NewProxy(config, sessionHandler, logger, []string{}, []string{"/health", "/metrics"}) - if err != nil { - return err - } - http.HandleFunc("/", mp.ServeHTTP) - - errCh := make(chan error) - switch { - case cfg.CertFile != "" || cfg.KeyFile != "": - go func() { - errCh <- mp.Listen(ctx) - }() - logger.Info(fmt.Sprintf("%s service HTTPS server listening at %s:%s with TLS", svcName, cfg.Host, cfg.Port)) - default: - go func() { - errCh <- mp.Listen(ctx) - }() - logger.Info(fmt.Sprintf("%s service HTTP server listening at %s:%s without TLS", svcName, cfg.Host, cfg.Port)) - } - - select { - case <-ctx.Done(): - logger.Info(fmt.Sprintf("proxy HTTP shutdown at %s:%s", config.Host, config.Port)) - return nil - case err := <-errCh: - return err - } -} diff --git a/cmd/journal/main.go b/cmd/journal/main.go index 33a6e1fad..7e14fa7a7 100644 --- a/cmd/journal/main.go +++ b/cmd/journal/main.go @@ -44,23 +44,23 @@ import ( const ( svcName = "journal" - envPrefixDB = "SMQ_JOURNAL_DB_" - envPrefixHTTP = "SMQ_JOURNAL_HTTP_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" + envPrefixDB = "MG_JOURNAL_DB_" + envPrefixHTTP = "MG_JOURNAL_HTTP_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" defDB = "journal" defSvcHTTPPort = "9021" ) type config struct { - LogLevel string `env:"SMQ_JOURNAL_LOG_LEVEL" envDefault:"info"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_JOURNAL_INSTANCE_ID" envDefault:""` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + LogLevel string `env:"MG_JOURNAL_LOG_LEVEL" envDefault:"info"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_JOURNAL_INSTANCE_ID" envDefault:""` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` } func main() { @@ -177,7 +177,7 @@ func main() { svc := newService(db, dbConfig, authz, logger, tracer) - subscriber, err := store.NewSubscriber(ctx, cfg.ESURL, logger) + subscriber, err := store.NewSubscriber(ctx, cfg.ESURL, "journal-es-sub", logger) if err != nil { logger.Error(fmt.Sprintf("failed to create subscriber: %s", err)) exitCode = 1 diff --git a/cmd/mqtt/main.go b/cmd/mqtt/main.go deleted file mode 100644 index 4779d5e57..000000000 --- a/cmd/mqtt/main.go +++ /dev/null @@ -1,448 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package main contains mqtt-adapter main function to start the mqtt-adapter service. -package main - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "log" - "log/slog" - "net/http" - "net/url" - "os" - "os/signal" - "syscall" - "time" - - chclient "github.com/absmach/callhome/pkg/client" - mgate "github.com/absmach/mgate" - mgatemqtt "github.com/absmach/mgate/pkg/mqtt" - "github.com/absmach/mgate/pkg/mqtt/websocket" - "github.com/absmach/mgate/pkg/session" - mgtls "github.com/absmach/mgate/pkg/tls" - "github.com/absmach/supermq" - smqlog "github.com/absmach/supermq/logger" - "github.com/absmach/supermq/mqtt" - "github.com/absmach/supermq/mqtt/events" - mqtttracing "github.com/absmach/supermq/mqtt/tracing" - domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/grpcclient" - jaegerclient "github.com/absmach/supermq/pkg/jaeger" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/brokers" - brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" - msgevents "github.com/absmach/supermq/pkg/messaging/events" - "github.com/absmach/supermq/pkg/messaging/handler" - mqttpub "github.com/absmach/supermq/pkg/messaging/mqtt" - "github.com/absmach/supermq/pkg/server" - "github.com/absmach/supermq/pkg/uuid" - "github.com/caarlos0/env/v11" - "github.com/cenkalti/backoff/v4" - "github.com/eclipse/paho.mqtt.golang/packets" - "golang.org/x/sync/errgroup" -) - -const ( - svcName = "mqtt" - envPrefixCache = "SMQ_MQTT_ADAPTER_CACHE_" - envPrefixClients = "SMQ_CLIENTS_GRPC_" - envPrefixChannels = "SMQ_CHANNELS_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixMQTT = "SMQ_MQTT_ADAPTER_" - wsPathPrefix = "/mqtt" -) - -type config struct { - LogLevel string `env:"SMQ_MQTT_ADAPTER_LOG_LEVEL" envDefault:"info"` - MQTTPort string `env:"SMQ_MQTT_ADAPTER_MQTT_PORT" envDefault:"1883"` - MQTTTargetProtocol string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL" envDefault:"mqtt"` - MQTTTargetHost string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST" envDefault:"localhost"` - MQTTTargetPort string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT" envDefault:"1883"` - MQTTTargetUsername string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME" envDefault:""` - MQTTTargetPassword string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD" envDefault:""` - MQTTForwarderTimeout time.Duration `env:"SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT" envDefault:"30s"` - MQTTTargetHealthCheck string `env:"SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK" envDefault:""` - MQTTQoS uint8 `env:"SMQ_MQTT_ADAPTER_MQTT_QOS" envDefault:"1"` - HTTPPort string `env:"SMQ_MQTT_ADAPTER_WS_PORT" envDefault:"8080"` - HTTPTargetProtocol string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL" envDefault:"http"` - HTTPTargetHost string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_HOST" envDefault:"localhost"` - HTTPTargetPort string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PORT" envDefault:"8080"` - HTTPTargetPath string `env:"SMQ_MQTT_ADAPTER_WS_TARGET_PATH" envDefault:"/mqtt"` - Instance string `env:"SMQ_MQTT_ADAPTER_INSTANCE" envDefault:""` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - BrokerURL string `env:"SMQ_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_MQTT_ADAPTER_INSTANCE_ID" envDefault:""` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` -} - -func main() { - ctx, cancel := context.WithCancel(context.Background()) - g, ctx := errgroup.WithContext(ctx) - - cfg := config{} - if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) - } - - logger, err := smqlog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("failed to init logger: %s", err.Error()) - } - - var exitCode int - defer smqlog.ExitWithError(&exitCode) - - if cfg.InstanceID == "" { - if cfg.InstanceID, err = uuid.New().ID(); err != nil { - logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) - exitCode = 1 - return - } - } - - if cfg.MQTTTargetHealthCheck != "" { - notify := func(e error, next time.Duration) { - logger.Info(fmt.Sprintf("Broker not ready: %s, next try in %s", e.Error(), next)) - } - - err := backoff.RetryNotify(healthcheck(cfg), backoff.NewExponentialBackOff(), notify) - if err != nil { - logger.Error(fmt.Sprintf("MQTT healthcheck limit exceeded, exiting. %s ", err)) - exitCode = 1 - return - } - } - - serverConfig := server.Config{ - Host: cfg.HTTPTargetHost, - Port: cfg.HTTPTargetPort, - } - - tlsCfg, err := mgtls.NewConfig(env.Options{Prefix: envPrefixMQTT}) - if err != nil { - logger.Error(fmt.Sprintf("Failed to load TLS config: %s", err)) - exitCode = 1 - return - } - - tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) - if err != nil { - logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) - exitCode = 1 - return - } - defer func() { - if err := tp.Shutdown(ctx); err != nil { - logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) - } - }() - tracer := tp.Tracer(svcName) - - bsub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) - if err != nil { - logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) - exitCode = 1 - return - } - defer bsub.Close() - bsub = brokerstracing.NewPubSub(serverConfig, tracer, bsub) - - mpub, err := mqttpub.NewPublisher(fmt.Sprintf("mqtt://%s:%s", cfg.MQTTTargetHost, cfg.MQTTTargetPort), cfg.MQTTTargetUsername, cfg.MQTTTargetPassword, cfg.MQTTQoS, cfg.MQTTForwarderTimeout) - if err != nil { - logger.Error(fmt.Sprintf("failed to create MQTT publisher: %s", err)) - exitCode = 1 - return - } - defer mpub.Close() - - fwd := mqtt.NewForwarder(brokers.SubjectAllMessages, logger) - fwd = mqtttracing.New(serverConfig, tracer, fwd, brokers.SubjectAllMessages) - if err := fwd.Forward(ctx, svcName, bsub, mpub); err != nil { - logger.Error(fmt.Sprintf("failed to forward message broker messages: %s", err)) - exitCode = 1 - return - } - - np, err := brokers.NewPublisher(ctx, cfg.BrokerURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) - exitCode = 1 - return - } - defer np.Close() - np = brokerstracing.NewPublisher(serverConfig, tracer, np) - - np, err = msgevents.NewPublisherMiddleware(ctx, np, cfg.ESURL) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - - domsGrpcCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { - logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) - exitCode = 1 - return - } - _, domainsClient, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer domainsHandler.Close() - - clientsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) - exitCode = 1 - return - } - - clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer clientsHandler.Close() - logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) - - channelsClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { - logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) - exitCode = 1 - return - } - - channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer channelsHandler.Close() - logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) - - cacheConfig := messaging.CacheConfig{} - if err := env.ParseWithOptions(&cacheConfig, env.Options{Prefix: envPrefixCache}); err != nil { - logger.Error(fmt.Sprintf("failed to load cache configuration : %s", err)) - exitCode = 1 - return - } - parser, err := messaging.NewTopicParser(cacheConfig, channelsClient, domainsClient) - if err != nil { - logger.Error(fmt.Sprintf("failed to create topic parsers: %s", err)) - exitCode = 1 - return - } - - h := mqtt.NewHandler(np, logger, clientsClient, channelsClient, parser) - - h, err = events.NewEventStoreMiddleware(ctx, h, cfg.ESURL, cfg.Instance) - if err != nil { - logger.Error(fmt.Sprintf("failed to create event store middleware: %s", err)) - exitCode = 1 - return - } - - h = handler.NewTracing(tracer, h) - - if cfg.SendTelemetry { - chc := chclient.New(svcName, supermq.Version, logger, cancel) - go chc.CallHome(ctx) - } - - beforeHandler := beforeHandler{ - resolver: messaging.NewTopicResolver(channelsClient, domainsClient), - } - - afterHandler := afterHandler{ - username: cfg.MQTTTargetUsername, - password: cfg.MQTTTargetPassword, - } - logger.Info(fmt.Sprintf("Starting MQTT proxy on port %s", cfg.MQTTPort)) - g.Go(func() error { - return proxyMQTT(ctx, cfg, tlsCfg, logger, h, beforeHandler, afterHandler) - }) - - logger.Info(fmt.Sprintf("Starting MQTT over WS proxy on port %s", cfg.HTTPPort)) - g.Go(func() error { - return proxyWS(ctx, cfg, tlsCfg, logger, h, afterHandler) - }) - - g.Go(func() error { - return stopSignalHandler(ctx, cancel, logger) - }) - - if err := g.Wait(); err != nil { - logger.Error(fmt.Sprintf("mProxy terminated: %s", err)) - } -} - -func proxyMQTT(ctx context.Context, cfg config, tlsCfg mgtls.Config, logger *slog.Logger, sessionHandler session.Handler, beforeHandler, afterHandler session.Interceptor) error { - var err error - config := mgate.Config{ - Port: cfg.MQTTPort, - TargetHost: cfg.MQTTTargetHost, - TargetPort: cfg.MQTTTargetPort, - } - errCh := make(chan error) - - config.TLSConfig, err = mgtls.LoadTLSConfig(&tlsCfg, &tls.Config{}) - if err != nil { - return err - } - - mgate := mgatemqtt.New(config, sessionHandler, beforeHandler, afterHandler, logger) - - go func() { - errCh <- mgate.Listen(ctx) - }() - - select { - case <-ctx.Done(): - logger.Info(fmt.Sprintf("proxy MQTT shutdown at %s:%s", config.Host, config.Port)) - return nil - case err := <-errCh: - return err - } -} - -func proxyWS(ctx context.Context, cfg config, tlsCfg mgtls.Config, logger *slog.Logger, sessionHandler session.Handler, interceptor session.Interceptor) error { - var err error - config := mgate.Config{ - Port: cfg.HTTPPort, - TargetProtocol: "ws", - TargetHost: cfg.HTTPTargetHost, - TargetPort: cfg.HTTPTargetPort, - TargetPath: cfg.HTTPTargetPath, - PathPrefix: wsPathPrefix, - } - config.TLSConfig, err = mgtls.LoadTLSConfig(&tlsCfg, &tls.Config{}) - if err != nil { - return err - } - - wp := websocket.New(config, sessionHandler, nil, interceptor, logger) - http.HandleFunc(wsPathPrefix, wp.ServeHTTP) - - errCh := make(chan error) - - go func() { - errCh <- wp.Listen(ctx) - }() - - select { - case <-ctx.Done(): - logger.Info(fmt.Sprintf("proxy MQTT WS shutdown at %s:%s", config.Host, config.Port)) - return nil - case err := <-errCh: - return err - } -} - -func healthcheck(cfg config) func() error { - client := &http.Client{ - Timeout: 30 * time.Second, - } - return func() error { - res, err := client.Get(cfg.MQTTTargetHealthCheck) - if err != nil { - return err - } - defer res.Body.Close() - body, err := io.ReadAll(res.Body) - if err != nil { - return err - } - if res.StatusCode != http.StatusOK { - return errors.New(string(body)) - } - return nil - } -} - -func stopSignalHandler(ctx context.Context, cancel context.CancelFunc, logger *slog.Logger) error { - c := make(chan os.Signal, 2) - signal.Notify(c, syscall.SIGINT, syscall.SIGABRT) - select { - case sig := <-c: - defer cancel() - logger.Info(fmt.Sprintf("%s service shutdown by signal: %s", svcName, sig)) - return nil - case <-ctx.Done(): - return nil - } -} - -type afterHandler struct { - username string - password string -} - -// This interceptor adds the correct credentials to upstream MQTT broker since the downstream clients -// are authenticated to the MQTT adapter but not upstream MQTT broker. -func (ah afterHandler) Intercept(ctx context.Context, pkt packets.ControlPacket, dir session.Direction) (packets.ControlPacket, error) { - if connectPkt, ok := pkt.(*packets.ConnectPacket); ok { - if ah.username != "" { - connectPkt.Username = ah.username - connectPkt.UsernameFlag = true - } - if ah.password != "" { - connectPkt.Password = []byte(ah.password) - connectPkt.PasswordFlag = true - } - - return connectPkt, nil - } - - return pkt, nil -} - -type beforeHandler struct { - resolver messaging.TopicResolver -} - -// This interceptor is used to replace domain and channel routes with relevant domain and channel IDs in the message topic. -func (bh beforeHandler) Intercept(ctx context.Context, pkt packets.ControlPacket, dir session.Direction) (packets.ControlPacket, error) { - switch pt := pkt.(type) { - case *packets.SubscribePacket: - for i, topic := range pt.Topics { - ft, err := bh.resolver.ResolveTopic(ctx, topic) - if err != nil { - return nil, err - } - pt.Topics[i] = ft - } - - return pt, nil - case *packets.UnsubscribePacket: - for i, topic := range pt.Topics { - ft, err := bh.resolver.ResolveTopic(ctx, topic) - if err != nil { - return nil, err - } - pt.Topics[i] = ft - } - return pt, nil - case *packets.PublishPacket: - ft, err := bh.resolver.ResolveTopic(ctx, pt.TopicName) - if err != nil { - return nil, err - } - pt.TopicName = ft - - return pt, nil - } - - return pkt, nil -} diff --git a/cmd/notifications/main.go b/cmd/notifications/main.go index fd8bb90f0..86e65ac54 100644 --- a/cmd/notifications/main.go +++ b/cmd/notifications/main.go @@ -29,27 +29,27 @@ import ( const ( svcName = "notifications" - envPrefixUsers = "SMQ_USERS_GRPC_" + envPrefixUsers = "MG_USERS_GRPC_" defEmailPort = "25" ) type config struct { - LogLevel string `env:"SMQ_NOTIFICATIONS_LOG_LEVEL" envDefault:"info"` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_NOTIFICATIONS_INSTANCE_ID" envDefault:""` - DomainAltName string `env:"SMQ_NOTIFICATIONS_DOMAIN_ALT_NAME" envDefault:"domain"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - EmailHost string `env:"SMQ_EMAIL_HOST" envDefault:"localhost"` - EmailPort string `env:"SMQ_EMAIL_PORT" envDefault:"25"` - EmailUsername string `env:"SMQ_EMAIL_USERNAME" envDefault:""` - EmailPassword string `env:"SMQ_EMAIL_PASSWORD" envDefault:""` - EmailFromAddress string `env:"SMQ_EMAIL_FROM_ADDRESS" envDefault:"noreply@supermq.com"` - EmailFromName string `env:"SMQ_EMAIL_FROM_NAME" envDefault:"SuperMQ Notifications"` - InvitationTemplate string `env:"SMQ_EMAIL_INVITATION_TEMPLATE" envDefault:"docker/templates/invitation-sent-email.tmpl"` - AcceptanceTemplate string `env:"SMQ_EMAIL_ACCEPTANCE_TEMPLATE" envDefault:"docker/templates/invitation-accepted-email.tmpl"` - RejectionTemplate string `env:"SMQ_EMAIL_REJECTION_TEMPLATE" envDefault:"docker/templates/invitation-rejected-email.tmpl"` + LogLevel string `env:"MG_NOTIFICATIONS_LOG_LEVEL" envDefault:"info"` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_NOTIFICATIONS_INSTANCE_ID" envDefault:""` + DomainAltName string `env:"MG_NOTIFICATIONS_DOMAIN_ALT_NAME" envDefault:"domain"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + EmailHost string `env:"MG_EMAIL_HOST" envDefault:"localhost"` + EmailPort string `env:"MG_EMAIL_PORT" envDefault:"25"` + EmailUsername string `env:"MG_EMAIL_USERNAME" envDefault:""` + EmailPassword string `env:"MG_EMAIL_PASSWORD" envDefault:""` + EmailFromAddress string `env:"MG_EMAIL_FROM_ADDRESS" envDefault:"noreply@supermq.com"` + EmailFromName string `env:"MG_EMAIL_FROM_NAME" envDefault:"SuperMQ Notifications"` + InvitationTemplate string `env:"MG_EMAIL_INVITATION_TEMPLATE" envDefault:"docker/templates/invitation-sent-email.tmpl"` + AcceptanceTemplate string `env:"MG_EMAIL_ACCEPTANCE_TEMPLATE" envDefault:"docker/templates/invitation-accepted-email.tmpl"` + RejectionTemplate string `env:"MG_EMAIL_REJECTION_TEMPLATE" envDefault:"docker/templates/invitation-rejected-email.tmpl"` } func main() { @@ -131,7 +131,7 @@ func main() { notifier = middleware.NewMetrics(notifier, counter, latency) notifier = middleware.NewTracing(notifier, tp.Tracer(svcName)) - subscriber, err := store.NewSubscriber(ctx, cfg.ESURL, logger) + subscriber, err := store.NewSubscriber(ctx, cfg.ESURL, "notifications-es-sub", logger) if err != nil { logger.Error(fmt.Sprintf("failed to create subscriber: %s", err)) exitCode = 1 diff --git a/cmd/postgres-reader/main.go b/cmd/postgres-reader/main.go new file mode 100644 index 000000000..1b5f1f0b4 --- /dev/null +++ b/cmd/postgres-reader/main.go @@ -0,0 +1,197 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains postgres-reader main function to start the postgres-reader service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/authn/authsvc" + "github.com/absmach/supermq/pkg/grpcclient" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/server" + grpcserver "github.com/absmach/supermq/pkg/server/grpc" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/readers" + readersgrpcapi "github.com/absmach/supermq/readers/api/grpc" + httpapi "github.com/absmach/supermq/readers/api/http" + middleware "github.com/absmach/supermq/readers/middleware" + "github.com/absmach/supermq/readers/postgres" + "github.com/caarlos0/env/v11" + "github.com/jmoiron/sqlx" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +const ( + svcName = "postgres-reader" + envPrefixDB = "MG_POSTGRES_" + envPrefixHTTP = "MG_POSTGRES_READER_HTTP_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixClients = "MG_CLIENTS_GRPC_" + envPrefixChannels = "MG_CHANNELS_GRPC_" + defDB = "supermq" + defSvcHTTPPort = "9009" + defSvcGRPCPort = "7009" + envPrefixGrpc = "MG_POSTGRES_READER_GRPC_" +) + +type config struct { + LogLevel string `env:"MG_POSTGRES_READER_LOG_LEVEL" envDefault:"info"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_POSTGRES_READER_INSTANCE_ID" envDefault:""` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + dbConfig := pgclient.Config{} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + db, err := pgclient.Connect(dbConfig) + if err != nil { + logger.Error(fmt.Sprintf("failed to setup postgres database : %s", err)) + exitCode = 1 + return + } + defer db.Close() + + repo := newService(db, logger) + + grpcServerConfig := server.Config{Port: defSvcGRPCPort} + if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err.Error())) + exitCode = 1 + return + } + registerReadersServiceServer := func(srv *grpc.Server) { + reflection.Register(srv) + grpcReadersV1.RegisterReadersServiceServer(srv, readersgrpcapi.NewReadersServer(repo)) + } + + clientsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { + logger.Error(fmt.Sprintf("failed to load clients gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer clientsHandler.Close() + + logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) + + channelsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { + logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer channelsHandler.Close() + logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) + + authnCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&authnCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authnHandler.Close() + logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure()) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(repo, authn, clientsClient, channelsClient, svcName, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerReadersServiceServer, logger) + + g.Go(func() error { + return gs.Start() + }) + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Postgres reader service terminated: %s", err)) + } +} + +func newService(db *sqlx.DB, logger *slog.Logger) readers.MessageRepository { + svc := postgres.New(db) + svc = middleware.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics("postgres", "message_reader") + svc = middleware.MetricsMiddleware(svc, counter, latency) + + return svc +} diff --git a/cmd/postgres-writer/main.go b/cmd/postgres-writer/main.go new file mode 100644 index 000000000..9f620e84c --- /dev/null +++ b/cmd/postgres-writer/main.go @@ -0,0 +1,154 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains postgres-writer main function to start the postgres-writer service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/url" + "os" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + "github.com/absmach/supermq/consumers" + consumertracing "github.com/absmach/supermq/consumers/tracing" + httpapi "github.com/absmach/supermq/consumers/writers/api" + "github.com/absmach/supermq/consumers/writers/brokers" + writerpg "github.com/absmach/supermq/consumers/writers/postgres" + smqlog "github.com/absmach/supermq/logger" + jaegerclient "github.com/absmach/supermq/pkg/jaeger" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v11" + "github.com/jmoiron/sqlx" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "postgres-writer" + envPrefixDB = "MG_POSTGRES_" + envPrefixHTTP = "MG_POSTGRES_WRITER_HTTP_" + defDB = "messages" + defSvcHTTPPort = "9010" +) + +type config struct { + LogLevel string `env:"MG_POSTGRES_WRITER_LOG_LEVEL" envDefault:"info"` + ConfigPath string `env:"MG_POSTGRES_WRITER_CONFIG_PATH" envDefault:"/config.toml"` + BrokerURL string `env:"MG_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_POSTGRES_WRITER_INSTANCE_ID" envDefault:""` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Postgres configuration : %s", svcName, err)) + exitCode = 1 + return + } + db, err := pgclient.Setup(dbConfig, *writerpg.Migration()) + if err != nil { + logger.Error(err.Error()) + } + defer db.Close() + + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + exitCode = 1 + return + } + defer pubSub.Close() + pubSub = brokerstracing.NewPubSub(httpServerConfig, tracer, pubSub) + + repo := newService(db, logger) + repo = consumertracing.NewBlocking(tracer, repo, httpServerConfig) + + if err = consumers.Start(ctx, svcName, pubSub, repo, cfg.ConfigPath, brokers.AllTopic, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create Postgres writer: %s", err)) + exitCode = 1 + return + } + + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svcName, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Postgres writer service terminated: %s", err)) + } +} + +func newService(db *sqlx.DB, logger *slog.Logger) consumers.BlockingConsumer { + svc := writerpg.New(db) + svc = httpapi.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics("postgres", "message_writer") + svc = httpapi.MetricsMiddleware(svc, counter, latency) + return svc +} diff --git a/cmd/provision/main.go b/cmd/provision/main.go new file mode 100644 index 000000000..de0031d41 --- /dev/null +++ b/cmd/provision/main.go @@ -0,0 +1,221 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains provision main function to start the provision service. +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "reflect" + + chclient "github.com/absmach/callhome/pkg/client" + csdk "github.com/absmach/certs/sdk" + "github.com/absmach/supermq" + "github.com/absmach/supermq/channels" + "github.com/absmach/supermq/clients" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnsvc "github.com/absmach/supermq/pkg/authn/authsvc" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/grpcclient" + mgsdk "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/provision" + httpapi "github.com/absmach/supermq/provision/api" + "github.com/absmach/supermq/provision/middleware" + "github.com/caarlos0/env/v11" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "provision" + contentType = "application/json" + envPrefixAuth = "MG_AUTH_GRPC_" +) + +var ( + errMissingConfigFile = errors.New("missing config file setting") + errFailLoadingConfigFile = errors.New("failed to load config from file") + errFailedToReadBootstrapContent = errors.New("failed to read bootstrap content from envs") +) + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg, err := loadConfig() + if err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.Server.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + grpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&grpcCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + + return + } + authn, authnClient, err := authnsvc.NewAuthentication(ctx, grpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + am := smqauthn.NewAuthNMiddleware(authn) + + if cfgFromFile, err := loadConfigFromFile(cfg.File); err != nil { + logger.Warn(fmt.Sprintf("Continue with settings from env, failed to load from: %s: %s", cfg.File, err)) + } else { + // Merge environment variables and file settings. + mergeConfigs(&cfgFromFile, &cfg) + cfg = cfgFromFile + logger.Info("Continue with settings from file: " + cfg.File) + } + + SDKCfg := mgsdk.Config{ + UsersURL: cfg.Server.UsersURL, + ChannelsURL: cfg.Server.ChannelsURL, + ClientsURL: cfg.Server.ClientsURL, + BootstrapURL: cfg.Server.MgBSURL, + CertsURL: cfg.Server.CertsURL, + MsgContentType: contentType, + TLSVerification: cfg.Server.TLS, + } + mgSdk := mgsdk.NewSDK(SDKCfg) + + csdkConf := csdk.Config{ + CertsURL: cfg.Server.CertsURL, + } + + cSdk := csdk.NewSDK(csdkConf) + + svc := provision.New(cfg, mgSdk, cSdk, logger) + svc = middleware.NewLogging(svc, logger) + + httpServerConfig := server.Config{Host: "", Port: cfg.Server.Port, KeyFile: cfg.Server.ServerKey, CertFile: cfg.Server.ServerCert} + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, logger, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Provision service terminated: %s", err)) + } +} + +func loadConfigFromFile(file string) (provision.Config, error) { + _, err := os.Stat(file) + if os.IsNotExist(err) { + return provision.Config{}, errors.Wrap(errMissingConfigFile, err) + } + c, err := provision.Read(file) + if err != nil { + return provision.Config{}, errors.Wrap(errFailLoadingConfigFile, err) + } + return c, nil +} + +func loadConfig() (provision.Config, error) { + cfg := provision.Config{} + if err := env.Parse(&cfg); err != nil { + return provision.Config{}, err + } + + if cfg.Bootstrap.AutoWhiteList && !cfg.Bootstrap.Provision { + return provision.Config{}, errors.New("Can't auto whitelist if auto config save is off") + } + + var content map[string]any + if cfg.BSContent != "" { + if err := json.Unmarshal([]byte(cfg.BSContent), &content); err != nil { + return provision.Config{}, errFailedToReadBootstrapContent + } + } + + cfg.Bootstrap.Content = content + // This is default conf for provision if there is no config file + cfg.Channels = []channels.Channel{ + { + Name: "control-channel", + Metadata: map[string]any{"type": "control"}, + }, { + Name: "data-channel", + Metadata: map[string]any{"type": "data"}, + }, + } + cfg.Clients = []clients.Client{ + { + Name: "client", + Metadata: map[string]any{"external_id": "xxxxxx"}, + }, + } + + return cfg, nil +} + +func mergeConfigs(dst, src any) any { + d := reflect.ValueOf(dst).Elem() + s := reflect.ValueOf(src).Elem() + + for i := 0; i < d.NumField(); i++ { + dField := d.Field(i) + sField := s.Field(i) + switch dField.Kind() { + case reflect.Struct: + dst := dField.Addr().Interface() + src := sField.Addr().Interface() + m := mergeConfigs(dst, src) + val := reflect.ValueOf(m).Elem().Interface() + dField.Set(reflect.ValueOf(val)) + case reflect.Slice: + case reflect.Bool: + if dField.Interface() == false { + dField.Set(reflect.ValueOf(sField.Interface())) + } + case reflect.Int: + if dField.Interface() == 0 { + dField.Set(reflect.ValueOf(sField.Interface())) + } + case reflect.String: + if dField.Interface() == "" { + dField.Set(reflect.ValueOf(sField.Interface())) + } + } + } + return dst +} diff --git a/cmd/re/main.go b/cmd/re/main.go new file mode 100644 index 000000000..6954f3c64 --- /dev/null +++ b/cmd/re/main.go @@ -0,0 +1,445 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains rule engine main function to start the service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/url" + "os" + "time" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + abrokers "github.com/absmach/supermq/alarms/brokers" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + "github.com/absmach/supermq/consumers/writers/brokers" + dpostgres "github.com/absmach/supermq/domains/postgres" + "github.com/absmach/supermq/internal/email" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnsvc "github.com/absmach/supermq/pkg/authn/authsvc" + mgauthz "github.com/absmach/supermq/pkg/authz" + authzsvc "github.com/absmach/supermq/pkg/authz/authsvc" + "github.com/absmach/supermq/pkg/callout" + dconsumer "github.com/absmach/supermq/pkg/domains/events/consumer" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/emailer" + "github.com/absmach/supermq/pkg/grpcclient" + jaegerclient "github.com/absmach/supermq/pkg/jaeger" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + smqbrokers "github.com/absmach/supermq/pkg/messaging/brokers" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/policies/spicedb" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + spicedbdecoder "github.com/absmach/supermq/pkg/spicedb" + "github.com/absmach/supermq/pkg/ticker" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/re" + httpapi "github.com/absmach/supermq/re/api" + "github.com/absmach/supermq/re/events" + "github.com/absmach/supermq/re/middleware" + "github.com/absmach/supermq/re/operations" + repg "github.com/absmach/supermq/re/postgres" + grpcClient "github.com/absmach/supermq/readers/api/grpc" + "github.com/authzed/authzed-go/v1" + "github.com/authzed/grpcutil" + "github.com/caarlos0/env/v11" + "github.com/go-chi/chi/v5" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const ( + svcName = "rules_engine" + envPrefixDB = "MG_RE_DB_" + envPrefixHTTP = "MG_RE_HTTP_" + envPrefixCallout = "MG_RE_CALLOUT_" + envPrefixAuth = "MG_AUTH_GRPC_" + defDB = "r" + defSvcHTTPPort = "9008" + envPrefixGrpc = "MG_TIMESCALE_READER_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" +) + +// We use a buffered channel to prevent blocking, as logging is an expensive operation. +// A larger buffer size would also work, but we’d likely need another instance of RE in that case. +// A smaller size would probably work too, but there's no need to be that frugal with resources. +const channBuffer = 256 + +type config struct { + LogLevel string `env:"MG_RE_LOG_LEVEL" envDefault:"info"` + InstanceID string `env:"MG_RE_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` + ESConsumerName string `env:"MG_RE_EVENT_CONSUMER" envDefault:"rules_engine"` + CacheURL string `env:"MG_RE_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_RE_CACHE_KEY_DURATION" envDefault:"10m"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + BrokerURL string `env:"MG_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + var logger *slog.Logger + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + ec := email.Config{} + if err := env.Parse(&ec); err != nil { + logger.Error(fmt.Sprintf("failed to load email configuration : %s", err)) + exitCode = 1 + + return + } + + callCfg := callout.Config{} + if err := env.ParseWithOptions(&callCfg, env.Options{Prefix: envPrefixCallout}); err != nil { + logger.Error(fmt.Sprintf("failed to parse callout config : %s", err)) + exitCode = 1 + return + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + migration, err := repg.Migration() + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + db, err := pgclient.Setup(dbConfig, *migration) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + defer db.Close() + + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) + exitCode = 1 + + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + + return + } + + callout, err := callout.New(callCfg) + if err != nil { + logger.Error(fmt.Sprintf("failed to create new callout: %s", err)) + exitCode = 1 + return + } + + msgSub, err := smqbrokers.NewPubSub(ctx, cfg.BrokerURL, logger, smqbrokers.ConnectionName("re-msg-pubsub")) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker for mg pubSub: %s", err)) + exitCode = 1 + + return + } + defer msgSub.Close() + msgSub = brokerstracing.NewPubSub(httpServerConfig, tracer, msgSub) + + writersPub, err := brokers.NewPublisher(ctx, cfg.BrokerURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker for writers publisher: %s", err)) + exitCode = 1 + + return + } + defer writersPub.Close() + writersPub = brokerstracing.NewPublisher(httpServerConfig, tracer, writersPub) + + alarmsPub, err := abrokers.NewPublisher(ctx, cfg.BrokerURL) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker for alarms publisher: %s", err)) + exitCode = 1 + + return + } + defer alarmsPub.Close() + alarmsPub = brokerstracing.NewPublisher(httpServerConfig, tracer, alarmsPub) + + grpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&grpcCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + + return + } + authn, authnClient, err := authnsvc.NewAuthentication(ctx, grpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + am := smqauthn.NewAuthNMiddleware(authn) + + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + runInfo := make(chan pkglog.RunInfo, channBuffer) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzClient, err := authzsvc.NewAuthorization(ctx, grpcCfg, domAuthz) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authzClient.Close() + logger.Info("AuthZ successfully connected to auth gRPC server " + authnClient.Secure()) + + database := pgclient.NewDatabase(db, dbConfig, tracer) + + ddatabase := pgclient.NewDatabase(db, dbConfig, tracer) + drepo := dpostgres.NewRepository(ddatabase) + + if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) + exitCode = 1 + return + } + + regrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(®rpcCfg, env.Options{Prefix: envPrefixGrpc}); err != nil { + logger.Error(fmt.Sprintf("failed to load clients gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + client, err := grpcclient.NewHandler(regrpcCfg) + if err != nil { + exitCode = 1 + return + } + defer client.Close() + + readersClient := grpcClient.NewReadersClient(client.Connection(), regrpcCfg.Timeout) + logger.Info("Readers gRPC client successfully connected to readers gRPC server " + client.Secure()) + + svc, err := newService(ctx, cfg, database, runInfo, msgSub, writersPub, alarmsPub, authz, ec, logger, readersClient, callout, tracer) + if err != nil { + logger.Error(fmt.Sprintf("failed to create services: %s", err)) + exitCode = 1 + + return + } + subCfg := messaging.SubscriberConfig{ + ID: svcName, + Topic: smqbrokers.SubjectAllMessages, + DeliveryPolicy: messaging.DeliverAllPolicy, + Handler: svc, + } + if err := msgSub.Subscribe(ctx, subCfg); err != nil { + logger.Error(fmt.Sprintf("failed to subscribe to internal message broker: %s", err)) + exitCode = 1 + + return + } + + go func() { + for info := range runInfo { + logger.LogAttrs(context.Background(), info.Level, info.Message, info.Details...) + } + }() + + mux := chi.NewRouter() + + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + g.Go(func() error { + return svc.StartScheduler(ctx) + }) + + g.Go(func() error { + return httpSvc.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, httpSvc) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err)) + } +} + +func newService(ctx context.Context, cfg config, db pgclient.Database, runInfo chan pkglog.RunInfo, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, authz mgauthz.Authorization, ec email.Config, logger *slog.Logger, readersClient grpcReadersV1.ReadersServiceClient, callout callout.Callout, tracer trace.Tracer) (re.Service, error) { + repo := repg.NewRepository(db) + idp := uuid.New() + + emailerClient, err := emailer.New(&ec) + if err != nil { + logger.Error(fmt.Sprintf("failed to configure e-mailing util: %s", err.Error())) + } + + policyService, err := newSpiceDBPolicyServiceEvaluator(cfg, logger) + if err != nil { + return nil, err + } + logger.Info("Policy service successfully connected to SpiceDB gRPC server") + + availableActions, builtInRoles, err := availableActionsAndBuiltInRoles(cfg.SpicedbSchemaFile) + if err != nil { + return nil, fmt.Errorf("failed to get available actions and built-in roles: %w", err) + } + + csvc, err := re.NewService(repo, runInfo, policyService, idp, rePubSub, writersPub, alarmsPub, ticker.NewTicker(time.Second*30), emailerClient, readersClient, availableActions, builtInRoles) + if err != nil { + return nil, fmt.Errorf("failed to create RE service: %w", err) + } + + csvc, err = events.NewEventStoreMiddleware(ctx, csvc, cfg.ESURL) + if err != nil { + return nil, fmt.Errorf("failed to init re event store middleware: %w", err) + } + + permConfig, err := permissions.ParsePermissionsFile(cfg.PermissionsFile) + if err != nil { + return nil, fmt.Errorf("failed to parse permissions file: %w", err) + } + + ruleOps, ruleRoleOps, err := permConfig.GetEntityPermissions(operations.EntityType) + if err != nil { + return nil, fmt.Errorf("failed to get rule permissions: %w", err) + } + + entitiesOps, err := permissions.NewEntitiesOperations( + permissions.EntitiesPermission{ + operations.EntityType: ruleOps, + }, + permissions.EntitiesOperationDetails[permissions.Operation]{ + operations.EntityType: operations.OperationDetails(), + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create entities operations: %w", err) + } + + roleOps, err := permissions.NewOperations(roles.Operations(), ruleRoleOps) + if err != nil { + return nil, fmt.Errorf("failed to create role operations: %w", err) + } + + csvc, err = middleware.AuthorizationMiddleware(csvc, authz, entitiesOps, roleOps) + if err != nil { + return nil, err + } + csvc, err = middleware.NewCallout(csvc, callout, entitiesOps, roleOps) + if err != nil { + return nil, err + } + csvc = middleware.LoggingMiddleware(csvc, logger) + counter, latency := prometheus.MakeMetrics("re", "api") + csvc = middleware.NewMetricsMiddleware(counter, latency, csvc) + csvc = middleware.NewTracingMiddleware(tracer, csvc) + + return csvc, nil +} + +func newSpiceDBPolicyServiceEvaluator(cfg config, logger *slog.Logger) (policies.Service, error) { + client, err := authzed.NewClientWithExperimentalAPIs( + fmt.Sprintf("%s:%s", cfg.SpicedbHost, cfg.SpicedbPort), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpcutil.WithInsecureBearerToken(cfg.SpicedbPreSharedKey), + ) + if err != nil { + return nil, err + } + ps := spicedb.NewPolicyService(client, logger) + + return ps, nil +} + +func availableActionsAndBuiltInRoles(spicedbSchemaFile string) ([]roles.Action, map[roles.BuiltInRoleName][]roles.Action, error) { + availableActions, err := spicedbdecoder.GetActionsFromSchema(spicedbSchemaFile, operations.EntityType) + if err != nil { + return []roles.Action{}, map[roles.BuiltInRoleName][]roles.Action{}, err + } + + builtInRoles := map[roles.BuiltInRoleName][]roles.Action{ + re.BuiltInRoleAdmin: availableActions, + } + + return availableActions, builtInRoles, err +} diff --git a/cmd/reports/main.go b/cmd/reports/main.go new file mode 100644 index 000000000..13acb3f31 --- /dev/null +++ b/cmd/reports/main.go @@ -0,0 +1,421 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains reports main function to start the service. +package main + +import ( + "context" + "embed" + "fmt" + "log" + "log/slog" + "net/url" + "os" + "time" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + dpostgres "github.com/absmach/supermq/domains/postgres" + "github.com/absmach/supermq/internal/email" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnsvc "github.com/absmach/supermq/pkg/authn/authsvc" + mgauthz "github.com/absmach/supermq/pkg/authz" + authzsvc "github.com/absmach/supermq/pkg/authz/authsvc" + "github.com/absmach/supermq/pkg/callout" + dconsumer "github.com/absmach/supermq/pkg/domains/events/consumer" + domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/emailer" + "github.com/absmach/supermq/pkg/grpcclient" + jaegerclient "github.com/absmach/supermq/pkg/jaeger" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/policies/spicedb" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + spicedbdecoder "github.com/absmach/supermq/pkg/spicedb" + "github.com/absmach/supermq/pkg/ticker" + "github.com/absmach/supermq/pkg/uuid" + grpcClient "github.com/absmach/supermq/readers/api/grpc" + "github.com/absmach/supermq/reports" + httpapi "github.com/absmach/supermq/reports/api" + "github.com/absmach/supermq/reports/middleware" + "github.com/absmach/supermq/reports/operations" + repg "github.com/absmach/supermq/reports/postgres" + "github.com/authzed/authzed-go/v1" + "github.com/authzed/grpcutil" + "github.com/caarlos0/env/v11" + "github.com/go-chi/chi/v5" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const ( + svcName = "reports" + envPrefixDB = "MG_REPORTS_DB_" + envPrefixHTTP = "MG_REPORTS_HTTP_" + envPrefixCallout = "MG_REPORTS_CALLOUT_" + envPrefixAuth = "MG_AUTH_GRPC_" + defDB = "repo" + defSvcHTTPPort = "9017" + envPrefixGrpc = "MG_TIMESCALE_READER_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + templatePath = "template/reports_default_template.html" + reportEntity = "report" +) + +// We use a buffered channel to prevent blocking, as logging is an expensive operation. +const channBuffer = 256 + +//go:embed template/reports_default_template.html +var templateFS embed.FS + +type config struct { + LogLevel string `env:"MG_REPORTS_LOG_LEVEL" envDefault:"info"` + InstanceID string `env:"MG_REPORTS_INSTANCE_ID" envDefault:""` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` + ESConsumerName string `env:"MG_REPORTS_EVENT_CONSUMER" envDefault:"reports"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + BrokerURL string `env:"MG_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + DefaultTemplatePath string `env:"MG_REPORTS_DEFAULT_TEMPLATE" envDefault:""` + ConverterURL string `env:"MG_PDF_CONVERTER_URL" envDefault:"http://localhost:4000/pdf"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"schema.zed"` + PermissionsFile string `env:"MG_PERMISSIONS_FILE" envDefault:"permission.yaml"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + var logger *slog.Logger + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + var templateData []byte + + switch cfg.DefaultTemplatePath { + case "": + templateData, err = templateFS.ReadFile(templatePath) + default: + templateData, err = os.ReadFile(templatePath) + } + + if err != nil { + logger.Error(fmt.Sprintf("failed to read report template: %s", err)) + exitCode = 1 + return + } + + template := reports.ReportTemplate(string(templateData)) + + if err := template.Validate(); err != nil { + logger.Error(fmt.Sprintf("failed to validate report template: %s", err)) + exitCode = 1 + return + } + logger.Info("Report template validated successfully") + + ec := email.Config{} + if err := env.Parse(&ec); err != nil { + logger.Error(fmt.Sprintf("failed to load email configuration : %s", err)) + exitCode = 1 + + return + } + + callCfg := callout.Config{} + if err := env.ParseWithOptions(&callCfg, env.Options{Prefix: envPrefixCallout}); err != nil { + logger.Error(fmt.Sprintf("failed to parse callout config : %s", err)) + exitCode = 1 + return + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + + migration, err := repg.Migration() + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + + db, err := pgclient.Setup(dbConfig, *migration) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + defer db.Close() + + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) + exitCode = 1 + + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + + return + } + + callout, err := callout.New(callCfg) + if err != nil { + logger.Error(fmt.Sprintf("failed to create new callout: %s", err)) + exitCode = 1 + return + } + + grpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&grpcCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + + return + } + authn, authnClient, err := authnsvc.NewAuthentication(ctx, grpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + + return + } + am := smqauthn.NewAuthNMiddleware(authn) + defer authnClient.Close() + logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + + domsGrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { + logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) + exitCode = 1 + return + } + domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer domainsHandler.Close() + + authz, authzClient, err := authzsvc.NewAuthorization(ctx, grpcCfg, domAuthz) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authzClient.Close() + logger.Info("AuthZ successfully connected to auth gRPC server " + authnClient.Secure()) + + ddatabase := pgclient.NewDatabase(db, dbConfig, tracer) + drepo := dpostgres.NewRepository(ddatabase) + + if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) + exitCode = 1 + return + } + + database := pgclient.NewDatabase(db, dbConfig, tracer) + regrpcCfg := grpcclient.Config{} + if err := env.ParseWithOptions(®rpcCfg, env.Options{Prefix: envPrefixGrpc}); err != nil { + logger.Error(fmt.Sprintf("failed to load clients gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + client, err := grpcclient.NewHandler(regrpcCfg) + if err != nil { + exitCode = 1 + return + } + defer client.Close() + + readersClient := grpcClient.NewReadersClient(client.Connection(), regrpcCfg.Timeout) + logger.Info("Readers gRPC client successfully connected to readers gRPC server " + client.Secure()) + + runInfo := make(chan pkglog.RunInfo, channBuffer) + + svc, err := newService(cfg, database, runInfo, authz, ec, logger, readersClient, template, callout, tracer) + if err != nil { + logger.Error(fmt.Sprintf("failed to create services: %s", err)) + exitCode = 1 + + return + } + + go func() { + for info := range runInfo { + logger.LogAttrs(context.Background(), info.Level, info.Message, info.Details...) + } + }() + + mux := chi.NewRouter() + + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, am, mux, logger, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + g.Go(func() error { + return svc.StartScheduler(ctx) + }) + + g.Go(func() error { + return httpSvc.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, httpSvc) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err)) + } +} + +func newService(cfg config, db pgclient.Database, runInfo chan pkglog.RunInfo, authz mgauthz.Authorization, ec email.Config, logger *slog.Logger, readersClient grpcReadersV1.ReadersServiceClient, template reports.ReportTemplate, callout callout.Callout, tracer trace.Tracer) (reports.Service, error) { + repo := repg.NewRepository(db) + idp := uuid.New() + + emailClient, err := emailer.New(&ec) + if err != nil { + logger.Error(fmt.Sprintf("failed to configure e-mailing util: %s", err.Error())) + } + + policyService, err := newSpiceDBPolicyServiceEvaluator(cfg, logger) + if err != nil { + return nil, err + } + logger.Info("Policy service successfully connected to SpiceDB gRPC server") + + availableActions, builtInRoles, err := availableActionsAndBuiltInRoles(cfg.SpicedbSchemaFile) + if err != nil { + return nil, fmt.Errorf("failed to get available actions and built-in roles: %w", err) + } + + csvc, err := reports.NewService(repo, runInfo, policyService, idp, ticker.NewTicker(time.Second*30), emailClient, readersClient, template, cfg.ConverterURL, availableActions, builtInRoles) + if err != nil { + return nil, fmt.Errorf("failed to create reports service: %w", err) + } + + permConfig, err := permissions.ParsePermissionsFile(cfg.PermissionsFile) + if err != nil { + return nil, fmt.Errorf("failed to parse permissions file: %w", err) + } + + reportOps, reportRoleOps, err := permConfig.GetEntityPermissions(reportEntity) + if err != nil { + return nil, fmt.Errorf("failed to get report permissions: %w", err) + } + + entitiesOps, err := permissions.NewEntitiesOperations( + permissions.EntitiesPermission{ + operations.EntityType: reportOps, + }, + permissions.EntitiesOperationDetails[permissions.Operation]{ + operations.EntityType: operations.OperationDetails(), + }, + ) + if err != nil { + return nil, fmt.Errorf("failed to create entities operations: %w", err) + } + + roleOps, err := permissions.NewOperations(roles.Operations(), reportRoleOps) + if err != nil { + return nil, fmt.Errorf("failed to create role operations: %w", err) + } + + csvc, err = middleware.AuthorizationMiddleware(csvc, authz, entitiesOps, roleOps) + if err != nil { + return nil, err + } + csvc, err = middleware.NewCallout(csvc, callout, entitiesOps, roleOps) + if err != nil { + return nil, err + } + csvc = middleware.LoggingMiddleware(csvc, logger) + counter, latency := prometheus.MakeMetrics("reports", "api") + csvc = middleware.NewMetricsMiddleware(counter, latency, csvc) + csvc = middleware.NewTracingMiddleware(tracer, csvc) + + return csvc, nil +} + +func newSpiceDBPolicyServiceEvaluator(cfg config, logger *slog.Logger) (policies.Service, error) { + client, err := authzed.NewClientWithExperimentalAPIs( + fmt.Sprintf("%s:%s", cfg.SpicedbHost, cfg.SpicedbPort), + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpcutil.WithInsecureBearerToken(cfg.SpicedbPreSharedKey), + ) + if err != nil { + return nil, err + } + ps := spicedb.NewPolicyService(client, logger) + + return ps, nil +} + +func availableActionsAndBuiltInRoles(spicedbSchemaFile string) ([]roles.Action, map[roles.BuiltInRoleName][]roles.Action, error) { + availableActions, err := spicedbdecoder.GetActionsFromSchema(spicedbSchemaFile, reportEntity) + if err != nil { + return []roles.Action{}, map[roles.BuiltInRoleName][]roles.Action{}, err + } + + builtInRoles := map[roles.BuiltInRoleName][]roles.Action{ + reports.BuiltInRoleAdmin: availableActions, + } + + return availableActions, builtInRoles, err +} diff --git a/cmd/reports/template/reports_default_template.html b/cmd/reports/template/reports_default_template.html new file mode 100644 index 000000000..1785e0931 --- /dev/null +++ b/cmd/reports/template/reports_default_template.html @@ -0,0 +1,479 @@ + + + + + + + + + {{.Title}} + + + + {{if gt (len .Reports) 0}} + {{$firstPageRows := 24}} + {{$continuationPageRows := 32}} + {{$totalPages := 0}} + + {{/* Calculate total pages across all reports */}} + {{range $report := .Reports}} + {{$totalMessages := len .Messages}} + {{$reportPages := 1}} + {{if gt $totalMessages $firstPageRows}} + {{$remaining := sub $totalMessages $firstPageRows}} + {{$additionalPages := div $remaining $continuationPageRows}} + {{if gt (mod $remaining $continuationPageRows) 0}} + {{$additionalPages = add $additionalPages 1}} + {{end}} + {{$reportPages = add 1 $additionalPages}} + {{end}} + {{$totalPages = add $totalPages $reportPages}} + {{end}} + + {{$globalPage := 0}} + + {{range $reportIndex, $report := .Reports}} + {{$totalMessages := len .Messages}} + {{$pageCount := 1}} + {{if gt $totalMessages $firstPageRows}} + {{$remaining := sub $totalMessages $firstPageRows}} + {{$additionalPages := div $remaining $continuationPageRows}} + {{if gt (mod $remaining $continuationPageRows) 0}} + {{$additionalPages = add $additionalPages 1}} + {{end}} + {{$pageCount = add 1 $additionalPages}} + {{end}} + + {{range $pageNum := iterate $pageCount}} + {{$globalPage = add $globalPage 1}} + {{$isFirstPage := eq $pageNum 0}} + {{$startRow := getStartRow $pageNum $firstPageRows $continuationPageRows}} + {{$endRow := getEndRow $pageNum $firstPageRows $continuationPageRows $totalMessages}} + +
+
+
+
+
+
{{$.Title}}
+
{{$.GeneratedDate}}{{if $.Timezone}} ({{$.Timezone}}){{end}}
+
+
+
+ +
+ {{if $isFirstPage}} +
+
Metrics
+
+
+
Name:
+
{{$report.Metric.Name}}
+
+ {{if $report.Metric.ClientID}} +
+
Device ID:
+
{{$report.Metric.ClientID}}
+
+ {{end}} +
+
Channel ID:
+
{{$report.Metric.ChannelID}}
+
+
+
+ +
+ Total Records: {{$totalMessages}} +
+ {{else}} +
+
Metrics (continued)
+
+ {{end}} + +
+
+ + + + + + + + + + + + {{range $msgIndex, $msg := $report.Messages}} + {{if and (ge $msgIndex $startRow) (lt $msgIndex $endRow)}} + + + + + + + + {{end}} + {{end}} + +
TimeValueUnitProtocolSubtopic
{{formatTime $msg.Time}}{{formatValue $msg}}{{$msg.Unit}}{{$msg.Protocol}}{{$msg.Subtopic}}
+
+
+ + +
+ {{end}} + {{end}} + {{else}} +
+
+
+
+
+
{{.Title}}
+
{{.GeneratedDate}}{{if .Timezone}} ({{.Timezone}}){{end}}
+
+
+
+ +
+
+
Metrics
+
+
+
Name:
+
No Report
+
+
+
Channel ID:
+
N/A
+
+
+
+ +
+ Total Records: 0 +
+ +
+
+ + + + + + + + + + + + + + + +
TimeValueUnitProtocolSubtopic
No data available
+
+
+ + +
+ {{end}} + + + diff --git a/cmd/timescale-reader/main.go b/cmd/timescale-reader/main.go new file mode 100644 index 000000000..d6a5165f1 --- /dev/null +++ b/cmd/timescale-reader/main.go @@ -0,0 +1,197 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains timescale-reader main function to start the timescale-reader service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "os" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/authn/authsvc" + "github.com/absmach/supermq/pkg/grpcclient" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/server" + grpcserver "github.com/absmach/supermq/pkg/server/grpc" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/readers" + readersgrpcapi "github.com/absmach/supermq/readers/api/grpc" + httpapi "github.com/absmach/supermq/readers/api/http" + middleware "github.com/absmach/supermq/readers/middleware" + "github.com/absmach/supermq/readers/timescale" + "github.com/caarlos0/env/v11" + "github.com/jmoiron/sqlx" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" +) + +const ( + svcName = "timescaledb-reader" + envPrefixDB = "MG_TIMESCALE_" + envPrefixHTTP = "MG_TIMESCALE_READER_HTTP_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixClients = "MG_CLIENTS_GRPC_" + envPrefixChannels = "MG_CHANNELS_GRPC_" + defDB = "messages" + defSvcHTTPPort = "9011" + defSvcGRPCPort = "7011" + envPrefixGrpc = "MG_TIMESCALE_READER_GRPC_" +) + +type config struct { + LogLevel string `env:"MG_TIMESCALE_READER_LOG_LEVEL" envDefault:"info"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_TIMESCALE_READER_INSTANCE_ID" envDefault:""` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + db, err := pgclient.Connect(dbConfig) + if err != nil { + logger.Error(err.Error()) + } + defer db.Close() + + repo := newService(db, logger) + + grpcServerConfig := server.Config{ + Port: defSvcGRPCPort, + } + if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err.Error())) + exitCode = 1 + return + } + registerReadersServiceServer := func(srv *grpc.Server) { + reflection.Register(srv) + grpcReadersV1.RegisterReadersServiceServer(srv, readersgrpcapi.NewReadersServer(repo)) + } + + clientsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&clientsClientCfg, env.Options{Prefix: envPrefixClients}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s auth configuration : %s", svcName, err)) + exitCode = 1 + return + } + + clientsClient, clientsHandler, err := grpcclient.SetupClientsClient(ctx, clientsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer clientsHandler.Close() + + logger.Info("Clients service gRPC client successfully connected to clients gRPC server " + clientsHandler.Secure()) + + channelsClientCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&channelsClientCfg, env.Options{Prefix: envPrefixChannels}); err != nil { + logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + channelsClient, channelsHandler, err := grpcclient.SetupChannelsClient(ctx, channelsClientCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer channelsHandler.Close() + logger.Info("Channels service gRPC client successfully connected to channels gRPC server " + channelsHandler.Secure()) + + authnCfg := grpcclient.Config{} + if err := env.ParseWithOptions(&authnCfg, env.Options{Prefix: envPrefixAuth}); err != nil { + logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err)) + exitCode = 1 + return + } + + authn, authnHandler, err := authsvc.NewAuthentication(ctx, authnCfg) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer authnHandler.Close() + logger.Info("authn successfully connected to auth gRPC server " + authnHandler.Secure()) + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(repo, authn, clientsClient, channelsClient, svcName, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + gs := grpcserver.NewServer(ctx, cancel, svcName, grpcServerConfig, registerReadersServiceServer, logger) + + g.Go(func() error { + return gs.Start() + }) + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Timescale reader service terminated: %s", err)) + } +} + +func newService(db *sqlx.DB, logger *slog.Logger) readers.MessageRepository { + svc := timescale.New(db) + svc = middleware.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics("timescale", "message_reader") + svc = middleware.MetricsMiddleware(svc, counter, latency) + + return svc +} diff --git a/cmd/timescale-writer/main.go b/cmd/timescale-writer/main.go new file mode 100644 index 000000000..49bb25f18 --- /dev/null +++ b/cmd/timescale-writer/main.go @@ -0,0 +1,156 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains timescale-writer main function to start the timescale-writer service. +package main + +import ( + "context" + "fmt" + "log" + "log/slog" + "net/url" + "os" + + chclient "github.com/absmach/callhome/pkg/client" + "github.com/absmach/supermq" + "github.com/absmach/supermq/consumers" + consumertracing "github.com/absmach/supermq/consumers/tracing" + httpapi "github.com/absmach/supermq/consumers/writers/api" + "github.com/absmach/supermq/consumers/writers/brokers" + "github.com/absmach/supermq/consumers/writers/timescale" + smqlog "github.com/absmach/supermq/logger" + jaegerclient "github.com/absmach/supermq/pkg/jaeger" + brokerstracing "github.com/absmach/supermq/pkg/messaging/brokers/tracing" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/prometheus" + "github.com/absmach/supermq/pkg/server" + httpserver "github.com/absmach/supermq/pkg/server/http" + "github.com/absmach/supermq/pkg/uuid" + "github.com/caarlos0/env/v11" + "github.com/jmoiron/sqlx" + "golang.org/x/sync/errgroup" +) + +const ( + svcName = "timescaledb-writer" + envPrefixDB = "MG_TIMESCALE_" + envPrefixHTTP = "MG_TIMESCALE_WRITER_HTTP_" + defDB = "messages" + defSvcHTTPPort = "9012" +) + +type config struct { + LogLevel string `env:"MG_TIMESCALE_WRITER_LOG_LEVEL" envDefault:"info"` + ConfigPath string `env:"MG_TIMESCALE_WRITER_CONFIG_PATH" envDefault:"/config.toml"` + BrokerURL string `env:"MG_MESSAGE_BROKER_URL" envDefault:"nats://localhost:4222"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_TIMESCALE_WRITER_INSTANCE_ID" envDefault:""` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` +} + +func main() { + ctx, cancel := context.WithCancel(context.Background()) + g, ctx := errgroup.WithContext(ctx) + + cfg := config{} + if err := env.Parse(&cfg); err != nil { + log.Fatalf("failed to load %s service configuration : %s", svcName, err) + } + + logger, err := smqlog.New(os.Stdout, cfg.LogLevel) + if err != nil { + log.Fatalf("failed to init logger: %s", err.Error()) + } + + var exitCode int + defer smqlog.ExitWithError(&exitCode) + + if cfg.InstanceID == "" { + if cfg.InstanceID, err = uuid.New().ID(); err != nil { + logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) + exitCode = 1 + return + } + } + + httpServerConfig := server.Config{Port: defSvcHTTPPort} + if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) + exitCode = 1 + return + } + + dbConfig := pgclient.Config{Name: defDB} + if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { + logger.Error(fmt.Sprintf("failed to load %s Postgres configuration : %s", svcName, err)) + exitCode = 1 + return + } + db, err := pgclient.Setup(dbConfig, *timescale.Migration()) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer db.Close() + + tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) + if err != nil { + logger.Error(fmt.Sprintf("Failed to init Jaeger: %s", err)) + exitCode = 1 + return + } + defer func() { + if err := tp.Shutdown(ctx); err != nil { + logger.Error(fmt.Sprintf("Error shutting down tracer provider: %v", err)) + } + }() + tracer := tp.Tracer(svcName) + + repo := newService(db, logger) + repo = consumertracing.NewBlocking(tracer, repo, httpServerConfig) + + pubSub, err := brokers.NewPubSub(ctx, cfg.BrokerURL, logger) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to message broker: %s", err)) + exitCode = 1 + return + } + defer pubSub.Close() + pubSub = brokerstracing.NewPubSub(httpServerConfig, tracer, pubSub) + + if err = consumers.Start(ctx, svcName, pubSub, repo, cfg.ConfigPath, brokers.AllTopic, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create Timescale writer: %s", err)) + exitCode = 1 + return + } + + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svcName, cfg.InstanceID), logger) + + if cfg.SendTelemetry { + chc := chclient.New(svcName, supermq.Version, logger, cancel) + go chc.CallHome(ctx) + } + + g.Go(func() error { + return hs.Start() + }) + + g.Go(func() error { + return server.StopSignalHandler(ctx, cancel, logger, svcName, hs) + }) + + if err := g.Wait(); err != nil { + logger.Error(fmt.Sprintf("Timescale writer service terminated: %s", err)) + } +} + +func newService(db *sqlx.DB, logger *slog.Logger) consumers.BlockingConsumer { + svc := timescale.New(db) + svc = httpapi.LoggingMiddleware(svc, logger) + counter, latency := prometheus.MakeMetrics("timescale", "message_writer") + svc = httpapi.MetricsMiddleware(svc, counter, latency) + return svc +} diff --git a/cmd/users/main.go b/cmd/users/main.go index 43e5d4815..316cb3d09 100644 --- a/cmd/users/main.go +++ b/cmd/users/main.go @@ -28,6 +28,8 @@ import ( smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" "github.com/absmach/supermq/pkg/grpcclient" jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/oauth2" @@ -63,44 +65,44 @@ import ( const ( svcName = "users" - envPrefixDB = "SMQ_USERS_DB_" - envPrefixHTTP = "SMQ_USERS_HTTP_" - envPrefixGRPC = "SMQ_USERS_GRPC_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - envPrefixGoogle = "SMQ_GOOGLE_" + envPrefixDB = "MG_USERS_DB_" + envPrefixHTTP = "MG_USERS_HTTP_" + envPrefixGRPC = "MG_USERS_GRPC_" + envPrefixAuth = "MG_AUTH_GRPC_" + envPrefixDomains = "MG_DOMAINS_GRPC_" + envPrefixGoogle = "MG_GOOGLE_" defDB = "users" defSvcHTTPPort = "9002" defSvcGRPCPort = "7002" ) type config struct { - LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"` - AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"` - AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"` - AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"` - AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"` - AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"` - PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"` - OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"` - OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"` - DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"` - DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost/password/reset"` - PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"` - VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"` - VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"` - AuthKeyAlgorithm string `env:"SMQ_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` - JWKSURL string `env:"SMQ_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` + LogLevel string `env:"MG_USERS_LOG_LEVEL" envDefault:"info"` + AdminEmail string `env:"MG_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"` + AdminPassword string `env:"MG_USERS_ADMIN_PASSWORD" envDefault:"12345678"` + AdminUsername string `env:"MG_USERS_ADMIN_USERNAME" envDefault:"admin"` + AdminFirstName string `env:"MG_USERS_ADMIN_FIRST_NAME" envDefault:"super"` + AdminLastName string `env:"MG_USERS_ADMIN_LAST_NAME" envDefault:"admin"` + PassRegexText string `env:"MG_USERS_PASS_REGEX" envDefault:"^.{8,}$"` + JaegerURL url.URL `env:"MG_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_USERS_INSTANCE_ID" envDefault:""` + ESURL string `env:"MG_ES_URL" envDefault:"amqp://guest:guest@localhost:5682/"` + TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SelfRegister bool `env:"MG_USERS_ALLOW_SELF_REGISTER" envDefault:"false"` + OAuthUIRedirectURL string `env:"MG_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"` + OAuthUIErrorURL string `env:"MG_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"` + DeleteInterval time.Duration `env:"MG_USERS_DELETE_INTERVAL" envDefault:"24h"` + DeleteAfter time.Duration `env:"MG_USERS_DELETE_AFTER" envDefault:"720h"` + SpicedbHost string `env:"MG_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + PasswordResetURLPrefix string `env:"MG_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost/password/reset"` + PasswordResetEmailTemplate string `env:"MG_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"` + VerificationURLPrefix string `env:"MG_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"` + VerificationEmailTemplate string `env:"MG_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"` + AuthKeyAlgorithm string `env:"MG_AUTH_KEYS_ALGORITHM" envDefault:"RS256"` + JWKSURL string `env:"MG_AUTH_JWKS_URL" envDefault:"http://auth:9001/keys/.well-known/jwks.json"` PassRegex *regexp.Regexp } @@ -354,8 +356,10 @@ func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTok if err != nil { logger.Error(fmt.Sprintf("failed to create admin client: %s", err)) } - if err := createAdminPolicy(ctx, userID, authz, policyService); err != nil { - return nil, err + if userID != "" { + if err := createAdminPolicy(ctx, userID, policyService); err != nil { + return nil, err + } } users.NewDeleteHandler(ctx, repo, policyService, domainsClient, c.DeleteInterval, c.DeleteAfter, logger) @@ -395,34 +399,22 @@ func createAdmin(ctx context.Context, c config, repo users.Repository, hsr users return u.ID, nil } - // Create an admin if _, err = repo.Save(ctx, user); err != nil { return "", err } - if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword, ""); err != nil { - return "", err - } return user.ID, nil } -func createAdminPolicy(ctx context.Context, userID string, authz smqauthz.Authorization, policyService policies.Service) error { - if err := authz.Authorize(ctx, smqauthz.PolicyReq{ +func createAdminPolicy(ctx context.Context, userID string, policyService policies.Service) error { + err := policyService.AddPolicy(ctx, policies.Policy{ SubjectType: policies.UserType, Subject: userID, - Permission: policies.AdministratorRelation, - Object: policies.SuperMQObject, + Relation: policies.AdministratorRelation, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, - }, nil); err != nil { - err := policyService.AddPolicy(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Relation: policies.AdministratorRelation, - Object: policies.SuperMQObject, - ObjectType: policies.PlatformType, - }) - if err != nil { - return err - } + }) + if err != nil && !errors.Contains(err, repoerr.ErrConflict) { + return err } return nil } diff --git a/coap/README.md b/coap/README.md deleted file mode 100644 index 7dfc9ac1d..000000000 --- a/coap/README.md +++ /dev/null @@ -1,126 +0,0 @@ -# SuperMQ CoAP Adapter - -SuperMQ CoAP adapter provides an [CoAP](http://coap.technology/) API for sending messages through the platform. - -## Configuration - -The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. - -| Variable | Description | Default | -| ------------------------------------- | -------------------------------------------------------------------------------------------- | ------------------------------------- | -| `SMQ_COAP_ADAPTER_LOG_LEVEL` | Log level for the CoAP Adapter (`debug`, `info`, `warn`, `error`) | info | -| `SMQ_COAP_ADAPTER_HOST` | CoAP service listening host | "" | -| `SMQ_COAP_ADAPTER_PORT` | CoAP service listening port | 5683 | -| `SMQ_COAP_ADAPTER_SERVER_CERT` | Path to the PEM-encoded CoAP server certificate | "" | -| `SMQ_COAP_ADAPTER_SERVER_KEY` | Path to the PEM-encoded CoAP server key | "" | -| `SMQ_COAP_ADAPTER_HTTP_HOST` | Service HTTP listening host | "" | -| `SMQ_COAP_ADAPTER_HTTP_PORT` | Service HTTP listening port | 5683 | -| `SMQ_COAP_ADAPTER_HTTP_SERVER_CERT` | Path to the PEM-encoded HTTP server certificate | "" | -| `SMQ_COAP_ADAPTER_HTTP_SERVER_KEY` | Path to the PEM-encoded HTTP server key | "" | -| `SMQ_COAP_ADAPTER_CACHE_NUM_COUNTERS` | Number of cache counters that track topic parsing frequency | 200000 | -| `SMQ_COAP_ADAPTER_CACHE_MAX_COST` | Maximum cache size (bytes) | 1048576 | -| `SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS` | Number of cache `Get` buffer items | 64 | -| `SMQ_CLIENTS_GRPC_URL` | Clients service Auth gRPC URL | | -| `SMQ_CLIENTS_GRPC_TIMEOUT` | Clients service Auth gRPC request timeout | 1s | -| `SMQ_CLIENTS_GRPC_CLIENT_CERT` | Path to the PEM-encoded clients service Auth gRPC client certificate file | "" | -| `SMQ_CLIENTS_GRPC_CLIENT_KEY` | Path to the PEM-encoded clients service Auth gRPC client key file | "" | -| `SMQ_CLIENTS_GRPC_SERVER_CERTS` | Path to the PEM-encoded clients server Auth gRPC trusted CA certificate file | "" | -| `SMQ_MESSAGE_BROKER_URL` | Message broker instance URL | | -| `SMQ_JAEGER_URL` | Jaeger server URL | | -| `SMQ_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true | -| `SMQ_COAP_ADAPTER_INSTANCE_ID` | CoAP adapter instance ID | "" | - -## Deployment - -The service itself is distributed as Docker container. Check the [`coap-adapter`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in docker-compose file to see how service is deployed. - -Running this service outside of container requires working instance of the message broker service, clients service and Jaeger server. -To start the service outside of the container, execute the following shell script: - -```bash -# download the latest version of the service -git clone https://github.com/absmach/supermq - -cd supermq - -# compile the http -make coap - -# copy binary to bin -make install - -# set the environment variables and run the service -SMQ_COAP_ADAPTER_LOG_LEVEL=info \ -SMQ_COAP_ADAPTER_HOST=localhost \ -SMQ_COAP_ADAPTER_PORT=5683 \ -SMQ_COAP_ADAPTER_SERVER_CERT="" \ -SMQ_COAP_ADAPTER_SERVER_KEY="" \ -SMQ_COAP_ADAPTER_HTTP_HOST=localhost \ -SMQ_COAP_ADAPTER_HTTP_PORT=5683 \ -SMQ_COAP_ADAPTER_HTTP_SERVER_CERT="" \ -SMQ_COAP_ADAPTER_HTTP_SERVER_KEY="" \ -SMQ_COAP_ADAPTER_CACHE_NUM_COUNTERS=200000 \ -SMQ_COAP_ADAPTER_CACHE_MAX_COST=1048576 \ -SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS=64 \ -SMQ_CLIENTS_GRPC_URL=localhost:7000 \ -SMQ_CLIENTS_GRPC_TIMEOUT=1s \ -SMQ_CLIENTS_GRPC_CLIENT_CERT="" \ -SMQ_CLIENTS_GRPC_CLIENT_KEY="" \ -SMQ_CLIENTS_GRPC_SERVER_CERTS="" \ -SMQ_MESSAGE_BROKER_URL=amqp://guest:guest@rabbitmq:5672/ \ -SMQ_JAEGER_URL=http://localhost:14268/api/traces \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_SEND_TELEMETRY=true \ -SMQ_COAP_ADAPTER_INSTANCE_ID="" \ -$GOBIN/supermq-coap -``` - -Setting `SMQ_COAP_ADAPTER_SERVER_CERT` and `SMQ_COAP_ADAPTER_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_COAP_ADAPTER_HTTP_SERVER_CERT` and `SMQ_COAP_ADAPTER_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. - -Setting `SMQ_CLIENTS_GRPC_CLIENT_CERT` and `SMQ_CLIENTS_GRPC_CLIENT_KEY` will enable TLS against the clients service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_CLIENTS_GRPC_SERVER_CERTS` will enable TLS against the clients service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. - -## Usage - -If CoAP adapter is running locally (on default 5683 port), a valid URL would be: `coap://localhost/m//c//?auth=`. -Since CoAP protocol does not support `Authorization` header (option) and options have limited size, in order to send CoAP messages, valid `auth` value (a valid Client key) must be present in `Uri-Query` option. - -## Best Practices - -- Use distinct client auth keys and rotate them frequently for better security. - -- Use meaningful channel IDs and subtopics so you know exactly where your messages go. - -- Leverage metadata/tags in channels and clients (via clients service) to filter and manage messaging paths. - -- Ensure the auth query parameter is not exposed publicly (use secure networks or DTLS if available). - -- Monitor message broker load and usage patterns — CoAP traffic can burst. - -- Use the /health endpoint (if exposed) to monitor service status and integrate with your observability stack. - -## Versioning and Health Check - -If the service exposes a /health endpoint, you can use it for monitoring and version readiness checks. - -```bash -curl -X GET coap://localhost/health \ - -H "accept: application/health+json" -``` - -The expected response is: - -```bash -{ - "status": "pass", - "version": "0.xx.x", - "commit": "", - "description": "coap‑adapter service", - "build_time": "YYYY‑MM‑DDT…" -} -``` - -## CLI - -SuperMQ provides a CoAP CLI for testing and interacting with the CoAP Adapter. -To learn more about this visit the [SuperMQ CoAp CLI page](https://github.com/absmach/coap-cli/tree/main). diff --git a/coap/adapter.go b/coap/adapter.go deleted file mode 100644 index e8c44fed8..000000000 --- a/coap/adapter.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package coap contains the domain concept definitions needed to support -// SuperMQ CoAP adapter service functionality. All constant values are taken -// from RFC, and could be adjusted based on specific use case. -package coap - -import ( - "context" - - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var errFailedToDisconnectClient = errors.New("failed to disconnect client") - -// Service specifies CoAP service API. -type Service interface { - // Publish publishes message to specified channel. - // Key is used to authorize publisher. - Publish(ctx context.Context, key string, msg *messaging.Message, topicType messaging.TopicType) error - - // Subscribes to channel with specified id, domainID, subtopic and adds subscription to - // service map of subscriptions under given ID. - Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error - - // Unsubscribe method is used to stop observing resource. - Unsubscribe(ctx context.Context, key, domainID, chanID, subptopic, token string) error - - // DisconnectHandler method is used to disconnected the client - DisconnectHandler(ctx context.Context, domainID, chanID, subptopic, token string) error -} - -var _ Service = (*adapterService)(nil) - -// Observers is a map of maps,. -type adapterService struct { - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - pubsub messaging.PubSub -} - -// New instantiates the CoAP adapter implementation. -func New(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, pubsub messaging.PubSub) Service { - as := &adapterService{ - clients: clients, - channels: channels, - pubsub: pubsub, - } - - return as -} - -func (svc *adapterService) Publish(ctx context.Context, key string, msg *messaging.Message, topicType messaging.TopicType) error { - authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ - Token: authn.AuthPack(authn.DomainAuth, msg.GetDomain(), key), - }) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !authnRes.Authenticated { - return svcerr.ErrAuthentication - } - - // Health topics do not require channel authorization. - if topicType == messaging.HealthType { - return nil - } - msg.Publisher = authnRes.GetId() - - return svc.pubsub.Publish(ctx, messaging.EncodeMessageTopic(msg), msg) -} - -func (svc *adapterService) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c Client) error { - authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ - Token: authn.AuthPack(authn.DomainAuth, domainID, key), - }) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !authnRes.Authenticated { - return svcerr.ErrAuthentication - } - - clientID := authnRes.GetId() - - subject := messaging.EncodeTopic(domainID, chanID, subtopic) - authzc := newAuthzClient(clientID, domainID, chanID, subtopic, svc.channels, c) - subCfg := messaging.SubscriberConfig{ - ID: c.Token(), - ClientID: clientID, - Topic: subject, - Handler: authzc, - } - return svc.pubsub.Subscribe(ctx, subCfg) -} - -func (svc *adapterService) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error { - authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ - Token: authn.AuthPack(authn.DomainAuth, domainID, key), - }) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !authnRes.Authenticated { - return svcerr.ErrAuthentication - } - subject := messaging.EncodeTopic(domainID, chanID, subtopic) - - return svc.pubsub.Unsubscribe(ctx, token, subject) -} - -func (svc *adapterService) DisconnectHandler(ctx context.Context, domainID, chanID, subtopic, token string) error { - subject := messaging.EncodeTopic(domainID, chanID, subtopic) - - return svc.pubsub.Unsubscribe(ctx, token, subject) -} - -type authzClient interface { - // Handle handles incoming messages. - Handle(m *messaging.Message) error - - // Cancel cancels the client. - Cancel() error -} - -type ac struct { - clientID string - channelID string - domainID string - subTopic string - channels grpcChannelsV1.ChannelsServiceClient - client Client -} - -func newAuthzClient(clientID, domainID, channelID, subTopic string, channels grpcChannelsV1.ChannelsServiceClient, client Client) authzClient { - return ac{clientID, channelID, domainID, subTopic, channels, client} -} - -func (a ac) Handle(m *messaging.Message) error { - res, err := a.channels.Authorize(context.Background(), &grpcChannelsV1.AuthzReq{ - ClientId: a.clientID, - ClientType: policies.ClientType, - ChannelId: a.channelID, - DomainId: a.domainID, - Type: uint32(connections.Subscribe), - }) - if err != nil { - if disErr := a.Cancel(); disErr != nil { - return errors.Wrap(err, errors.Wrap(errFailedToDisconnectClient, disErr)) - } - return err - } - if !res.GetAuthorized() { - err := svcerr.ErrAuthorization - if disErr := a.Cancel(); disErr != nil { - return errors.Wrap(err, errors.Wrap(errFailedToDisconnectClient, disErr)) - } - return err - } - return a.client.Handle(m) -} - -func (a ac) Cancel() error { - return a.client.Cancel() -} diff --git a/coap/api/transport.go b/coap/api/transport.go deleted file mode 100644 index 7e0f09560..000000000 --- a/coap/api/transport.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "fmt" - "io" - "log/slog" - "net/http" - "strings" - "time" - - "github.com/absmach/supermq" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - "github.com/absmach/supermq/coap" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/go-chi/chi/v5" - "github.com/plgd-dev/go-coap/v3/message" - "github.com/plgd-dev/go-coap/v3/message/codes" - "github.com/plgd-dev/go-coap/v3/message/pool" - "github.com/plgd-dev/go-coap/v3/mux" - "github.com/prometheus/client_golang/prometheus/promhttp" -) - -const ( - protocol = "coap" - authQuery = "auth" - startObserve = 0 // observe option value that indicates start of observation -) - -var ( - errBadOptions = errors.New("bad options") - errMethodNotAllowed = errors.New("method not allowed") -) - -// MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(instanceID string) http.Handler { - b := chi.NewRouter() - b.Get("/health", supermq.Health(protocol, instanceID)) - b.Handle("/metrics", promhttp.Handler()) - - return b -} - -type CoAPHandler struct { - logger *slog.Logger - service coap.Service - channels grpcChannelsV1.ChannelsServiceClient - parser messaging.TopicParser -} - -// MakeCoAPHandler creates handler for CoAP messages. -func MakeCoAPHandler(svc coap.Service, channelsClient grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser, l *slog.Logger) mux.Handler { - return &CoAPHandler{ - logger: l, - service: svc, - channels: channelsClient, - parser: parser, - } -} - -// ServeCOAP implements the mux.Handler interface for handling CoAP messages. -func (h *CoAPHandler) ServeCOAP(w mux.ResponseWriter, m *mux.Message) { - resp := pool.NewMessage(w.Conn().Context()) - resp.SetToken(m.Token()) - for _, opt := range m.Options() { - resp.AddOptionBytes(opt.ID, opt.Value) - } - defer h.sendResp(w, resp) - - msg, topicType, err := h.decodeMessage(m) - if err != nil { - h.logger.Warn(fmt.Sprintf("Error decoding message: %s", err)) - resp.SetCode(codes.BadRequest) - return - } - key, err := parseKey(m) - if err != nil { - h.logger.Warn(fmt.Sprintf("Error parsing auth: %s", err)) - resp.SetCode(codes.Unauthorized) - return - } - - switch m.Code() { - case codes.GET: - resp.SetCode(codes.Content) - err = h.handleGet(m, w, topicType, msg, key) - case codes.POST: - resp.SetCode(codes.Created) - err = h.service.Publish(m.Context(), key, msg, topicType) - default: - err = errMethodNotAllowed - } - - if err != nil { - switch { - case err == errBadOptions: - resp.SetCode(codes.BadOption) - case err == errMethodNotAllowed: - resp.SetCode(codes.MethodNotAllowed) - case errors.Contains(err, svcerr.ErrAuthorization): - resp.SetCode(codes.Forbidden) - case errors.Contains(err, svcerr.ErrAuthentication): - resp.SetCode(codes.Unauthorized) - default: - resp.SetCode(codes.InternalServerError) - } - } -} - -func (h *CoAPHandler) handleGet(m *mux.Message, w mux.ResponseWriter, topicType messaging.TopicType, msg *messaging.Message, key string) error { - var obs uint32 - obs, err := m.Options().Observe() - if err != nil { - h.logger.Warn(fmt.Sprintf("Error reading observe option: %s", err)) - return errBadOptions - } - if obs == startObserve { - c := coap.NewClient(w.Conn(), m.Token(), h.logger) - w.Conn().AddOnClose(func() { - _ = h.service.DisconnectHandler(context.Background(), msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c.Token()) - }) - return h.service.Subscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), c) - } - return h.service.Unsubscribe(w.Conn().Context(), key, msg.GetDomain(), msg.GetChannel(), msg.GetSubtopic(), m.Token().String()) -} - -func (h *CoAPHandler) decodeMessage(msg *mux.Message) (*messaging.Message, messaging.TopicType, error) { - if msg.Options() == nil { - return &messaging.Message{}, messaging.InvalidType, errBadOptions - } - path, err := msg.Path() - if err != nil { - return &messaging.Message{}, messaging.InvalidType, err - } - - var domainID, channelID, subTopic string - var topicType messaging.TopicType - switch msg.Code() { - case codes.GET: - domainID, channelID, subTopic, topicType, err = h.parser.ParseSubscribeTopic(msg.Context(), path, true) - case codes.POST: - domainID, channelID, subTopic, topicType, err = h.parser.ParsePublishTopic(msg.Context(), path, true) - } - if err != nil { - return &messaging.Message{}, messaging.InvalidType, err - } - - ret := &messaging.Message{ - Protocol: protocol, - Domain: domainID, - Channel: channelID, - Subtopic: subTopic, - Payload: []byte{}, - Created: time.Now().UnixNano(), - } - - if msg.Body() != nil { - buff, err := io.ReadAll(msg.Body()) - if err != nil { - return ret, messaging.InvalidType, err - } - ret.Payload = buff - } - return ret, topicType, nil -} - -func (h *CoAPHandler) sendResp(w mux.ResponseWriter, resp *pool.Message) { - if err := w.Conn().WriteMessage(resp); err != nil { - h.logger.Warn(fmt.Sprintf("Can't set response: %s", err)) - } -} - -func parseKey(msg *mux.Message) (string, error) { - authKey, err := msg.Options().GetString(message.URIQuery) - if err != nil { - return "", err - } - vars := strings.Split(authKey, "=") - if len(vars) != 2 || vars[0] != authQuery { - return "", svcerr.ErrAuthorization - } - return vars[1], nil -} diff --git a/coap/client.go b/coap/client.go deleted file mode 100644 index 4e0eb2982..000000000 --- a/coap/client.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package coap - -import ( - "bytes" - "fmt" - "log/slog" - "sync/atomic" - - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/messaging" - "github.com/plgd-dev/go-coap/v3/message" - "github.com/plgd-dev/go-coap/v3/message/codes" - mux "github.com/plgd-dev/go-coap/v3/mux" -) - -// Client wraps CoAP client. -type Client interface { - // In CoAP terminology, Token similar to the Session ID. - Token() string - - // Handle handles incoming messages. - Handle(m *messaging.Message) error - - // Cancel cancels the client. - Cancel() error - - // Done returns a channel that's closed when the client is done. - Done() <-chan struct{} -} - -// ErrOption indicates an error when adding an option. -var ErrOption = errors.New("unable to set option") - -type client struct { - conn mux.Conn - token message.Token - observe uint32 - logger *slog.Logger -} - -// NewClient instantiates a new Observer. -func NewClient(conn mux.Conn, tkn message.Token, l *slog.Logger) Client { - return &client{ - conn: conn, - token: tkn, - logger: l, - observe: 0, - } -} - -func (c *client) Done() <-chan struct{} { - return c.conn.Done() -} - -func (c *client) Cancel() error { - pm := c.conn.AcquireMessage(c.conn.Context()) - pm.SetCode(codes.Content) - pm.SetToken(c.token) - if err := c.conn.WriteMessage(pm); err != nil { - c.logger.Error(fmt.Sprintf("Error sending message: %s.", err)) - } - c.conn.ReleaseMessage(pm) - return c.conn.Close() -} - -func (c *client) Token() string { - return c.token.String() -} - -func (c *client) Handle(msg *messaging.Message) error { - pm := c.conn.AcquireMessage(c.conn.Context()) - defer c.conn.ReleaseMessage(pm) - pm.SetCode(codes.Content) - pm.SetToken(c.token) - pm.SetBody(bytes.NewReader(msg.GetPayload())) - - atomic.AddUint32(&c.observe, 1) - var opts message.Options - var buff []byte - opts, n, err := opts.SetContentFormat(buff, message.TextPlain) - if err == message.ErrTooSmall { - buff = append(buff, make([]byte, n)...) - _, _, err = opts.SetContentFormat(buff, message.TextPlain) - } - if err != nil { - c.logger.Error(fmt.Sprintf("Can't set content format: %s.", err)) - return errors.Wrap(ErrOption, err) - } - opts, n, err = opts.SetObserve(buff, c.observe) - if err == message.ErrTooSmall { - buff = append(buff, make([]byte, n)...) - opts, _, err = opts.SetObserve(buff, uint32(c.observe)) - } - if err != nil { - return fmt.Errorf("cannot set options to response: %w", err) - } - - for _, option := range opts { - pm.SetOptionBytes(option.ID, option.Value) - } - return c.conn.WriteMessage(pm) -} diff --git a/coap/handler.go b/coap/handler.go deleted file mode 100644 index c65a568a9..000000000 --- a/coap/handler.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package coap - -import ( - "context" - "fmt" - "log/slog" - "net/http" - "strings" - - mgate "github.com/absmach/mgate/pkg/coap" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var _ session.Handler = (*handler)(nil) - -// Log message formats. -const ( - subscribedInfoFmt = "subscribed with client_id %s to topics %s" - publishedInfoFmt = "published with client_id %s to the topic %s" -) - -// Error wrappers for COAP errors. -var ( - errClientNotInitialized = errors.New("client is not initialized") - errMissingTopicPub = errors.New("failed to publish due to missing topic") - errMissingTopicSub = errors.New("failed to subscribe due to missing topic") - errFailedPublish = errors.New("failed to publish") -) - -type handler struct { - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - logger *slog.Logger - parser messaging.TopicParser -} - -// NewHandler creates new Handler entity. -func NewHandler(logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler { - return &handler{ - logger: logger, - clients: clients, - channels: channels, - parser: parser, - } -} - -// AuthConnect is called on device connection, -// prior forwarding to the coap server. -func (h *handler) AuthConnect(ctx context.Context) error { - return nil -} - -// AuthPublish is called on device publish, -// prior forwarding to the coap server. -func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - if topic == nil { - return errMissingTopicPub - } - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - domainID, channelID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true) - if err != nil { - return mgate.NewCOAPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) - } - - clientID, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Publish, topicType) - if err != nil { - return err - } - s.Username = clientID - - return nil -} - -// AuthSubscribe is called on device publish, -// prior forwarding to the COAP broker. -func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - if topics == nil || *topics == nil { - return errMissingTopicSub - } - - for _, topic := range *topics { - domainID, channelID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, true) - if err != nil { - return err - } - if _, err := h.authAccess(ctx, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil { - return err - } - } - return nil -} - -// Connect - after client successfully connected. -func (h *handler) Connect(ctx context.Context) error { - return nil -} - -// Publish - after client successfully published. -func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - if len(*payload) == 0 { - h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username)) - return nil - } - - h.logger.Info(fmt.Sprintf(publishedInfoFmt, s.Username, *topic)) - - return nil -} - -// Subscribe - after client successfully subscribed. -func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - h.logger.Info(fmt.Sprintf(subscribedInfoFmt, s.Username, strings.Join(*topics, ","))) - return nil -} - -// Unsubscribe - after client unsubscribed. -func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { - return nil -} - -// Disconnect - connection with broker or client lost. -func (h *handler) Disconnect(ctx context.Context) error { - return nil -} - -func (h *handler) authAccess(ctx context.Context, secret, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) { - authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, secret)}) - if err != nil { - return "", mgate.NewCOAPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - if !authnRes.Authenticated { - return "", mgate.NewCOAPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - - if topicType == messaging.HealthType { - return authnRes.GetId(), nil - } - - ar := &grpcChannelsV1.AuthzReq{ - Type: uint32(msgType), - ClientId: authnRes.GetId(), - ClientType: policies.ClientType, - ChannelId: chanID, - DomainId: domainID, - } - res, err := h.channels.Authorize(ctx, ar) - if err != nil { - return "", mgate.NewCOAPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err)) - } - if !res.GetAuthorized() { - return "", mgate.NewCOAPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - - return authnRes.GetId(), nil -} diff --git a/coap/middleware/doc.go b/coap/middleware/doc.go deleted file mode 100644 index acd084738..000000000 --- a/coap/middleware/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package middleware provides tracing, logging and metrics middleware -// for SuperMQ COAP service. -// -// For more details about tracing instrumentation for SuperMQ messaging refer -// to the documentation at https://docs.supermq.absmach.eu/tracing/. -package middleware diff --git a/coap/middleware/logging.go b/coap/middleware/logging.go deleted file mode 100644 index cc0bd8355..000000000 --- a/coap/middleware/logging.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build !test - -package middleware - -import ( - "context" - "log/slog" - "time" - - "github.com/absmach/supermq/coap" - "github.com/absmach/supermq/pkg/messaging" -) - -var _ coap.Service = (*loggingMiddleware)(nil) - -type loggingMiddleware struct { - logger *slog.Logger - svc coap.Service -} - -// NewLogging adds logging facilities to the adapter. -func NewLogging(svc coap.Service, logger *slog.Logger) coap.Service { - return &loggingMiddleware{logger, svc} -} - -// Publish logs the publish request. It logs the channel ID, subtopic (if any) and the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) Publish(ctx context.Context, key string, msg *messaging.Message, topicType messaging.TopicType) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("channel_id", msg.GetChannel()), - slog.String("domain_id", msg.GetDomain()), - } - if msg.GetSubtopic() != "" { - args = append(args, slog.String("subtopic", msg.GetSubtopic())) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Publish message failed", args...) - return - } - lm.logger.Info("Publish message completed successfully", args...) - }(time.Now()) - - return lm.svc.Publish(ctx, key, msg, topicType) -} - -// Subscribe logs the subscribe request. It logs the channel ID, subtopic (if any) and the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("channel_id", chanID), - slog.String("domain_id", domainID), - } - if subtopic != "" { - args = append(args, slog.String("subtopic", subtopic)) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Subscribe failed", args...) - return - } - lm.logger.Info("Subscribe completed successfully", args...) - }(time.Now()) - - return lm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c) -} - -// Unsubscribe logs the unsubscribe request. It logs the channel ID, subtopic (if any) and the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("channel_id", chanID), - slog.String("domain_id", domainID), - } - if subtopic != "" { - args = append(args, slog.String("subtopic", subtopic)) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Unsubscribe failed", args...) - return - } - lm.logger.Info("Unsubscribe completed successfully", args...) - }(time.Now()) - - return lm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token) -} - -// DisconnectHandler logs the disconnect handler. It logs the channel ID, subtopic (if any) and the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) DisconnectHandler(ctx context.Context, domainID, chanID, subtopic, token string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("domain_id", domainID), - slog.String("channel_id", chanID), - slog.String("token", token), - } - if subtopic != "" { - args = append(args, slog.String("subtopic", subtopic)) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Unsubscribe failed", args...) - return - } - lm.logger.Info("Unsubscribe completed successfully", args...) - }(time.Now()) - - return lm.svc.DisconnectHandler(ctx, domainID, chanID, subtopic, token) -} diff --git a/coap/middleware/metrics.go b/coap/middleware/metrics.go deleted file mode 100644 index 33ae98be8..000000000 --- a/coap/middleware/metrics.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build !test - -package middleware - -import ( - "context" - "time" - - "github.com/absmach/supermq/coap" - "github.com/absmach/supermq/pkg/messaging" - "github.com/go-kit/kit/metrics" -) - -var _ coap.Service = (*metricsMiddleware)(nil) - -type metricsMiddleware struct { - counter metrics.Counter - latency metrics.Histogram - svc coap.Service -} - -// NewMetrics instruments adapter by tracking request count and latency. -func NewMetrics(svc coap.Service, counter metrics.Counter, latency metrics.Histogram) coap.Service { - return &metricsMiddleware{ - counter: counter, - latency: latency, - svc: svc, - } -} - -// Publish instruments Publish method with metrics. -func (mm *metricsMiddleware) Publish(ctx context.Context, key string, msg *messaging.Message, topicType messaging.TopicType) error { - defer func(begin time.Time) { - mm.counter.With("method", "publish").Add(1) - mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Publish(ctx, key, msg, topicType) -} - -// Subscribe instruments Subscribe method with metrics. -func (mm *metricsMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) error { - defer func(begin time.Time) { - mm.counter.With("method", "subscribe").Add(1) - mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c) -} - -// Unsubscribe instruments Unsubscribe method with metrics. -func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error { - defer func(begin time.Time) { - mm.counter.With("method", "unsubscribe").Add(1) - mm.latency.With("method", "unsubscribe").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token) -} - -// DisconnectHandler instruments DisconnectHandler method with metrics. -func (mm *metricsMiddleware) DisconnectHandler(ctx context.Context, domainID, chanID, subtopic, token string) error { - defer func(begin time.Time) { - mm.counter.With("method", "disconnect_handler").Add(1) - mm.latency.With("method", "disconnect_handler").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.DisconnectHandler(ctx, domainID, chanID, subtopic, token) -} diff --git a/coap/middleware/tracing.go b/coap/middleware/tracing.go deleted file mode 100644 index 6d29baf55..000000000 --- a/coap/middleware/tracing.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - - "github.com/absmach/supermq/coap" - "github.com/absmach/supermq/pkg/messaging" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -var _ coap.Service = (*tracingServiceMiddleware)(nil) - -// Operation names for tracing CoAP operations. -const ( - publishOP = "publish_op" - subscribeOP = "subscribe_op" - unsubscribeOP = "unsubscribe_op" - disconnectHandlerOp = "disconnect_handler_op" -) - -// tracingServiceMiddleware is a middleware implementation for tracing CoAP service operations using OpenTelemetry. -type tracingServiceMiddleware struct { - tracer trace.Tracer - svc coap.Service -} - -// NewTracing creates a new instance of TracingServiceMiddleware that wraps an existing CoAP service with tracing capabilities. -func NewTracing(tracer trace.Tracer, svc coap.Service) coap.Service { - return &tracingServiceMiddleware{ - tracer: tracer, - svc: svc, - } -} - -// Publish traces a CoAP publish operation. -func (tm *tracingServiceMiddleware) Publish(ctx context.Context, key string, msg *messaging.Message, topicType messaging.TopicType) error { - ctx, span := tm.tracer.Start(ctx, publishOP, trace.WithAttributes( - attribute.String("channel_id", msg.Channel), - attribute.String("domain_id", msg.Domain), - attribute.String("topic_type", string(topicType)), - )) - defer span.End() - return tm.svc.Publish(ctx, key, msg, topicType) -} - -// Subscribe traces a CoAP subscribe operation. -func (tm *tracingServiceMiddleware) Subscribe(ctx context.Context, key, domainID, chanID, subtopic string, c coap.Client) error { - ctx, span := tm.tracer.Start(ctx, subscribeOP, trace.WithAttributes( - attribute.String("channel_id", chanID), - attribute.String("domain_id", domainID), - attribute.String("subtopic", subtopic), - )) - defer span.End() - return tm.svc.Subscribe(ctx, key, domainID, chanID, subtopic, c) -} - -// Unsubscribe traces a CoAP unsubscribe operation. -func (tm *tracingServiceMiddleware) Unsubscribe(ctx context.Context, key, domainID, chanID, subtopic, token string) error { - ctx, span := tm.tracer.Start(ctx, unsubscribeOP, trace.WithAttributes( - attribute.String("channel_id", chanID), - attribute.String("domain_id", domainID), - attribute.String("subtopic", subtopic), - )) - defer span.End() - return tm.svc.Unsubscribe(ctx, key, domainID, chanID, subtopic, token) -} - -// DisconnectHandler traces a CoAP disconnect operation. -func (tm *tracingServiceMiddleware) DisconnectHandler(ctx context.Context, domainID, chanID, subptopic, token string) error { - ctx, span := tm.tracer.Start(ctx, disconnectHandlerOp, trace.WithAttributes( - attribute.String("domain_id", domainID), - attribute.String("channel_id", chanID), - attribute.String("subtopic", subptopic), - )) - defer span.End() - return tm.svc.DisconnectHandler(ctx, domainID, chanID, subptopic, token) -} diff --git a/consumers/messages.go b/consumers/messages.go index 0b5bcfd43..e1d49d3b3 100644 --- a/consumers/messages.go +++ b/consumers/messages.go @@ -13,7 +13,6 @@ import ( apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/errors" "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/brokers" "github.com/absmach/supermq/pkg/transformers" "github.com/absmach/supermq/pkg/transformers/json" "github.com/absmach/supermq/pkg/transformers/senml" @@ -33,18 +32,18 @@ var ( // Start method starts consuming messages received from Message broker. // This method transforms messages to SenML format before // using MessageRepository to store them. -func Start(ctx context.Context, id string, sub messaging.Subscriber, consumer any, configPath string, logger *slog.Logger) error { - cfg, err := loadConfig(configPath) +func Start(ctx context.Context, id string, sub messaging.Subscriber, consumer any, configPath string, defaultTopic string, logger *slog.Logger) error { + cfg, err := loadConfig(configPath, defaultTopic) if err != nil { logger.Warn(fmt.Sprintf("Failed to load consumer config: %s", err)) } transformer := makeTransformer(cfg.TransformerCfg, logger) - for _, subject := range cfg.SubscriberCfg.Subjects { + for _, topic := range cfg.SubscriberCfg.Topics { subCfg := messaging.SubscriberConfig{ ID: id, - Topic: subject, + Topic: topic, DeliveryPolicy: messaging.DeliverAllPolicy, } switch c := consumer.(type) { @@ -106,7 +105,7 @@ func (h handleFunc) Cancel() error { } type subscriberConfig struct { - Subjects []string `toml:"subjects"` + Topics []string `toml:"topics"` } type transformerConfig struct { @@ -120,10 +119,10 @@ type config struct { TransformerCfg transformerConfig `toml:"transformer"` } -func loadConfig(configPath string) (config, error) { +func loadConfig(configPath, defaultTopic string) (config, error) { cfg := config{ SubscriberCfg: subscriberConfig{ - Subjects: []string{brokers.SubjectAllMessages}, + Topics: []string{defaultTopic}, }, TransformerCfg: transformerConfig{ Format: defFormat, diff --git a/consumers/notifiers/README.md b/consumers/notifiers/README.md new file mode 100644 index 000000000..ff16de480 --- /dev/null +++ b/consumers/notifiers/README.md @@ -0,0 +1,182 @@ +# Notifiers + +The Notifiers service manages notification subscriptions and dispatches alerts for incoming messages. It stores subscription records (topic + contact), exposes an HTTP API for CRUD operations, and consumes SuperMQ messages to fan out notifications via notifier implementations (SMTP for email, SMPP for SMS). Notifiers are dependencies used by the service, not standalone services. + +## Configuration + +The service is configured using environment variables. Values shown are from [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env) when available; otherwise defaults come from code or notifier-specific docs. + +### SMTP notifier (email) + +Used by `consumers/notifiers/smtp` via `internal/email`. + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_EMAIL_HOST` | SMTP host | `smtp.mailtrap.io` | +| `MG_EMAIL_PORT` | SMTP port | `2525` | +| `MG_EMAIL_USERNAME` | SMTP username | `18bf7f70705139` | +| `MG_EMAIL_PASSWORD` | SMTP password | `2b0d302e775b1e` | +| `MG_EMAIL_FROM_ADDRESS` | Default from address (used if `from` is empty) | `from@example.com` | +| `MG_EMAIL_FROM_NAME` | Default from name | `Example` | +| `MG_EMAIL_TEMPLATE` | Email template path | `email.tmpl` | + +### SMPP notifier (SMS) + +#### SMPP transport settings + +Defined in `consumers/notifiers/smpp/config.go`. + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_SMPP_ADDRESS` | SMPP address in `host:port` format | "" | +| `MG_SMPP_USERNAME` | SMPP username | "" | +| `MG_SMPP_PASSWORD` | SMPP password | "" | +| `MG_SMPP_SYSTEM_TYPE` | SMPP system type | "" | +| `MG_SMPP_SRC_ADDR_TON` | SMPP source address TON | `0` | +| `MG_SMPP_DST_ADDR_TON` | SMPP source address NPI | `0` | +| `MG_SMPP_SRC_ADDR_NPI` | SMPP destination address TON | `0` | +| `MG_SMPP_DST_ADDR_NPI` | SMPP destination address NPI | `0` | + +Note: The SMPP env tags are mapped exactly as defined in `consumers/notifiers/smpp/config.go`. + +#### SMPP notifier service settings + +Defined in `consumers/notifiers/smpp/README.md`. + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_SMPP_NOTIFIER_LOG_LEVEL` | Log level for SMPP notifier | `info` | +| `MG_SMPP_NOTIFIER_FROM_ADDRESS` | From address for SMS notifications | "" | +| `MG_SMPP_NOTIFIER_CONFIG_PATH` | Config file path for message broker subjects and payload type | `/config.toml` | +| `MG_SMPP_NOTIFIER_HTTP_HOST` | Service HTTP host | `localhost` | +| `MG_SMPP_NOTIFIER_HTTP_PORT` | Service HTTP port | `9014` | +| `MG_SMPP_NOTIFIER_HTTP_SERVER_CERT` | Service HTTP server certificate path | "" | +| `MG_SMPP_NOTIFIER_HTTP_SERVER_KEY` | Service HTTP server key path | "" | +| `MG_SMPP_NOTIFIER_DB_HOST` | Database host address | `localhost` | +| `MG_SMPP_NOTIFIER_DB_PORT` | Database host port | `5432` | +| `MG_SMPP_NOTIFIER_DB_USER` | Database user | `magistrala` | +| `MG_SMPP_NOTIFIER_DB_PASS` | Database password | `magistrala` | +| `MG_SMPP_NOTIFIER_DB_NAME` | Database name | `subscriptions` | +| `MG_SMPP_NOTIFIER_DB_SSL_MODE` | DB SSL mode (disable, require, verify-ca, verify-full) | `disable` | +| `MG_SMPP_NOTIFIER_DB_SSL_CERT` | DB SSL client cert path | "" | +| `MG_SMPP_NOTIFIER_DB_SSL_KEY` | DB SSL client key path | "" | +| `MG_SMPP_NOTIFIER_DB_SSL_ROOT_CERT` | DB SSL root cert path | "" | +| `MG_AUTH_GRPC_URL` | Auth gRPC URL | `localhost:7001` | +| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `1s` | +| `MG_AUTH_GRPC_CLIENT_TLS` | Auth client TLS flag | `false` | +| `MG_AUTH_GRPC_CA_CERT` | Auth client CA certs path | "" | +| `MG_MESSAGE_BROKER_URL` | Message broker URL | `nats://127.0.0.1:4222` | +| `MG_JAEGER_URL` | Jaeger tracing URL | `http://jaeger:14268/api/traces` | +| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | `true` | +| `MG_SMPP_NOTIFIER_INSTANCE_ID` | SMPP notifier instance ID | "" | + +## Features + +- **Subscription management**: Create, view, list, and remove notification subscriptions. +- **Topic-based dispatch**: Matches subscriptions by topic and fan-outs to contacts. +- **Multiple notifier backends**: SMTP (email) and SMPP (SMS) implementations are available. +- **Observability**: Exposes `/metrics` and `/health` endpoints. +- **Uniqueness guardrails**: Prevents duplicate subscriptions for the same topic/contact pair. + +## Architecture + +### Runtime flow + +1. Clients register subscriptions through the HTTP API (`topic` + `contact`). +2. The service authenticates the token, assigns an owner ID, and persists the subscription. +3. When a message arrives, the service builds the topic as `channel` or `channel.subtopic`, retrieves matching subscriptions, and gathers contacts. +4. The notifier implementation sends notifications using the configured backend. + +### Components + +- **HTTP API**: `consumers/notifiers/api` exposes `/subscriptions`, `/health`, and `/metrics`. +- **Service layer**: `consumers/notifiers/service.go` handles authn, ID creation, and notification dispatch. +- **Repository**: `consumers/notifiers/postgres` persists subscriptions and supports filtering. +- **Notifier implementations**: `consumers/notifiers/smtp` (email) and `consumers/notifiers/smpp` (SMS). +- **Email agent**: `internal/email` manages SMTP connectivity and template rendering. + +### Subscriptions table + +Defined in `consumers/notifiers/postgres/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `id` | `VARCHAR(254)` | Subscription identifier (primary key) | +| `owner_id` | `VARCHAR(254)` | Owner ID derived from the auth token | +| `contact` | `VARCHAR(254)` | Notification contact (email or phone) | +| `topic` | `TEXT` | Topic to match (`channel` or `channel.subtopic`) | + +Constraint: `UNIQUE(topic, contact)` + +## Deployment + +The Notifiers service is provided as a consumer package. It is typically wired into a notifier-specific binary that provides the HTTP server and message broker subscription. For the SMPP notifier runtime configuration, see `consumers/notifiers/smpp/README.md`. + +### Health check + +```bash +curl -X GET http://localhost:9014/health \ + -H "accept: application/health+json" +``` + +## Testing + +```bash +go test ./consumers/notifiers/... +``` + +## Usage + +The Notifiers service supports the following operations (see `apidocs/openapi/notifiers.yaml`): + +| Operation | Method & Path | Description | +| --- | --- | --- | +| `createSubscription` | `POST /subscriptions` | Create a new subscription | +| `listSubscriptions` | `GET /subscriptions` | List subscriptions with filters | +| `viewSubscription` | `GET /subscriptions/{id}` | Retrieve a subscription | +| `removeSubscription` | `DELETE /subscriptions/{id}` | Delete a subscription | +| `health` | `GET /health` | Service health check | + +### Example: Create a subscription + +```bash +curl -X POST http://localhost:9014/subscriptions \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "topic": "channel.subtopic", + "contact": "user@example.com" + }' +``` + +### Example: List subscriptions + +```bash +curl -X GET "http://localhost:9014/subscriptions?topic=channel.subtopic&contact=user@example.com&limit=20&offset=0" \ + -H "Authorization: Bearer " +``` + +### Example: View a subscription + +```bash +curl -X GET http://localhost:9014/subscriptions/ \ + -H "Authorization: Bearer " +``` + +### Example: Remove a subscription + +```bash +curl -X DELETE http://localhost:9014/subscriptions/ \ + -H "Authorization: Bearer " +``` + +### Example: Health check + +```bash +curl -X GET http://localhost:9014/health \ + -H "accept: application/health+json" +``` + +For an in-depth explanation of the Notifiers, see the [official documentation][doc]. + +[doc]: https://docs.magistrala.absmach.eu/dev-guide/consumers/#notifiers diff --git a/http/api/doc.go b/consumers/notifiers/api/doc.go similarity index 100% rename from http/api/doc.go rename to consumers/notifiers/api/doc.go diff --git a/consumers/notifiers/api/endpoint.go b/consumers/notifiers/api/endpoint.go new file mode 100644 index 000000000..9b9eb8202 --- /dev/null +++ b/consumers/notifiers/api/endpoint.go @@ -0,0 +1,103 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + apiutil "github.com/absmach/supermq/api/http/util" + notifiers "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-kit/kit/endpoint" +) + +func createSubscriptionEndpoint(svc notifiers.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(createSubReq) + if err := req.validate(); err != nil { + return createSubRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + sub := notifiers.Subscription{ + Contact: req.Contact, + Topic: req.Topic, + } + id, err := svc.CreateSubscription(ctx, req.token, sub) + if err != nil { + return createSubRes{}, err + } + ucr := createSubRes{ + ID: id, + } + + return ucr, nil + } +} + +func viewSubscriptionEndpoint(svc notifiers.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(subReq) + if err := req.validate(); err != nil { + return viewSubRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + sub, err := svc.ViewSubscription(ctx, req.token, req.id) + if err != nil { + return viewSubRes{}, err + } + res := viewSubRes{ + ID: sub.ID, + OwnerID: sub.OwnerID, + Contact: sub.Contact, + Topic: sub.Topic, + } + return res, nil + } +} + +func listSubscriptionsEndpoint(svc notifiers.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(listSubsReq) + if err := req.validate(); err != nil { + return listSubsRes{}, errors.Wrap(apiutil.ErrValidation, err) + } + pm := notifiers.PageMetadata{ + Topic: req.topic, + Contact: req.contact, + Offset: req.offset, + Limit: int(req.limit), + } + page, err := svc.ListSubscriptions(ctx, req.token, pm) + if err != nil { + return listSubsRes{}, err + } + res := listSubsRes{ + Offset: page.Offset, + Limit: page.Limit, + Total: page.Total, + } + for _, sub := range page.Subscriptions { + r := viewSubRes{ + ID: sub.ID, + OwnerID: sub.OwnerID, + Contact: sub.Contact, + Topic: sub.Topic, + } + res.Subscriptions = append(res.Subscriptions, r) + } + + return res, nil + } +} + +func deleteSubscriptionEndpoint(svc notifiers.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(subReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + if err := svc.RemoveSubscription(ctx, req.token, req.id); err != nil { + return nil, err + } + return removeSubRes{}, nil + } +} diff --git a/consumers/notifiers/api/endpoint_test.go b/consumers/notifiers/api/endpoint_test.go new file mode 100644 index 000000000..0e6eb614e --- /dev/null +++ b/consumers/notifiers/api/endpoint_test.go @@ -0,0 +1,548 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "path" + "strings" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/consumers/notifiers/api" + "github.com/absmach/supermq/consumers/notifiers/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + contentType = "application/json" + email = "user@example.com" + contact1 = "email1@example.com" + contact2 = "email2@example.com" + token = "token" + invalidToken = "invalid" + topic = "topic" + instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" +) + +var ( + notFoundRes = toJSON(svcerr.ErrNotFound) + unauthRes = toJSON(svcerr.ErrAuthentication) + invalidRes = toJSON(apiutil.ErrInvalidQueryParams) + missingTokRes = toJSON(apiutil.ErrBearerToken) +) + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + return tr.client.Do(req) +} + +func newServer() (*httptest.Server, *mocks.Service) { + logger := smqlog.NewMock() + svc := new(mocks.Service) + mux := api.MakeHandler(svc, logger, instanceID) + return httptest.NewServer(mux), svc +} + +func toJSON(data any) string { + jsonData, err := json.Marshal(data) + if err != nil { + return "" + } + return string(jsonData) +} + +func TestCreate(t *testing.T) { + ss, svc := newServer() + defer ss.Close() + + sub := notifiers.Subscription{ + Topic: topic, + Contact: contact1, + } + + data := toJSON(sub) + + emptyTopic := toJSON(notifiers.Subscription{Contact: contact1}) + emptyContact := toJSON(notifiers.Subscription{Topic: "topic123"}) + + cases := []struct { + desc string + req string + contentType string + auth string + status int + location string + err error + }{ + { + desc: "add successfully", + req: data, + contentType: contentType, + auth: token, + status: http.StatusCreated, + location: fmt.Sprintf("/subscriptions/%s%012d", uuid.Prefix, 1), + err: nil, + }, + { + desc: "add an existing subscription", + req: data, + contentType: contentType, + auth: token, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrConflict, + }, + { + desc: "add with empty topic", + req: emptyTopic, + contentType: contentType, + auth: token, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add with empty contact", + req: emptyContact, + contentType: contentType, + auth: token, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add with invalid auth token", + req: data, + contentType: contentType, + auth: invalidToken, + status: http.StatusUnauthorized, + location: "", + err: svcerr.ErrAuthentication, + }, + { + desc: "add with empty auth token", + req: data, + contentType: contentType, + auth: "", + status: http.StatusUnauthorized, + location: "", + err: svcerr.ErrAuthentication, + }, + { + desc: "add with invalid request format", + req: "}", + contentType: contentType, + auth: token, + status: http.StatusBadRequest, + location: "", + err: svcerr.ErrMalformedEntity, + }, + { + desc: "add without content type", + req: data, + contentType: "", + auth: token, + status: http.StatusUnsupportedMediaType, + location: "", + err: apiutil.ErrUnsupportedContentType, + }, + } + + for _, tc := range cases { + svcCall := svc.On("CreateSubscription", mock.Anything, tc.auth, sub).Return(path.Base(tc.location), tc.err) + + req := testRequest{ + client: ss.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/subscriptions", ss.URL), + contentType: tc.contentType, + token: tc.auth, + body: strings.NewReader(tc.req), + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + location := res.Header.Get("Location") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.location, location, fmt.Sprintf("%s: expected location %s got %s", tc.desc, tc.location, location)) + + svcCall.Unset() + } +} + +func TestView(t *testing.T) { + ss, svc := newServer() + defer ss.Close() + + sub := notifiers.Subscription{ + Topic: topic, + Contact: contact1, + ID: testsutil.GenerateUUID(t), + OwnerID: validID, + } + + sr := subRes{ + ID: sub.ID, + OwnerID: validID, + Contact: sub.Contact, + Topic: sub.Topic, + } + data := toJSON(sr) + + cases := []struct { + desc string + id string + auth string + status int + res string + err error + Sub notifiers.Subscription + }{ + { + desc: "view successfully", + id: sub.ID, + auth: token, + status: http.StatusOK, + res: data, + err: nil, + Sub: sub, + }, + { + desc: "view not existing", + id: "not existing", + auth: token, + status: http.StatusNotFound, + res: notFoundRes, + err: svcerr.ErrNotFound, + }, + { + desc: "view with invalid auth token", + id: sub.ID, + auth: invalidToken, + status: http.StatusUnauthorized, + res: unauthRes, + err: svcerr.ErrAuthentication, + }, + { + desc: "view with empty auth token", + id: sub.ID, + auth: "", + status: http.StatusUnauthorized, + res: missingTokRes, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + svcCall := svc.On("ViewSubscription", mock.Anything, tc.auth, tc.id).Return(tc.Sub, tc.err) + + req := testRequest{ + client: ss.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/subscriptions/%s", ss.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected request error %s", tc.desc, err)) + body, err := io.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected read error %s", tc.desc, err)) + data := strings.Trim(string(body), "\n") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res, data, fmt.Sprintf("%s: expected body %s got %s", tc.desc, tc.res, data)) + + svcCall.Unset() + } +} + +func TestList(t *testing.T) { + ss, svc := newServer() + defer ss.Close() + + const numSubs = 100 + var subs []subRes + var sub notifiers.Subscription + + for i := 0; i < numSubs; i++ { + sub = notifiers.Subscription{ + Topic: fmt.Sprintf("topic.subtopic.%d", i), + Contact: contact1, + ID: testsutil.GenerateUUID(t), + } + if i%2 == 0 { + sub.Contact = contact2 + } + sr := subRes{ + ID: sub.ID, + OwnerID: validID, + Contact: sub.Contact, + Topic: sub.Topic, + } + subs = append(subs, sr) + } + noLimit := toJSON(page{Offset: 5, Limit: 20, Total: numSubs, Subscriptions: subs[5:25]}) + one := toJSON(page{Offset: 0, Limit: 20, Total: 1, Subscriptions: subs[10:11]}) + + var contact2Subs []subRes + for i := 20; i < 40; i += 2 { + contact2Subs = append(contact2Subs, subs[i]) + } + contactList := toJSON(page{Offset: 10, Limit: 10, Total: 50, Subscriptions: contact2Subs}) + + cases := []struct { + desc string + query map[string]string + auth string + status int + res string + err error + page notifiers.Page + }{ + { + desc: "list default limit", + query: map[string]string{ + "offset": "5", + }, + auth: token, + status: http.StatusOK, + res: noLimit, + err: nil, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Offset: 5, + Limit: 20, + }, + Total: numSubs, + Subscriptions: subscriptionsSlice(subs, 5, 25), + }, + }, + { + desc: "list not existing", + query: map[string]string{ + "topic": "not-found-topic", + }, + auth: token, + status: http.StatusNotFound, + res: notFoundRes, + err: svcerr.ErrNotFound, + }, + { + desc: "list one with topic", + query: map[string]string{ + "topic": "topic.subtopic.10", + }, + auth: token, + status: http.StatusOK, + res: one, + err: nil, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Offset: 0, + Limit: 20, + }, + Total: 1, + Subscriptions: subscriptionsSlice(subs, 10, 11), + }, + }, + { + desc: "list with contact", + query: map[string]string{ + "contact": contact2, + "offset": "10", + "limit": "10", + }, + auth: token, + status: http.StatusOK, + res: contactList, + err: nil, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Offset: 10, + Limit: 10, + }, + Total: 50, + Subscriptions: subscriptionsSlice(contact2Subs, 0, 10), + }, + }, + { + desc: "list with invalid query", + query: map[string]string{ + "offset": "two", + }, + auth: token, + status: http.StatusBadRequest, + res: invalidRes, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "list with invalid auth token", + auth: invalidToken, + status: http.StatusUnauthorized, + res: unauthRes, + err: svcerr.ErrAuthentication, + }, + { + desc: "list with empty auth token", + auth: "", + status: http.StatusUnauthorized, + res: missingTokRes, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + svcCall := svc.On("ListSubscriptions", mock.Anything, tc.auth, mock.Anything).Return(tc.page, tc.err) + req := testRequest{ + client: ss.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/subscriptions%s", ss.URL, makeQuery(tc.query)), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body, err := io.ReadAll(res.Body) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + data := strings.Trim(string(body), "\n") + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res, data, fmt.Sprintf("%s: got unexpected body\n", tc.desc)) + + svcCall.Unset() + } +} + +func TestRemove(t *testing.T) { + ss, svc := newServer() + defer ss.Close() + id := testsutil.GenerateUUID(t) + + cases := []struct { + desc string + id string + auth string + status int + res string + err error + }{ + { + desc: "remove successfully", + id: id, + auth: token, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "remove not existing", + id: "not existing", + auth: token, + status: http.StatusNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "remove empty id", + id: "", + auth: token, + status: http.StatusBadRequest, + err: svcerr.ErrMalformedEntity, + }, + { + desc: "view with invalid auth token", + id: id, + auth: invalidToken, + status: http.StatusUnauthorized, + res: unauthRes, + err: svcerr.ErrAuthentication, + }, + { + desc: "view with empty auth token", + id: id, + auth: "", + status: http.StatusUnauthorized, + res: missingTokRes, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + svcCall := svc.On("RemoveSubscription", mock.Anything, tc.auth, tc.id).Return(tc.err) + + req := testRequest{ + client: ss.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/subscriptions/%s", ss.URL, tc.id), + token: tc.auth, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + + svcCall.Unset() + } +} + +func makeQuery(m map[string]string) string { + var ret string + for k, v := range m { + ret += fmt.Sprintf("&%s=%s", k, v) + } + if ret != "" { + return fmt.Sprintf("?%s", ret[1:]) + } + return "" +} + +type subRes struct { + ID string `json:"id"` + OwnerID string `json:"owner_id"` + Contact string `json:"contact"` + Topic string `json:"topic"` +} +type page struct { + Offset uint `json:"offset"` + Limit int `json:"limit"` + Total uint `json:"total,omitempty"` + Subscriptions []subRes `json:"subscriptions,omitempty"` +} + +func subscriptionsSlice(subs []subRes, start, end int) []notifiers.Subscription { + var res []notifiers.Subscription + for i := start; i < end; i++ { + sub := subs[i] + res = append(res, notifiers.Subscription{ + ID: sub.ID, + OwnerID: sub.OwnerID, + Contact: sub.Contact, + Topic: sub.Topic, + }) + } + return res +} diff --git a/consumers/notifiers/api/logging.go b/consumers/notifiers/api/logging.go new file mode 100644 index 000000000..6793c2640 --- /dev/null +++ b/consumers/notifiers/api/logging.go @@ -0,0 +1,131 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package api + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/consumers/notifiers" +) + +var _ notifiers.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc notifiers.Service +} + +// LoggingMiddleware adds logging facilities to the core service. +func LoggingMiddleware(svc notifiers.Service, logger *slog.Logger) notifiers.Service { + return &loggingMiddleware{logger, svc} +} + +// CreateSubscription logs the create_subscription request. It logs subscription ID and topic and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) CreateSubscription(ctx context.Context, token string, sub notifiers.Subscription) (id string, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("subscription", + slog.String("topic", sub.Topic), + slog.String("id", id), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create subscription failed", args...) + return + } + lm.logger.Info("Create subscription completed successfully", args...) + }(time.Now()) + + return lm.svc.CreateSubscription(ctx, token, sub) +} + +// ViewSubscription logs the view_subscription request. It logs subscription topic and id and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) ViewSubscription(ctx context.Context, token, topic string) (sub notifiers.Subscription, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("subscription", + slog.String("topic", topic), + slog.String("id", sub.ID), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View subscription failed", args...) + return + } + lm.logger.Info("View subscription completed successfully", args...) + }(time.Now()) + + return lm.svc.ViewSubscription(ctx, token, topic) +} + +// ListSubscriptions logs the list_subscriptions request. It logs page metadata and subscription topic and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) ListSubscriptions(ctx context.Context, token string, pm notifiers.PageMetadata) (res notifiers.Page, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("page", + slog.String("topic", pm.Topic), + slog.Int("limit", pm.Limit), + slog.Uint64("offset", uint64(pm.Offset)), + slog.Uint64("total", uint64(res.Total)), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List subscriptions failed", args...) + return + } + lm.logger.Info("List subscriptions completed successfully", args...) + }(time.Now()) + + return lm.svc.ListSubscriptions(ctx, token, pm) +} + +// RemoveSubscription logs the remove_subscription request. It logs subscription ID and the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) RemoveSubscription(ctx context.Context, token, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("subscription_id", id), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove subscription failed", args...) + return + } + lm.logger.Info("Remove subscription completed successfully", args...) + }(time.Now()) + + return lm.svc.RemoveSubscription(ctx, token, id) +} + +// ConsumeBlocking logs the consume_blocking request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) ConsumeBlocking(ctx context.Context, msg any) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Blocking consumer failed to consume messages successfully", args...) + return + } + lm.logger.Info("Blocking consumer consumed messages successfully", args...) + }(time.Now()) + + return lm.svc.ConsumeBlocking(ctx, msg) +} diff --git a/consumers/notifiers/api/metrics.go b/consumers/notifiers/api/metrics.go new file mode 100644 index 000000000..8ec8167c3 --- /dev/null +++ b/consumers/notifiers/api/metrics.go @@ -0,0 +1,81 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package api + +import ( + "context" + "time" + + "github.com/absmach/supermq/consumers/notifiers" + "github.com/go-kit/kit/metrics" +) + +var _ notifiers.Service = (*metricsMiddleware)(nil) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + svc notifiers.Service +} + +// MetricsMiddleware instruments core service by tracking request count and latency. +func MetricsMiddleware(svc notifiers.Service, counter metrics.Counter, latency metrics.Histogram) notifiers.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + svc: svc, + } +} + +// CreateSubscription instruments CreateSubscription method with metrics. +func (ms *metricsMiddleware) CreateSubscription(ctx context.Context, token string, sub notifiers.Subscription) (string, error) { + defer func(begin time.Time) { + ms.counter.With("method", "create_subscription").Add(1) + ms.latency.With("method", "create_subscription").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.CreateSubscription(ctx, token, sub) +} + +// ViewSubscription instruments ViewSubscription method with metrics. +func (ms *metricsMiddleware) ViewSubscription(ctx context.Context, token, topic string) (notifiers.Subscription, error) { + defer func(begin time.Time) { + ms.counter.With("method", "view_subscription").Add(1) + ms.latency.With("method", "view_subscription").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.ViewSubscription(ctx, token, topic) +} + +// ListSubscriptions instruments ListSubscriptions method with metrics. +func (ms *metricsMiddleware) ListSubscriptions(ctx context.Context, token string, pm notifiers.PageMetadata) (notifiers.Page, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_subscriptions").Add(1) + ms.latency.With("method", "list_subscriptions").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.ListSubscriptions(ctx, token, pm) +} + +// RemoveSubscription instruments RemoveSubscription method with metrics. +func (ms *metricsMiddleware) RemoveSubscription(ctx context.Context, token, id string) error { + defer func(begin time.Time) { + ms.counter.With("method", "remove_subscription").Add(1) + ms.latency.With("method", "remove_subscription").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RemoveSubscription(ctx, token, id) +} + +// ConsumeBlocking instruments ConsumeBlocking method with metrics. +func (ms *metricsMiddleware) ConsumeBlocking(ctx context.Context, msg any) error { + defer func(begin time.Time) { + ms.counter.With("method", "consume").Add(1) + ms.latency.With("method", "consume").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.ConsumeBlocking(ctx, msg) +} diff --git a/consumers/notifiers/api/requests.go b/consumers/notifiers/api/requests.go new file mode 100644 index 000000000..9b133aead --- /dev/null +++ b/consumers/notifiers/api/requests.go @@ -0,0 +1,55 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import apiutil "github.com/absmach/supermq/api/http/util" + +type createSubReq struct { + token string + Topic string `json:"topic,omitempty"` + Contact string `json:"contact,omitempty"` +} + +func (req createSubReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.Topic == "" { + return apiutil.ErrInvalidTopic + } + if req.Contact == "" { + return apiutil.ErrInvalidContact + } + return nil +} + +type subReq struct { + token string + id string +} + +func (req subReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type listSubsReq struct { + token string + topic string + contact string + offset uint + limit uint +} + +func (req listSubsReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + return nil +} diff --git a/consumers/notifiers/api/responses.go b/consumers/notifiers/api/responses.go new file mode 100644 index 000000000..c4732213c --- /dev/null +++ b/consumers/notifiers/api/responses.go @@ -0,0 +1,88 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/supermq" +) + +var ( + _ supermq.Response = (*createSubRes)(nil) + _ supermq.Response = (*viewSubRes)(nil) + _ supermq.Response = (*listSubsRes)(nil) + _ supermq.Response = (*removeSubRes)(nil) +) + +type createSubRes struct { + ID string +} + +func (res createSubRes) Code() int { + return http.StatusCreated +} + +func (res createSubRes) Headers() map[string]string { + return map[string]string{ + "Location": fmt.Sprintf("/subscriptions/%s", res.ID), + } +} + +func (res createSubRes) Empty() bool { + return true +} + +type viewSubRes struct { + ID string `json:"id"` + OwnerID string `json:"owner_id"` + Contact string `json:"contact"` + Topic string `json:"topic"` +} + +func (res viewSubRes) Code() int { + return http.StatusOK +} + +func (res viewSubRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewSubRes) Empty() bool { + return false +} + +type listSubsRes struct { + Offset uint `json:"offset"` + Limit int `json:"limit"` + Total uint `json:"total,omitempty"` + Subscriptions []viewSubRes `json:"subscriptions,omitempty"` +} + +func (res listSubsRes) Code() int { + return http.StatusOK +} + +func (res listSubsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listSubsRes) Empty() bool { + return false +} + +type removeSubRes struct{} + +func (res removeSubRes) Code() int { + return http.StatusNoContent +} + +func (res removeSubRes) Headers() map[string]string { + return map[string]string{} +} + +func (res removeSubRes) Empty() bool { + return true +} diff --git a/consumers/notifiers/api/transport.go b/consumers/notifiers/api/transport.go new file mode 100644 index 000000000..2fb673cf3 --- /dev/null +++ b/consumers/notifiers/api/transport.go @@ -0,0 +1,131 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strings" + + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + contentType = "application/json" + offsetKey = "offset" + limitKey = "limit" + topicKey = "topic" + contactKey = "contact" + defOffset = 0 + defLimit = 20 +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc notifiers.Service, logger *slog.Logger, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + mux := chi.NewRouter() + + mux.Route("/subscriptions", func(r chi.Router) { + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + createSubscriptionEndpoint(svc), + decodeCreate, + api.EncodeResponse, + opts..., + ), "create").ServeHTTP) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listSubscriptionsEndpoint(svc), + decodeList, + api.EncodeResponse, + opts..., + ), "list").ServeHTTP) + + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteSubscriptionEndpoint(svc), + decodeSubscription, + api.EncodeResponse, + opts..., + ), "delete").ServeHTTP) + + r.Get("/{subID}", otelhttp.NewHandler(kithttp.NewServer( + viewSubscriptionEndpoint(svc), + decodeSubscription, + api.EncodeResponse, + opts..., + ), "view").ServeHTTP) + + r.Delete("/{subID}", otelhttp.NewHandler(kithttp.NewServer( + deleteSubscriptionEndpoint(svc), + decodeSubscription, + api.EncodeResponse, + opts..., + ), "delete").ServeHTTP) + }) + mux.Get("/health", supermq.Health("notifier", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeCreate(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := createSubReq{token: apiutil.ExtractBearerToken(r)} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeSubscription(_ context.Context, r *http.Request) (any, error) { + req := subReq{ + id: chi.URLParam(r, "subID"), + token: apiutil.ExtractBearerToken(r), + } + + return req, nil +} + +func decodeList(_ context.Context, r *http.Request) (any, error) { + req := listSubsReq{token: apiutil.ExtractBearerToken(r)} + vals := r.URL.Query()[topicKey] + if len(vals) > 0 { + req.topic = vals[0] + } + + vals = r.URL.Query()[contactKey] + if len(vals) > 0 { + req.contact = vals[0] + } + + offset, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset) + if err != nil { + return listSubsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + req.offset = uint(offset) + + limit, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit) + if err != nil { + return listSubsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + req.limit = uint(limit) + + return req, nil +} diff --git a/consumers/notifiers/doc.go b/consumers/notifiers/doc.go new file mode 100644 index 000000000..ee084f198 --- /dev/null +++ b/consumers/notifiers/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package notifiers contain the domain concept definitions needed to +// support SuperMQ notifications functionality. +package notifiers diff --git a/consumers/notifiers/mocks/service.go b/consumers/notifiers/mocks/service.go new file mode 100644 index 000000000..3f35ff745 --- /dev/null +++ b/consumers/notifiers/mocks/service.go @@ -0,0 +1,379 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/consumers/notifiers" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// ConsumeBlocking provides a mock function for the type Service +func (_mock *Service) ConsumeBlocking(ctx context.Context, messages any) error { + ret := _mock.Called(ctx, messages) + + if len(ret) == 0 { + panic("no return value specified for ConsumeBlocking") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, any) error); ok { + r0 = returnFunc(ctx, messages) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_ConsumeBlocking_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ConsumeBlocking' +type Service_ConsumeBlocking_Call struct { + *mock.Call +} + +// ConsumeBlocking is a helper method to define mock.On call +// - ctx context.Context +// - messages any +func (_e *Service_Expecter) ConsumeBlocking(ctx interface{}, messages interface{}) *Service_ConsumeBlocking_Call { + return &Service_ConsumeBlocking_Call{Call: _e.mock.On("ConsumeBlocking", ctx, messages)} +} + +func (_c *Service_ConsumeBlocking_Call) Run(run func(ctx context.Context, messages any)) *Service_ConsumeBlocking_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 any + if args[1] != nil { + arg1 = args[1].(any) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_ConsumeBlocking_Call) Return(err error) *Service_ConsumeBlocking_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_ConsumeBlocking_Call) RunAndReturn(run func(ctx context.Context, messages any) error) *Service_ConsumeBlocking_Call { + _c.Call.Return(run) + return _c +} + +// CreateSubscription provides a mock function for the type Service +func (_mock *Service) CreateSubscription(ctx context.Context, token string, sub notifiers.Subscription) (string, error) { + ret := _mock.Called(ctx, token, sub) + + if len(ret) == 0 { + panic("no return value specified for CreateSubscription") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, notifiers.Subscription) (string, error)); ok { + return returnFunc(ctx, token, sub) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, notifiers.Subscription) string); ok { + r0 = returnFunc(ctx, token, sub) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, notifiers.Subscription) error); ok { + r1 = returnFunc(ctx, token, sub) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_CreateSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateSubscription' +type Service_CreateSubscription_Call struct { + *mock.Call +} + +// CreateSubscription is a helper method to define mock.On call +// - ctx context.Context +// - token string +// - sub notifiers.Subscription +func (_e *Service_Expecter) CreateSubscription(ctx interface{}, token interface{}, sub interface{}) *Service_CreateSubscription_Call { + return &Service_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", ctx, token, sub)} +} + +func (_c *Service_CreateSubscription_Call) Run(run func(ctx context.Context, token string, sub notifiers.Subscription)) *Service_CreateSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 notifiers.Subscription + if args[2] != nil { + arg2 = args[2].(notifiers.Subscription) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_CreateSubscription_Call) Return(s string, err error) *Service_CreateSubscription_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *Service_CreateSubscription_Call) RunAndReturn(run func(ctx context.Context, token string, sub notifiers.Subscription) (string, error)) *Service_CreateSubscription_Call { + _c.Call.Return(run) + return _c +} + +// ListSubscriptions provides a mock function for the type Service +func (_mock *Service) ListSubscriptions(ctx context.Context, token string, pm notifiers.PageMetadata) (notifiers.Page, error) { + ret := _mock.Called(ctx, token, pm) + + if len(ret) == 0 { + panic("no return value specified for ListSubscriptions") + } + + var r0 notifiers.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, notifiers.PageMetadata) (notifiers.Page, error)); ok { + return returnFunc(ctx, token, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, notifiers.PageMetadata) notifiers.Page); ok { + r0 = returnFunc(ctx, token, pm) + } else { + r0 = ret.Get(0).(notifiers.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, notifiers.PageMetadata) error); ok { + r1 = returnFunc(ctx, token, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListSubscriptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListSubscriptions' +type Service_ListSubscriptions_Call struct { + *mock.Call +} + +// ListSubscriptions is a helper method to define mock.On call +// - ctx context.Context +// - token string +// - pm notifiers.PageMetadata +func (_e *Service_Expecter) ListSubscriptions(ctx interface{}, token interface{}, pm interface{}) *Service_ListSubscriptions_Call { + return &Service_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", ctx, token, pm)} +} + +func (_c *Service_ListSubscriptions_Call) Run(run func(ctx context.Context, token string, pm notifiers.PageMetadata)) *Service_ListSubscriptions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 notifiers.PageMetadata + if args[2] != nil { + arg2 = args[2].(notifiers.PageMetadata) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ListSubscriptions_Call) Return(page notifiers.Page, err error) *Service_ListSubscriptions_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *Service_ListSubscriptions_Call) RunAndReturn(run func(ctx context.Context, token string, pm notifiers.PageMetadata) (notifiers.Page, error)) *Service_ListSubscriptions_Call { + _c.Call.Return(run) + return _c +} + +// RemoveSubscription provides a mock function for the type Service +func (_mock *Service) RemoveSubscription(ctx context.Context, token string, id string) error { + ret := _mock.Called(ctx, token, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveSubscription") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, token, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveSubscription' +type Service_RemoveSubscription_Call struct { + *mock.Call +} + +// RemoveSubscription is a helper method to define mock.On call +// - ctx context.Context +// - token string +// - id string +func (_e *Service_Expecter) RemoveSubscription(ctx interface{}, token interface{}, id interface{}) *Service_RemoveSubscription_Call { + return &Service_RemoveSubscription_Call{Call: _e.mock.On("RemoveSubscription", ctx, token, id)} +} + +func (_c *Service_RemoveSubscription_Call) Run(run func(ctx context.Context, token string, id string)) *Service_RemoveSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RemoveSubscription_Call) Return(err error) *Service_RemoveSubscription_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveSubscription_Call) RunAndReturn(run func(ctx context.Context, token string, id string) error) *Service_RemoveSubscription_Call { + _c.Call.Return(run) + return _c +} + +// ViewSubscription provides a mock function for the type Service +func (_mock *Service) ViewSubscription(ctx context.Context, token string, id string) (notifiers.Subscription, error) { + ret := _mock.Called(ctx, token, id) + + if len(ret) == 0 { + panic("no return value specified for ViewSubscription") + } + + var r0 notifiers.Subscription + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (notifiers.Subscription, error)); ok { + return returnFunc(ctx, token, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) notifiers.Subscription); ok { + r0 = returnFunc(ctx, token, id) + } else { + r0 = ret.Get(0).(notifiers.Subscription) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, token, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewSubscription' +type Service_ViewSubscription_Call struct { + *mock.Call +} + +// ViewSubscription is a helper method to define mock.On call +// - ctx context.Context +// - token string +// - id string +func (_e *Service_Expecter) ViewSubscription(ctx interface{}, token interface{}, id interface{}) *Service_ViewSubscription_Call { + return &Service_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", ctx, token, id)} +} + +func (_c *Service_ViewSubscription_Call) Run(run func(ctx context.Context, token string, id string)) *Service_ViewSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ViewSubscription_Call) Return(subscription notifiers.Subscription, err error) *Service_ViewSubscription_Call { + _c.Call.Return(subscription, err) + return _c +} + +func (_c *Service_ViewSubscription_Call) RunAndReturn(run func(ctx context.Context, token string, id string) (notifiers.Subscription, error)) *Service_ViewSubscription_Call { + _c.Call.Return(run) + return _c +} diff --git a/consumers/notifiers/mocks/subscriptions_repository.go b/consumers/notifiers/mocks/subscriptions_repository.go new file mode 100644 index 000000000..9a4738922 --- /dev/null +++ b/consumers/notifiers/mocks/subscriptions_repository.go @@ -0,0 +1,298 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/consumers/notifiers" + mock "github.com/stretchr/testify/mock" +) + +// NewSubscriptionsRepository creates a new instance of SubscriptionsRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSubscriptionsRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *SubscriptionsRepository { + mock := &SubscriptionsRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// SubscriptionsRepository is an autogenerated mock type for the SubscriptionsRepository type +type SubscriptionsRepository struct { + mock.Mock +} + +type SubscriptionsRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *SubscriptionsRepository) EXPECT() *SubscriptionsRepository_Expecter { + return &SubscriptionsRepository_Expecter{mock: &_m.Mock} +} + +// Remove provides a mock function for the type SubscriptionsRepository +func (_mock *SubscriptionsRepository) Remove(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// SubscriptionsRepository_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type SubscriptionsRepository_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *SubscriptionsRepository_Expecter) Remove(ctx interface{}, id interface{}) *SubscriptionsRepository_Remove_Call { + return &SubscriptionsRepository_Remove_Call{Call: _e.mock.On("Remove", ctx, id)} +} + +func (_c *SubscriptionsRepository_Remove_Call) Run(run func(ctx context.Context, id string)) *SubscriptionsRepository_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SubscriptionsRepository_Remove_Call) Return(err error) *SubscriptionsRepository_Remove_Call { + _c.Call.Return(err) + return _c +} + +func (_c *SubscriptionsRepository_Remove_Call) RunAndReturn(run func(ctx context.Context, id string) error) *SubscriptionsRepository_Remove_Call { + _c.Call.Return(run) + return _c +} + +// Retrieve provides a mock function for the type SubscriptionsRepository +func (_mock *SubscriptionsRepository) Retrieve(ctx context.Context, id string) (notifiers.Subscription, error) { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Retrieve") + } + + var r0 notifiers.Subscription + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (notifiers.Subscription, error)); ok { + return returnFunc(ctx, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) notifiers.Subscription); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Get(0).(notifiers.Subscription) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// SubscriptionsRepository_Retrieve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Retrieve' +type SubscriptionsRepository_Retrieve_Call struct { + *mock.Call +} + +// Retrieve is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *SubscriptionsRepository_Expecter) Retrieve(ctx interface{}, id interface{}) *SubscriptionsRepository_Retrieve_Call { + return &SubscriptionsRepository_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, id)} +} + +func (_c *SubscriptionsRepository_Retrieve_Call) Run(run func(ctx context.Context, id string)) *SubscriptionsRepository_Retrieve_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SubscriptionsRepository_Retrieve_Call) Return(subscription notifiers.Subscription, err error) *SubscriptionsRepository_Retrieve_Call { + _c.Call.Return(subscription, err) + return _c +} + +func (_c *SubscriptionsRepository_Retrieve_Call) RunAndReturn(run func(ctx context.Context, id string) (notifiers.Subscription, error)) *SubscriptionsRepository_Retrieve_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAll provides a mock function for the type SubscriptionsRepository +func (_mock *SubscriptionsRepository) RetrieveAll(ctx context.Context, pm notifiers.PageMetadata) (notifiers.Page, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAll") + } + + var r0 notifiers.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, notifiers.PageMetadata) (notifiers.Page, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, notifiers.PageMetadata) notifiers.Page); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(notifiers.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, notifiers.PageMetadata) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// SubscriptionsRepository_RetrieveAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAll' +type SubscriptionsRepository_RetrieveAll_Call struct { + *mock.Call +} + +// RetrieveAll is a helper method to define mock.On call +// - ctx context.Context +// - pm notifiers.PageMetadata +func (_e *SubscriptionsRepository_Expecter) RetrieveAll(ctx interface{}, pm interface{}) *SubscriptionsRepository_RetrieveAll_Call { + return &SubscriptionsRepository_RetrieveAll_Call{Call: _e.mock.On("RetrieveAll", ctx, pm)} +} + +func (_c *SubscriptionsRepository_RetrieveAll_Call) Run(run func(ctx context.Context, pm notifiers.PageMetadata)) *SubscriptionsRepository_RetrieveAll_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 notifiers.PageMetadata + if args[1] != nil { + arg1 = args[1].(notifiers.PageMetadata) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SubscriptionsRepository_RetrieveAll_Call) Return(page notifiers.Page, err error) *SubscriptionsRepository_RetrieveAll_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *SubscriptionsRepository_RetrieveAll_Call) RunAndReturn(run func(ctx context.Context, pm notifiers.PageMetadata) (notifiers.Page, error)) *SubscriptionsRepository_RetrieveAll_Call { + _c.Call.Return(run) + return _c +} + +// Save provides a mock function for the type SubscriptionsRepository +func (_mock *SubscriptionsRepository) Save(ctx context.Context, sub notifiers.Subscription) (string, error) { + ret := _mock.Called(ctx, sub) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, notifiers.Subscription) (string, error)); ok { + return returnFunc(ctx, sub) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, notifiers.Subscription) string); ok { + r0 = returnFunc(ctx, sub) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, notifiers.Subscription) error); ok { + r1 = returnFunc(ctx, sub) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// SubscriptionsRepository_Save_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Save' +type SubscriptionsRepository_Save_Call struct { + *mock.Call +} + +// Save is a helper method to define mock.On call +// - ctx context.Context +// - sub notifiers.Subscription +func (_e *SubscriptionsRepository_Expecter) Save(ctx interface{}, sub interface{}) *SubscriptionsRepository_Save_Call { + return &SubscriptionsRepository_Save_Call{Call: _e.mock.On("Save", ctx, sub)} +} + +func (_c *SubscriptionsRepository_Save_Call) Run(run func(ctx context.Context, sub notifiers.Subscription)) *SubscriptionsRepository_Save_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 notifiers.Subscription + if args[1] != nil { + arg1 = args[1].(notifiers.Subscription) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SubscriptionsRepository_Save_Call) Return(s string, err error) *SubscriptionsRepository_Save_Call { + _c.Call.Return(s, err) + return _c +} + +func (_c *SubscriptionsRepository_Save_Call) RunAndReturn(run func(ctx context.Context, sub notifiers.Subscription) (string, error)) *SubscriptionsRepository_Save_Call { + _c.Call.Return(run) + return _c +} diff --git a/consumers/notifiers/postgres/database.go b/consumers/notifiers/postgres/database.go new file mode 100644 index 000000000..e6418dda5 --- /dev/null +++ b/consumers/notifiers/postgres/database.go @@ -0,0 +1,74 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/jmoiron/sqlx" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +var _ Database = (*database)(nil) + +type database struct { + db *sqlx.DB + tracer trace.Tracer +} + +// Database provides a database interface. +type Database interface { + NamedExecContext(context.Context, string, any) (sql.Result, error) + QueryRowxContext(context.Context, string, ...any) *sqlx.Row + NamedQueryContext(context.Context, string, any) (*sqlx.Rows, error) + GetContext(context.Context, any, string, ...any) error +} + +// NewDatabase creates a SubscriptionsDatabase instance. +func NewDatabase(db *sqlx.DB, tracer trace.Tracer) Database { + return &database{ + db: db, + tracer: tracer, + } +} + +func (dm database) NamedExecContext(ctx context.Context, query string, args any) (sql.Result, error) { + ctx, span := dm.addSpanTags(ctx, "NamedExecContext", query) + defer span.End() + return dm.db.NamedExecContext(ctx, query, args) +} + +func (dm database) QueryRowxContext(ctx context.Context, query string, args ...any) *sqlx.Row { + ctx, span := dm.addSpanTags(ctx, "QueryRowxContext", query) + defer span.End() + return dm.db.QueryRowxContext(ctx, query, args...) +} + +func (dm database) NamedQueryContext(ctx context.Context, query string, args any) (*sqlx.Rows, error) { + ctx, span := dm.addSpanTags(ctx, "NamedQueryContext", query) + defer span.End() + return dm.db.NamedQueryContext(ctx, query, args) +} + +func (dm database) GetContext(ctx context.Context, dest any, query string, args ...any) error { + ctx, span := dm.addSpanTags(ctx, "GetContext", query) + defer span.End() + return dm.db.GetContext(ctx, dest, query, args...) +} + +func (dm database) addSpanTags(ctx context.Context, method, query string) (context.Context, trace.Span) { + ctx, span := dm.tracer.Start(ctx, + fmt.Sprintf("sql_%s", method), + trace.WithAttributes( + attribute.String("sql.statement", query), + attribute.String("span.kind", "client"), + attribute.String("peer.service", "postgres"), + attribute.String("db.type", "sql"), + ), + ) + return ctx, span +} diff --git a/consumers/notifiers/postgres/doc.go b/consumers/notifiers/postgres/doc.go new file mode 100644 index 000000000..73a678477 --- /dev/null +++ b/consumers/notifiers/postgres/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres contains repository implementations using PostgreSQL as +// the underlying database. +package postgres diff --git a/consumers/notifiers/postgres/init.go b/consumers/notifiers/postgres/init.go new file mode 100644 index 000000000..ac74c3c0b --- /dev/null +++ b/consumers/notifiers/postgres/init.go @@ -0,0 +1,28 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import migrate "github.com/rubenv/sql-migrate" + +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "subscriptions_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS subscriptions ( + id VARCHAR(254) PRIMARY KEY, + owner_id VARCHAR(254) NOT NULL, + contact VARCHAR(254), + topic TEXT, + UNIQUE(topic, contact) + )`, + }, + Down: []string{ + "DROP TABLE IF EXISTS subscriptions", + }, + }, + }, + } +} diff --git a/consumers/notifiers/postgres/setup_test.go b/consumers/notifiers/postgres/setup_test.go new file mode 100644 index 000000000..523bf92de --- /dev/null +++ b/consumers/notifiers/postgres/setup_test.go @@ -0,0 +1,89 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres_test contains tests for PostgreSQL repository +// implementations. +package postgres_test + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/absmach/supermq/consumers/notifiers/postgres" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/absmach/supermq/pkg/ulid" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var ( + idProvider = ulid.New() + db *sqlx.DB +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + if err := pool.Retry(func() error { + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = pgclient.Setup(dbConfig, *postgres.Migration()); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/consumers/notifiers/postgres/subscriptions.go b/consumers/notifiers/postgres/subscriptions.go new file mode 100644 index 000000000..93b590f54 --- /dev/null +++ b/consumers/notifiers/postgres/subscriptions.go @@ -0,0 +1,164 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" +) + +var _ notifiers.SubscriptionsRepository = (*subscriptionsRepo)(nil) + +type subscriptionsRepo struct { + db Database +} + +// New instantiates a PostgreSQL implementation of Subscriptions repository. +func New(db Database) notifiers.SubscriptionsRepository { + return &subscriptionsRepo{ + db: db, + } +} + +func (repo subscriptionsRepo) Save(ctx context.Context, sub notifiers.Subscription) (string, error) { + q := `INSERT INTO subscriptions (id, owner_id, contact, topic) VALUES (:id, :owner_id, :contact, :topic) RETURNING id` + + dbSub := dbSubscription{ + ID: sub.ID, + OwnerID: sub.OwnerID, + Contact: sub.Contact, + Topic: sub.Topic, + } + + row, err := repo.db.NamedQueryContext(ctx, q, dbSub) + if err != nil { + if pqErr, ok := err.(*pgconn.PgError); ok && pqErr.Code == pgerrcode.UniqueViolation { + return "", errors.Wrap(notifiers.ErrSubscriptionsAlreadyExists, err) + } + return "", errors.Wrap(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + return sub.ID, nil +} + +func (repo subscriptionsRepo) Retrieve(ctx context.Context, id string) (notifiers.Subscription, error) { + q := `SELECT id, owner_id, contact, topic FROM subscriptions WHERE id = $1` + sub := dbSubscription{} + if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&sub); err != nil { + if err == sql.ErrNoRows { + return notifiers.Subscription{}, errors.Wrap(repoerr.ErrNotFound, err) + } + return notifiers.Subscription{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return fromDBSub(sub), nil +} + +func (repo subscriptionsRepo) RetrieveAll(ctx context.Context, pm notifiers.PageMetadata) (notifiers.Page, error) { + q := `SELECT id, owner_id, contact, topic FROM subscriptions` + args := make(map[string]any) + if pm.Topic != "" { + args["topic"] = pm.Topic + } + if pm.Contact != "" { + args["contact"] = pm.Contact + } + var condition string + if len(args) > 0 { + var cond []string + for k := range args { + cond = append(cond, fmt.Sprintf("%s = :%s", k, k)) + } + condition = fmt.Sprintf(" WHERE %s", strings.Join(cond, " AND ")) + q = fmt.Sprintf("%s%s", q, condition) + } + args["offset"] = pm.Offset + q = fmt.Sprintf("%s OFFSET :offset", q) + if pm.Limit > 0 { + q = fmt.Sprintf("%s LIMIT :limit", q) + args["limit"] = pm.Limit + } + + rows, err := repo.db.NamedQueryContext(ctx, q, args) + if err != nil { + return notifiers.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var subs []notifiers.Subscription + for rows.Next() { + sub := dbSubscription{} + if err := rows.StructScan(&sub); err != nil { + return notifiers.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + subs = append(subs, fromDBSub(sub)) + } + + if len(subs) == 0 { + return notifiers.Page{}, repoerr.ErrNotFound + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM subscriptions %s`, condition) + total, err := total(ctx, repo.db, cq, args) + if err != nil { + return notifiers.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + ret := notifiers.Page{ + PageMetadata: pm, + Total: total, + Subscriptions: subs, + } + + return ret, nil +} + +func (repo subscriptionsRepo) Remove(ctx context.Context, id string) error { + q := `DELETE from subscriptions WHERE id = $1` + + if r := repo.db.QueryRowxContext(ctx, q, id); r.Err() != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, r.Err()) + } + return nil +} + +func total(ctx context.Context, db Database, query string, params any) (uint, error) { + rows, err := db.NamedQueryContext(ctx, query, params) + if err != nil { + return 0, err + } + defer rows.Close() + var total uint + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return 0, err + } + } + return total, nil +} + +type dbSubscription struct { + ID string `db:"id"` + OwnerID string `db:"owner_id"` + Contact string `db:"contact"` + Topic string `db:"topic"` +} + +func fromDBSub(sub dbSubscription) notifiers.Subscription { + return notifiers.Subscription{ + ID: sub.ID, + OwnerID: sub.OwnerID, + Contact: sub.Contact, + Topic: sub.Topic, + } +} diff --git a/consumers/notifiers/postgres/subscriptions_test.go b/consumers/notifiers/postgres/subscriptions_test.go new file mode 100644 index 000000000..f0f1fa9d6 --- /dev/null +++ b/consumers/notifiers/postgres/subscriptions_test.go @@ -0,0 +1,263 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "testing" + + "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/consumers/notifiers/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" +) + +const ( + owner = "owner@example.com" + numSubs = 100 +) + +var tracer = otel.Tracer("tests") + +func TestSave(t *testing.T) { + dbMiddleware := postgres.NewDatabase(db, tracer) + repo := postgres.New(dbMiddleware) + + id1, err := idProvider.ID() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + id2, err := idProvider.ID() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + sub1 := notifiers.Subscription{ + OwnerID: id1, + ID: id1, + Contact: owner, + Topic: "topic.subtopic", + } + + sub2 := sub1 + sub2.ID = id2 + + cases := []struct { + desc string + sub notifiers.Subscription + id string + err error + }{ + { + desc: "save successfully", + sub: sub1, + id: id1, + err: nil, + }, + { + desc: "save duplicate", + sub: sub2, + id: "", + err: notifiers.ErrSubscriptionsAlreadyExists, + }, + } + + for _, tc := range cases { + id, err := repo.Save(context.Background(), tc.sub) + assert.Equal(t, tc.id, id, fmt.Sprintf("%s: expected id %s got %s\n", tc.desc, tc.id, id)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestView(t *testing.T) { + dbMiddleware := postgres.NewDatabase(db, tracer) + repo := postgres.New(dbMiddleware) + + id, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("got an error creating id: %s", err)) + + sub := notifiers.Subscription{ + OwnerID: id, + ID: id, + Contact: owner, + Topic: "view.subtopic", + } + + ret, err := repo.Save(context.Background(), sub) + require.Nil(t, err, fmt.Sprintf("creating subscription must not fail: %s", err)) + require.Equal(t, id, ret, fmt.Sprintf("provided id %s must be the same as the returned id %s", id, ret)) + + cases := []struct { + desc string + sub notifiers.Subscription + id string + err error + }{ + { + desc: "retrieve successfully", + sub: sub, + id: id, + err: nil, + }, + { + desc: "retrieve not existing", + sub: notifiers.Subscription{}, + id: "non-existing", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + sub, err := repo.Retrieve(context.Background(), tc.id) + assert.Equal(t, tc.sub, sub, fmt.Sprintf("%s: expected sub %v got %v\n", tc.desc, tc.sub, sub)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestRetrieveAll(t *testing.T) { + _, err := db.Exec("DELETE FROM subscriptions") + require.Nil(t, err, fmt.Sprintf("cleanup must not fail: %s", err)) + + dbMiddleware := postgres.NewDatabase(db, tracer) + repo := postgres.New(dbMiddleware) + + var subs []notifiers.Subscription + + for i := 0; i < numSubs; i++ { + id, err := idProvider.ID() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + sub := notifiers.Subscription{ + OwnerID: "owner", + ID: id, + Contact: owner, + Topic: fmt.Sprintf("list.subtopic.%d", i), + } + + ret, err := repo.Save(context.Background(), sub) + require.Nil(t, err, fmt.Sprintf("creating subscription must not fail: %s", err)) + require.Equal(t, id, ret, fmt.Sprintf("provided id %s must be the same as the returned id %s", id, ret)) + subs = append(subs, sub) + } + + cases := []struct { + desc string + pageMeta notifiers.PageMetadata + page notifiers.Page + err error + }{ + { + desc: "retrieve successfully", + pageMeta: notifiers.PageMetadata{ + Offset: 10, + Limit: 2, + }, + page: notifiers.Page{ + Total: numSubs, + PageMetadata: notifiers.PageMetadata{ + Offset: 10, + Limit: 2, + }, + Subscriptions: subs[10:12], + }, + err: nil, + }, + { + desc: "retrieve with contact", + pageMeta: notifiers.PageMetadata{ + Offset: 10, + Limit: 2, + Contact: owner, + }, + page: notifiers.Page{ + Total: numSubs, + PageMetadata: notifiers.PageMetadata{ + Offset: 10, + Limit: 2, + Contact: owner, + }, + Subscriptions: subs[10:12], + }, + err: nil, + }, + { + desc: "retrieve with topic", + pageMeta: notifiers.PageMetadata{ + Offset: 0, + Limit: 2, + Topic: "list.subtopic.11", + }, + page: notifiers.Page{ + Total: 1, + PageMetadata: notifiers.PageMetadata{ + Offset: 0, + Limit: 2, + Topic: "list.subtopic.11", + }, + Subscriptions: subs[11:12], + }, + err: nil, + }, + { + desc: "retrieve with no limit", + pageMeta: notifiers.PageMetadata{ + Offset: 0, + Limit: -1, + }, + page: notifiers.Page{ + Total: numSubs, + PageMetadata: notifiers.PageMetadata{ + Limit: -1, + }, + Subscriptions: subs, + }, + err: nil, + }, + } + + for _, tc := range cases { + page, err := repo.RetrieveAll(context.Background(), tc.pageMeta) + assert.Equal(t, tc.page, page, fmt.Sprintf("%s: got unexpected page\n", tc.desc)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} + +func TestRemove(t *testing.T) { + dbMiddleware := postgres.NewDatabase(db, tracer) + repo := postgres.New(dbMiddleware) + id, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("got an error creating id: %s", err)) + sub := notifiers.Subscription{ + OwnerID: id, + ID: id, + Contact: owner, + Topic: "remove.subtopic.%d", + } + + ret, err := repo.Save(context.Background(), sub) + require.Nil(t, err, fmt.Sprintf("creating subscription must not fail: %s", err)) + require.Equal(t, id, ret, fmt.Sprintf("provided id %s must be the same as the returned id %s", id, ret)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove successfully", + id: id, + err: nil, + }, + { + desc: "remove not existing", + id: "empty", + err: nil, + }, + } + + for _, tc := range cases { + err := repo.Remove(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } +} diff --git a/consumers/notifiers/service.go b/consumers/notifiers/service.go new file mode 100644 index 000000000..6b9b1c8bf --- /dev/null +++ b/consumers/notifiers/service.go @@ -0,0 +1,177 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package notifiers + +import ( + "context" + "fmt" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/consumers" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/messaging" +) + +var ( + // ErrMessage indicates an error converting a message to SuperMQ message. + ErrMessage = errors.New("failed to convert to SuperMQ message") + + // ErrSubscriptionsAlreadyExists indicates subscription already exists. + ErrSubscriptionsAlreadyExists = errors.NewRequestError("subscription already exists") +) +var _ consumers.AsyncConsumer = (*notifierService)(nil) + +// Service reprents a notification service. +type Service interface { + // CreateSubscription persists a subscription. + // Successful operation is indicated by non-nil error response. + CreateSubscription(ctx context.Context, token string, sub Subscription) (string, error) + + // ViewSubscription retrieves the subscription for the given user and id. + ViewSubscription(ctx context.Context, token, id string) (Subscription, error) + + // ListSubscriptions lists subscriptions having the provided user token and search params. + ListSubscriptions(ctx context.Context, token string, pm PageMetadata) (Page, error) + + // RemoveSubscription removes the subscription having the provided identifier. + RemoveSubscription(ctx context.Context, token, id string) error + + consumers.BlockingConsumer +} + +var _ Service = (*notifierService)(nil) + +type notifierService struct { + authn smqauthn.Authentication + subs SubscriptionsRepository + idp supermq.IDProvider + notifier consumers.Notifier + errCh chan error + from string +} + +// New instantiates the subscriptions service implementation. +func New(authn smqauthn.Authentication, subs SubscriptionsRepository, idp supermq.IDProvider, notifier consumers.Notifier, from string) Service { + return ¬ifierService{ + authn: authn, + subs: subs, + idp: idp, + notifier: notifier, + errCh: make(chan error, 1), + from: from, + } +} + +func (ns *notifierService) CreateSubscription(ctx context.Context, token string, sub Subscription) (string, error) { + session, err := ns.authn.Authenticate(ctx, token) + if err != nil { + return "", err + } + sub.ID, err = ns.idp.ID() + if err != nil { + return "", err + } + + sub.OwnerID = session.DomainUserID + id, err := ns.subs.Save(ctx, sub) + if err != nil { + return "", errors.Wrap(svcerr.ErrCreateEntity, err) + } + return id, nil +} + +func (ns *notifierService) ViewSubscription(ctx context.Context, token, id string) (Subscription, error) { + if _, err := ns.authn.Authenticate(ctx, token); err != nil { + return Subscription{}, err + } + + return ns.subs.Retrieve(ctx, id) +} + +func (ns *notifierService) ListSubscriptions(ctx context.Context, token string, pm PageMetadata) (Page, error) { + if _, err := ns.authn.Authenticate(ctx, token); err != nil { + return Page{}, err + } + + return ns.subs.RetrieveAll(ctx, pm) +} + +func (ns *notifierService) RemoveSubscription(ctx context.Context, token, id string) error { + if _, err := ns.authn.Authenticate(ctx, token); err != nil { + return err + } + + return ns.subs.Remove(ctx, id) +} + +func (ns *notifierService) ConsumeBlocking(ctx context.Context, message any) error { + msg, ok := message.(*messaging.Message) + if !ok { + return ErrMessage + } + topic := msg.GetChannel() + if msg.GetSubtopic() != "" { + topic = fmt.Sprintf("%s.%s", msg.GetChannel(), msg.GetSubtopic()) + } + pm := PageMetadata{ + Topic: topic, + Offset: 0, + Limit: -1, + } + page, err := ns.subs.RetrieveAll(ctx, pm) + if err != nil { + return err + } + + var to []string + for _, sub := range page.Subscriptions { + to = append(to, sub.Contact) + } + if len(to) > 0 { + err := ns.notifier.Notify(ns.from, to, msg) + if err != nil { + return errors.Wrap(consumers.ErrNotify, err) + } + } + + return nil +} + +func (ns *notifierService) ConsumeAsync(ctx context.Context, message any) { + msg, ok := message.(*messaging.Message) + if !ok { + ns.errCh <- ErrMessage + return + } + topic := msg.GetChannel() + if msg.GetSubtopic() != "" { + topic = fmt.Sprintf("%s.%s", msg.GetChannel(), msg.GetSubtopic()) + } + pm := PageMetadata{ + Topic: topic, + Offset: 0, + Limit: -1, + } + page, err := ns.subs.RetrieveAll(ctx, pm) + if err != nil { + ns.errCh <- err + return + } + + var to []string + for _, sub := range page.Subscriptions { + to = append(to, sub.Contact) + } + if len(to) > 0 { + if err := ns.notifier.Notify(ns.from, to, msg); err != nil { + ns.errCh <- errors.Wrap(consumers.ErrNotify, err) + } + } +} + +func (ns *notifierService) Errors() <-chan error { + return ns.errCh +} diff --git a/consumers/notifiers/service_test.go b/consumers/notifiers/service_test.go new file mode 100644 index 000000000..58bd57715 --- /dev/null +++ b/consumers/notifiers/service_test.go @@ -0,0 +1,360 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package notifiers_test + +import ( + "context" + "fmt" + "testing" + + "github.com/absmach/supermq/consumers" + smqmocks "github.com/absmach/supermq/consumers/mocks" + "github.com/absmach/supermq/consumers/notifiers" + "github.com/absmach/supermq/consumers/notifiers/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + total = 100 + exampleUser1 = "token1" + exampleUser2 = "token2" + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" +) + +func newService() (notifiers.Service, *authnmocks.Authentication, *mocks.SubscriptionsRepository) { + repo := new(mocks.SubscriptionsRepository) + auth := new(authnmocks.Authentication) + notifier := new(smqmocks.Notifier) + idp := uuid.NewMock() + from := "exampleFrom" + return notifiers.New(auth, repo, idp, notifier, from), auth, repo +} + +func TestCreateSubscription(t *testing.T) { + svc, auth, repo := newService() + + cases := []struct { + desc string + token string + sub notifiers.Subscription + id string + err error + authenticateErr error + userID string + }{ + { + desc: "test success", + token: exampleUser1, + sub: notifiers.Subscription{Contact: exampleUser1, Topic: "valid.topic"}, + id: uuid.Prefix + fmt.Sprintf("%012d", 1), + err: nil, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test already existing", + token: exampleUser1, + sub: notifiers.Subscription{Contact: exampleUser1, Topic: "valid.topic"}, + id: "", + err: notifiers.ErrSubscriptionsAlreadyExists, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test with empty token", + token: "", + sub: notifiers.Subscription{Contact: exampleUser1, Topic: "valid.topic"}, + id: "", + err: svcerr.ErrAuthentication, + authenticateErr: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := auth.On("Authenticate", context.Background(), tc.token).Return(smqauthn.Session{UserID: tc.userID}, tc.authenticateErr) + repoCall1 := repo.On("Save", context.Background(), mock.Anything).Return(tc.id, tc.err) + id, err := svc.CreateSubscription(context.Background(), tc.token, tc.sub) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.id, id, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.id, id)) + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestViewSubscription(t *testing.T) { + svc, auth, repo := newService() + sub := notifiers.Subscription{ + Contact: exampleUser1, + Topic: "valid.topic", + ID: testsutil.GenerateUUID(t), + OwnerID: validID, + } + + cases := []struct { + desc string + token string + id string + sub notifiers.Subscription + err error + authenticateErr error + userID string + }{ + { + desc: "test success", + token: exampleUser1, + id: validID, + sub: sub, + err: nil, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test not existing", + token: exampleUser1, + id: "not_exist", + sub: notifiers.Subscription{}, + err: svcerr.ErrNotFound, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test with empty token", + token: "", + id: validID, + sub: notifiers.Subscription{}, + err: svcerr.ErrAuthentication, + authenticateErr: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := auth.On("Authenticate", context.Background(), tc.token).Return(smqauthn.Session{UserID: tc.userID}, tc.authenticateErr) + repoCall1 := repo.On("Retrieve", context.Background(), tc.id).Return(tc.sub, tc.err) + sub, err := svc.ViewSubscription(context.Background(), tc.token, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.sub, sub, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.sub, sub)) + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestListSubscriptions(t *testing.T) { + svc, auth, repo := newService() + sub := notifiers.Subscription{Contact: exampleUser1, OwnerID: exampleUser1} + topic := "topic.subtopic" + var subs []notifiers.Subscription + for i := 0; i < total; i++ { + tmp := sub + if i%2 == 0 { + tmp.Contact = exampleUser2 + tmp.OwnerID = exampleUser2 + } + tmp.Topic = fmt.Sprintf("%s.%d", topic, i) + tmp.ID = testsutil.GenerateUUID(t) + tmp.OwnerID = validID + subs = append(subs, tmp) + } + + var offsetSubs []notifiers.Subscription + for i := 20; i < 40; i += 2 { + offsetSubs = append(offsetSubs, subs[i]) + } + + cases := []struct { + desc string + token string + pageMeta notifiers.PageMetadata + page notifiers.Page + err error + authenticateErr error + userID string + }{ + { + desc: "test success", + token: exampleUser1, + pageMeta: notifiers.PageMetadata{ + Offset: 0, + Limit: 3, + }, + err: nil, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Offset: 0, + Limit: 3, + }, + Subscriptions: subs[:3], + Total: total, + }, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test not existing", + token: exampleUser1, + pageMeta: notifiers.PageMetadata{ + Limit: 10, + Contact: "empty@example.com", + }, + page: notifiers.Page{}, + err: svcerr.ErrNotFound, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test with empty token", + token: "", + pageMeta: notifiers.PageMetadata{ + Offset: 2, + Limit: 12, + Topic: "topic.subtopic.13", + }, + page: notifiers.Page{}, + err: svcerr.ErrAuthentication, + authenticateErr: svcerr.ErrAuthentication, + }, + { + desc: "test with topic", + token: exampleUser1, + pageMeta: notifiers.PageMetadata{ + Limit: 10, + Topic: fmt.Sprintf("%s.%d", topic, 4), + }, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Limit: 10, + Topic: fmt.Sprintf("%s.%d", topic, 4), + }, + Subscriptions: subs[4:5], + Total: 1, + }, + err: nil, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test with contact and offset", + token: exampleUser1, + pageMeta: notifiers.PageMetadata{ + Offset: 10, + Limit: 10, + Contact: exampleUser2, + }, + page: notifiers.Page{ + PageMetadata: notifiers.PageMetadata{ + Offset: 10, + Limit: 10, + Contact: exampleUser2, + }, + Subscriptions: offsetSubs, + Total: uint(total / 2), + }, + err: nil, + authenticateErr: nil, + userID: validID, + }, + } + + for _, tc := range cases { + repoCall := auth.On("Authenticate", context.Background(), tc.token).Return(smqauthn.Session{UserID: tc.userID}, tc.authenticateErr) + repoCall1 := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.page, tc.err) + page, err := svc.ListSubscriptions(context.Background(), tc.token, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.page, page, fmt.Sprintf("%s: got unexpected page\n", tc.desc)) + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestRemoveSubscription(t *testing.T) { + svc, auth, repo := newService() + sub := notifiers.Subscription{ + ID: testsutil.GenerateUUID(t), + } + + cases := []struct { + desc string + token string + id string + err error + authenticateErr error + userID string + }{ + { + desc: "test success", + token: exampleUser1, + id: sub.ID, + err: nil, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test not existing", + token: exampleUser1, + id: "not_exist", + err: svcerr.ErrNotFound, + authenticateErr: nil, + userID: validID, + }, + { + desc: "test with empty token", + token: "", + id: sub.ID, + err: svcerr.ErrAuthentication, + authenticateErr: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := auth.On("Authenticate", context.Background(), tc.token).Return(smqauthn.Session{UserID: tc.userID}, tc.authenticateErr) + repoCall1 := repo.On("Remove", context.Background(), tc.id).Return(tc.err) + err := svc.RemoveSubscription(context.Background(), tc.token, tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + repoCall1.Unset() + } +} + +func TestConsume(t *testing.T) { + svc, _, repo := newService() + msg := messaging.Message{ + Channel: "topic", + Subtopic: "subtopic", + } + errMsg := messaging.Message{ + Channel: "topic", + Subtopic: "subtopic-2", + } + + cases := []struct { + desc string + msg *messaging.Message + err error + }{ + { + desc: "test success", + msg: &msg, + err: nil, + }, + { + desc: "test fail", + msg: &errMsg, + err: consumers.ErrNotify, + }, + } + + for _, tc := range cases { + repoCall := repo.On("RetrieveAll", context.TODO(), mock.Anything).Return(notifiers.Page{}, tc.err) + err := svc.ConsumeBlocking(context.TODO(), tc.msg) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + } +} diff --git a/consumers/notifiers/smpp/README.md b/consumers/notifiers/smpp/README.md new file mode 100644 index 000000000..db2b9d8dd --- /dev/null +++ b/consumers/notifiers/smpp/README.md @@ -0,0 +1,51 @@ +# SMPP Notifier + +SMPP Notifier implements notifier for send SMS notifications. + +## Configuration + +The Subscription service using SMPP Notifier is configured using the environment variables presented in the +following table. Note that any unset variables will be replaced with their +default values. + +| Variable | Description | Default | +| --------------------------------- | --------------------------------------------------------------------------------- | ------------------------------ | +| MG_SMPP_NOTIFIER_LOG_LEVEL | Log level for SMPP Notifier (debug, info, warn, error) | info | +| MG_SMPP_NOTIFIER_FROM_ADDRESS | From address for SMS notifications | | +| MG_SMPP_NOTIFIER_CONFIG_PATH | Config file path with Message broker subjects list, payload type and content-type | /config.toml | +| MG_SMPP_NOTIFIER_HTTP_HOST | Service HTTP host | localhost | +| MG_SMPP_NOTIFIER_HTTP_PORT | Service HTTP port | 9014 | +| MG_SMPP_NOTIFIER_HTTP_SERVER_CERT | Service HTTP server certificate path | "" | +| MG_SMPP_NOTIFIER_HTTP_SERVER_KEY | Service HTTP server key | "" | +| MG_SMPP_NOTIFIER_DB_HOST | Database host address | localhost | +| MG_SMPP_NOTIFIER_DB_PORT | Database host port | 5432 | +| MG_SMPP_NOTIFIER_DB_USER | Database user | magistrala | +| MG_SMPP_NOTIFIER_DB_PASS | Database password | magistrala | +| MG_SMPP_NOTIFIER_DB_NAME | Name of the database used by the service | subscriptions | +| MG_SMPP_NOTIFIER_DB_SSL_MODE | DB connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| MG_SMPP_NOTIFIER_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| MG_SMPP_NOTIFIER_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| MG_SMPP_NOTIFIER_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| MG_SMPP_ADDRESS | SMPP address [host:port] | | +| MG_SMPP_USERNAME | SMPP Username | | +| MG_SMPP_PASSWORD | SMPP Password | | +| MG_SMPP_SYSTEM_TYPE | SMPP System Type | | +| MG_SMPP_SRC_ADDR_TON | SMPP source address TON | | +| MG_SMPP_DST_ADDR_TON | SMPP destination address TON | | +| MG_SMPP_SRC_ADDR_NPI | SMPP source address NPI | | +| MG_SMPP_DST_ADDR_NPI | SMPP destination address NPI | | +| MG_AUTH_GRPC_URL | Auth service gRPC URL | localhost:7001 | +| MG_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s | +| MG_AUTH_GRPC_CLIENT_TLS | Auth client TLS flag | false | +| MG_AUTH_GRPC_CA_CERT | Path to Auth client CA certs in pem format | "" | +| MG_MESSAGE_BROKER_URL | Message broker URL | nats://127.0.0.1:4222 | +| MG_JAEGER_URL | Jaeger server URL | http://jaeger:14268/api/traces | +| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | +| MG_SMPP_NOTIFIER_INSTANCE_ID | SMPP Notifier instance ID | "" | + +## Usage + +Starting service will start consuming messages and sending SMS when a message is received. + +[doc]: https://docs.magistrala.absmach.eu + diff --git a/consumers/notifiers/smpp/config.go b/consumers/notifiers/smpp/config.go new file mode 100644 index 000000000..a8af3a6cf --- /dev/null +++ b/consumers/notifiers/smpp/config.go @@ -0,0 +1,21 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package smpp + +import ( + "crypto/tls" +) + +// Config represents SMPP transmitter configuration. +type Config struct { + Address string `env:"MG_SMPP_ADDRESS" envDefault:""` + Username string `env:"MG_SMPP_USERNAME" envDefault:""` + Password string `env:"MG_SMPP_PASSWORD" envDefault:""` + SystemType string `env:"MG_SMPP_SYSTEM_TYPE" envDefault:""` + SourceAddrTON uint8 `env:"MG_SMPP_SRC_ADDR_TON" envDefault:"0"` + SourceAddrNPI uint8 `env:"MG_SMPP_DST_ADDR_TON" envDefault:"0"` + DestAddrTON uint8 `env:"MG_SMPP_SRC_ADDR_NPI" envDefault:"0"` + DestAddrNPI uint8 `env:"MG_SMPP_DST_ADDR_NPI" envDefault:"0"` + TLS *tls.Config +} diff --git a/consumers/notifiers/smpp/doc.go b/consumers/notifiers/smpp/doc.go new file mode 100644 index 000000000..c81f3e75b --- /dev/null +++ b/consumers/notifiers/smpp/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package smpp contains the domain concept definitions needed to +// support Magistrala SMS notifications. +package smpp diff --git a/consumers/notifiers/smpp/notifier.go b/consumers/notifiers/smpp/notifier.go new file mode 100644 index 000000000..0ee3cfbc8 --- /dev/null +++ b/consumers/notifiers/smpp/notifier.go @@ -0,0 +1,67 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package smpp + +import ( + "time" + + "github.com/absmach/supermq/consumers" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/transformers" + "github.com/absmach/supermq/pkg/transformers/json" + "github.com/fiorix/go-smpp/smpp" + "github.com/fiorix/go-smpp/smpp/pdu/pdufield" + "github.com/fiorix/go-smpp/smpp/pdu/pdutext" +) + +var _ consumers.Notifier = (*notifier)(nil) + +type notifier struct { + transmitter *smpp.Transmitter + transformer transformers.Transformer + sourceAddrTON uint8 + sourceAddrNPI uint8 + destAddrTON uint8 + destAddrNPI uint8 +} + +// New instantiates SMTP message notifier. +func New(cfg Config) consumers.Notifier { + t := &smpp.Transmitter{ + Addr: cfg.Address, + User: cfg.Username, + Passwd: cfg.Password, + SystemType: cfg.SystemType, + RespTimeout: 3 * time.Second, + } + t.Bind() + ret := ¬ifier{ + transmitter: t, + transformer: json.New([]json.TimeField{}), + sourceAddrTON: cfg.SourceAddrTON, + destAddrTON: cfg.DestAddrTON, + sourceAddrNPI: cfg.SourceAddrNPI, + destAddrNPI: cfg.DestAddrNPI, + } + return ret +} + +func (n *notifier) Notify(from string, to []string, msg *messaging.Message) error { + send := &smpp.ShortMessage{ + Src: from, + DstList: to, + Validity: 10 * time.Minute, + SourceAddrTON: n.sourceAddrTON, + DestAddrTON: n.destAddrTON, + SourceAddrNPI: n.sourceAddrNPI, + DestAddrNPI: n.destAddrNPI, + Text: pdutext.Raw(msg.GetPayload()), + Register: pdufield.NoDeliveryReceipt, + } + _, err := n.transmitter.Submit(send) + if err != nil { + return err + } + return nil +} diff --git a/consumers/notifiers/smtp/notifier.go b/consumers/notifiers/smtp/notifier.go new file mode 100644 index 000000000..12fc94939 --- /dev/null +++ b/consumers/notifiers/smtp/notifier.go @@ -0,0 +1,40 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package smtp + +import ( + "fmt" + + "github.com/absmach/supermq/consumers" + "github.com/absmach/supermq/internal/email" + "github.com/absmach/supermq/pkg/messaging" +) + +const ( + footer = "Sent by SuperMQ SMTP Notification" + contentTemplate = "A publisher with an id %s sent the message over %s with the following values \n %s" +) + +var _ consumers.Notifier = (*notifier)(nil) + +type notifier struct { + agent *email.Agent +} + +// New instantiates SMTP message notifier. +func New(agent *email.Agent) consumers.Notifier { + return ¬ifier{agent: agent} +} + +func (n *notifier) Notify(from string, to []string, msg *messaging.Message) error { + subject := fmt.Sprintf(`Notification for Channel %s`, msg.GetChannel()) + if msg.GetSubtopic() != "" { + subject = fmt.Sprintf("%s and subtopic %s", subject, msg.GetSubtopic()) + } + + values := string(msg.GetPayload()) + content := fmt.Sprintf(contentTemplate, msg.GetPublisher(), msg.GetProtocol(), values) + + return n.agent.Send(to, from, subject, "", "", content, footer, map[string][]byte{}) +} diff --git a/consumers/notifiers/subscriptions.go b/consumers/notifiers/subscriptions.go new file mode 100644 index 000000000..0d2166da5 --- /dev/null +++ b/consumers/notifiers/subscriptions.go @@ -0,0 +1,46 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package notifiers + +import "context" + +// Subscription represents a user Subscription. +type Subscription struct { + ID string + OwnerID string + Contact string + Topic string +} + +// Page represents page metadata with content. +type Page struct { + PageMetadata + Total uint + Subscriptions []Subscription +} + +// PageMetadata contains page metadata that helps navigation. +type PageMetadata struct { + Offset uint + // Limit values less than 0 indicate no limit. + Limit int + Topic string + Contact string +} + +// SubscriptionsRepository specifies a Subscription persistence API. +type SubscriptionsRepository interface { + // Save persists a subscription. Successful operation is indicated by non-nil + // error response. + Save(ctx context.Context, sub Subscription) (string, error) + + // Retrieve retrieves the subscription for the given id. + Retrieve(ctx context.Context, id string) (Subscription, error) + + // RetrieveAll retrieves all the subscriptions for the given page metadata. + RetrieveAll(ctx context.Context, pm PageMetadata) (Page, error) + + // Remove removes the subscription for the given ID. + Remove(ctx context.Context, id string) error +} diff --git a/pkg/messaging/rabbitmq/tracing/doc.go b/consumers/notifiers/tracing/doc.go similarity index 72% rename from pkg/messaging/rabbitmq/tracing/doc.go rename to consumers/notifiers/tracing/doc.go index 2f3ee3830..da7ad43bb 100644 --- a/pkg/messaging/rabbitmq/tracing/doc.go +++ b/consumers/notifiers/tracing/doc.go @@ -1,11 +1,11 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -// Package tracing provides tracing instrumentation for SuperMQ clients policies service. +// Package tracing provides tracing instrumentation for SuperMQ WebSocket adapter service. // -// This package provides tracing middleware for SuperMQ clients policies service. +// This package provides tracing middleware for SuperMQ WebSocket adapter service. // It can be used to trace incoming requests and add tracing capabilities to -// SuperMQ clients policies service. +// SuperMQ WebSocket adapter service. // // For more details about tracing instrumentation for SuperMQ messaging refer // to the documentation at https://docs.supermq.absmach.eu/tracing/. diff --git a/consumers/notifiers/tracing/subscriptions.go b/consumers/notifiers/tracing/subscriptions.go new file mode 100644 index 000000000..8236e7635 --- /dev/null +++ b/consumers/notifiers/tracing/subscriptions.go @@ -0,0 +1,73 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package tracing contains middlewares that will add spans +// to existing traces. +package tracing + +import ( + "context" + + "github.com/absmach/supermq/consumers/notifiers" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +const ( + saveOp = "save_op" + retrieveOp = "retrieve_op" + retrieveAllOp = "retrieve_all_op" + removeOp = "remove_op" +) + +var _ notifiers.SubscriptionsRepository = (*subRepositoryMiddleware)(nil) + +type subRepositoryMiddleware struct { + tracer trace.Tracer + repo notifiers.SubscriptionsRepository +} + +// New instantiates a new Subscriptions repository that +// tracks request and their latency, and adds spans to context. +func New(tracer trace.Tracer, repo notifiers.SubscriptionsRepository) notifiers.SubscriptionsRepository { + return subRepositoryMiddleware{ + tracer: tracer, + repo: repo, + } +} + +// Save traces the "Save" operation of the wrapped Subscriptions repository. +func (urm subRepositoryMiddleware) Save(ctx context.Context, sub notifiers.Subscription) (string, error) { + ctx, span := urm.tracer.Start(ctx, saveOp, trace.WithAttributes( + attribute.String("id", sub.ID), + attribute.String("contact", sub.Contact), + attribute.String("topic", sub.Topic), + )) + defer span.End() + + return urm.repo.Save(ctx, sub) +} + +// Retrieve traces the "Retrieve" operation of the wrapped Subscriptions repository. +func (urm subRepositoryMiddleware) Retrieve(ctx context.Context, id string) (notifiers.Subscription, error) { + ctx, span := urm.tracer.Start(ctx, retrieveOp, trace.WithAttributes(attribute.String("id", id))) + defer span.End() + + return urm.repo.Retrieve(ctx, id) +} + +// RetrieveAll traces the "RetrieveAll" operation of the wrapped Subscriptions repository. +func (urm subRepositoryMiddleware) RetrieveAll(ctx context.Context, pm notifiers.PageMetadata) (notifiers.Page, error) { + ctx, span := urm.tracer.Start(ctx, retrieveAllOp) + defer span.End() + + return urm.repo.RetrieveAll(ctx, pm) +} + +// Remove traces the "Remove" operation of the wrapped Subscriptions repository. +func (urm subRepositoryMiddleware) Remove(ctx context.Context, id string) error { + ctx, span := urm.tracer.Start(ctx, removeOp, trace.WithAttributes(attribute.String("id", id))) + defer span.End() + + return urm.repo.Remove(ctx, id) +} diff --git a/consumers/tracing/consumers.go b/consumers/tracing/consumers.go new file mode 100644 index 000000000..2a799259a --- /dev/null +++ b/consumers/tracing/consumers.go @@ -0,0 +1,132 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package tracing + +import ( + "context" + "fmt" + + "github.com/absmach/supermq/consumers" + "github.com/absmach/supermq/pkg/server" + smqjson "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +const ( + consumeBlockingOP = "retrieve_blocking" // This is not specified in the open telemetry spec. + consumeAsyncOP = "retrieve_async" // This is not specified in the open telemetry spec. +) + +var defaultAttributes = []attribute.KeyValue{ + attribute.String("messaging.system", "nats"), + attribute.Bool("messaging.destination.anonymous", false), + attribute.String("messaging.destination.template", "channels/{channelID}/messages/*"), + attribute.Bool("messaging.destination.temporary", true), + attribute.String("network.protocol.name", "nats"), + attribute.String("network.protocol.version", "2.2.4"), + attribute.String("network.transport", "tcp"), + attribute.String("network.type", "ipv4"), +} + +var ( + _ consumers.AsyncConsumer = (*tracingMiddlewareAsync)(nil) + _ consumers.BlockingConsumer = (*tracingMiddlewareBlock)(nil) +) + +type tracingMiddlewareAsync struct { + consumer consumers.AsyncConsumer + tracer trace.Tracer + host server.Config +} +type tracingMiddlewareBlock struct { + consumer consumers.BlockingConsumer + tracer trace.Tracer + host server.Config +} + +// NewAsync creates a new traced consumers.AsyncConsumer service. +func NewAsync(tracer trace.Tracer, consumerAsync consumers.AsyncConsumer, host server.Config) consumers.AsyncConsumer { + return &tracingMiddlewareAsync{ + consumer: consumerAsync, + tracer: tracer, + host: host, + } +} + +// NewBlocking creates a new traced consumers.BlockingConsumer service. +func NewBlocking(tracer trace.Tracer, consumerBlock consumers.BlockingConsumer, host server.Config) consumers.BlockingConsumer { + return &tracingMiddlewareBlock{ + consumer: consumerBlock, + tracer: tracer, + host: host, + } +} + +// ConsumeBlocking traces consume operations for message/s consumed. +func (tm *tracingMiddlewareBlock) ConsumeBlocking(ctx context.Context, messages any) error { + var span trace.Span + switch m := messages.(type) { + case smqjson.Messages: + if len(m.Data) > 0 { + firstMsg := m.Data[0] + ctx, span = createSpan(ctx, consumeBlockingOP, firstMsg.Publisher, firstMsg.Channel, firstMsg.Subtopic, len(m.Data), tm.host, trace.SpanKindConsumer, tm.tracer) + defer span.End() + } + case []senml.Message: + if len(m) > 0 { + firstMsg := m[0] + ctx, span = createSpan(ctx, consumeBlockingOP, firstMsg.Publisher, firstMsg.Channel, firstMsg.Subtopic, len(m), tm.host, trace.SpanKindConsumer, tm.tracer) + defer span.End() + } + } + return tm.consumer.ConsumeBlocking(ctx, messages) +} + +// ConsumeAsync traces consume operations for message/s consumed. +func (tm *tracingMiddlewareAsync) ConsumeAsync(ctx context.Context, messages any) { + var span trace.Span + switch m := messages.(type) { + case smqjson.Messages: + if len(m.Data) > 0 { + firstMsg := m.Data[0] + ctx, span = createSpan(ctx, consumeAsyncOP, firstMsg.Publisher, firstMsg.Channel, firstMsg.Subtopic, len(m.Data), tm.host, trace.SpanKindConsumer, tm.tracer) + defer span.End() + } + case []senml.Message: + if len(m) > 0 { + firstMsg := m[0] + ctx, span = createSpan(ctx, consumeAsyncOP, firstMsg.Publisher, firstMsg.Channel, firstMsg.Subtopic, len(m), tm.host, trace.SpanKindConsumer, tm.tracer) + defer span.End() + } + } + tm.consumer.ConsumeAsync(ctx, messages) +} + +// Errors traces async consume errors. +func (tm *tracingMiddlewareAsync) Errors() <-chan error { + return tm.consumer.Errors() +} + +func createSpan(ctx context.Context, operation, clientID, topic, subTopic string, noMessages int, cfg server.Config, spanKind trace.SpanKind, tracer trace.Tracer) (context.Context, trace.Span) { + subject := fmt.Sprintf("channels.%s.messages", topic) + if subTopic != "" { + subject = fmt.Sprintf("%s.%s", subject, subTopic) + } + spanName := fmt.Sprintf("%s %s", subject, operation) + + kvOpts := []attribute.KeyValue{ + attribute.String("messaging.operation", operation), + attribute.String("messaging.client_id", clientID), + attribute.String("messaging.destination.name", subject), + attribute.String("server.address", cfg.Host), + attribute.String("server.socket.port", cfg.Port), + attribute.Int("messaging.batch.message_count", noMessages), + } + + kvOpts = append(kvOpts, defaultAttributes...) + + return tracer.Start(ctx, spanName, trace.WithAttributes(kvOpts...), trace.WithSpanKind(spanKind)) +} diff --git a/consumers/writers/README.md b/consumers/writers/README.md new file mode 100644 index 000000000..d0945b762 --- /dev/null +++ b/consumers/writers/README.md @@ -0,0 +1,272 @@ +# Writers + +Writers consume messages from the message broker, normalize them (SenML or JSON), and persist them to a storage backend. Magistrala provides two writer services: + +- **Postgres writer**: Stores data in PostgreSQL. +- **Timescale writer**: Stores data in TimescaleDB and uses hypertables for time-series workloads. + +Writers are optional services and are treated as plugins. Core services and the message broker must be running first. For platform dependencies, see [Docker Compose](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml). + +## Configuration + +Values shown are from [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env) and the add-on compose files in `docker/addons/*-writer/docker-compose.yaml`. + +### Postgres writer + +#### Postgres Service endpoints + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_POSTGRES_WRITER_LOG_LEVEL` | Service log level | `debug` | +| `MG_POSTGRES_WRITER_CONFIG_PATH` | Config file path (topics/transformer) | `/config.toml` | +| `MG_POSTGRES_WRITER_HTTP_HOST` | HTTP host | `postgres-writer` | +| `MG_POSTGRES_WRITER_HTTP_PORT` | HTTP port | `9007` | +| `MG_POSTGRES_WRITER_HTTP_SERVER_CERT` | HTTPS server certificate path | "" | +| `MG_POSTGRES_WRITER_HTTP_SERVER_KEY` | HTTPS server key path | "" | +| `MG_POSTGRES_WRITER_INSTANCE_ID` | Instance ID | "" | + +#### Postgres Database + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_POSTGRES_HOST` | PostgreSQL host | `postgres` | +| `MG_POSTGRES_PORT` | PostgreSQL port | `5432` | +| `MG_POSTGRES_USER` | PostgreSQL user | `supermq` | +| `MG_POSTGRES_PASS` | PostgreSQL password | `supermq` | +| `MG_POSTGRES_NAME` | PostgreSQL database name | `messages` | +| `MG_POSTGRES_SSL_MODE` | PostgreSQL SSL mode | `disable` | +| `MG_POSTGRES_SSL_CERT` | PostgreSQL SSL client cert | "" | +| `MG_POSTGRES_SSL_KEY` | PostgreSQL SSL client key | "" | +| `MG_POSTGRES_SSL_ROOT_CERT` | PostgreSQL SSL root cert | "" | + +#### Postgres Message broker and observability + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_MESSAGE_BROKER_URL` | Message broker URL | `nats://nats:4222` | +| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` | +| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | `true` | + +### Timescale writer + +#### Timescale Service endpoints + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_TIMESCALE_WRITER_LOG_LEVEL` | Service log level | `debug` | +| `MG_TIMESCALE_WRITER_CONFIG_PATH` | Config file path (topics/transformer) | `/config.toml` | +| `MG_TIMESCALE_WRITER_HTTP_HOST` | HTTP host | `timescale-writer` | +| `MG_TIMESCALE_WRITER_HTTP_PORT` | HTTP port | `9012` | +| `MG_TIMESCALE_WRITER_HTTP_SERVER_CERT` | HTTPS server certificate path | "" | +| `MG_TIMESCALE_WRITER_HTTP_SERVER_KEY` | HTTPS server key path | "" | +| `MG_TIMESCALE_WRITER_INSTANCE_ID` | Instance ID | "" | + +#### Timescale Database + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_TIMESCALE_HOST` | TimescaleDB host | `timescale` | +| `MG_TIMESCALE_PORT` | TimescaleDB port | `5432` | +| `MG_TIMESCALE_USER` | TimescaleDB user | `supermq` | +| `MG_TIMESCALE_PASS` | TimescaleDB password | `supermq` | +| `MG_TIMESCALE_NAME` | TimescaleDB database name | `supermq` | +| `MG_TIMESCALE_SSL_MODE` | TimescaleDB SSL mode | `disable` | +| `MG_TIMESCALE_SSL_CERT` | TimescaleDB SSL client cert | "" | +| `MG_TIMESCALE_SSL_KEY` | TimescaleDB SSL client key | "" | +| `MG_TIMESCALE_SSL_ROOT_CERT` | TimescaleDB SSL root cert | "" | + +#### Timescale Message broker and observability + +Timescale writer uses the same broker and telemetry variables listed for Postgres writer. + +### Writer config file + +Both writers read a config file defined by `*_WRITER_CONFIG_PATH`. The default add-on config files are: + +- `docker/addons/postgres-writer/config.toml` +- `docker/addons/timescale-writer/config.toml` + +The config file controls subscription topics and optional transformer settings for both writers. The default Timescale add-on config omits the transformer section and relies on the built-in defaults: + +```toml +["subscriber"] +topics = ["writers.>"] + +[transformer] +format = "senml" +content_type = "application/senml+json" +time_fields = [ + { field_name = "seconds_key", field_format = "unix", location = "UTC" }, + { field_name = "millis_key", field_format = "unix_ms", location = "UTC" }, + { field_name = "micros_key", field_format = "unix_us", location = "UTC" }, + { field_name = "nanos_key", field_format = "unix_ns", location = "UTC" } +] +``` + +The topic filter uses `writers.*` syntax in the config file for both backends. Writers do not expose broker mode, delivery policy, or consumer-group settings in this file. They always consume through the stream-backed broker adapter in `consumers/writers/brokers`: + +- NATS builds use JetStream streams with durable consumers. +- FluxMQ builds publish to and consume from the `writers` stream queue while preserving the same `writers.>` config syntax. + +## Features + +- **Message persistence**: Stores incoming SenML messages into PostgreSQL or TimescaleDB. +- **JSON payload support**: Saves JSON payloads into dynamically created tables. +- **Stream-backed ingestion**: Consumes through NATS JetStream durable consumers or FluxMQ stream queues. +- **Configurable subscription**: Limits ingestion to specific `writers.*` topics. +- **Observability**: Exposes `/health` and `/metrics` endpoints, with Jaeger tracing. + +## Architecture + +### Runtime flow + +1. The rules engine publishes writer messages under `writers.*`. +2. The writer loads `config.toml` to select topic filters and transformer settings. +3. The broker adapter consumes from the underlying stream-backed implementation. +4. The consumer converts messages to SenML or JSON payloads. +5. The repository writes records to the target database. + +### Components + +- **Message broker adapter**: `consumers/writers/brokers` (NATS JetStream or FluxMQ stream queues). +- **Writer services**: `consumers/writers/postgres` and `consumers/writers/timescale`. +- **HTTP API**: `consumers/writers/api` exposes `/health` and `/metrics`. +- **Migrations**: `consumers/writers/*/init.go` defines the schema and indexes. + +### PostgreSQL schema (SenML messages) + +Defined in `consumers/writers/postgres/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `id` | `UUID` | Message ID | +| `channel` | `UUID` | Channel ID | +| `subtopic` | `VARCHAR(254)` | Subtopic | +| `publisher` | `UUID` | Publisher ID | +| `protocol` | `TEXT` | Protocol name | +| `name` | `TEXT` | SenML name | +| `unit` | `TEXT` | SenML unit | +| `value` | `FLOAT` | Numeric value | +| `string_value` | `TEXT` | String value | +| `bool_value` | `BOOL` | Boolean value | +| `data_value` | `BYTEA` | Data value | +| `sum` | `FLOAT` | Sum value | +| `time` | `FLOAT` | Measurement time | +| `update_time` | `FLOAT` | Update time | + +Primary key: `(time, publisher, subtopic, name)` + +### TimescaleDB schema (SenML messages) + +Defined in `consumers/writers/timescale/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `time` | `BIGINT` | Measurement time | +| `channel` | `UUID` | Channel ID | +| `subtopic` | `VARCHAR(254)` | Subtopic | +| `publisher` | `VARCHAR(254)` | Publisher ID | +| `protocol` | `TEXT` | Protocol name | +| `name` | `VARCHAR(254)` | SenML name | +| `unit` | `TEXT` | SenML unit | +| `value` | `FLOAT` | Numeric value | +| `string_value` | `TEXT` | String value | +| `bool_value` | `BOOL` | Boolean value | +| `data_value` | `BYTEA` | Data value | +| `sum` | `FLOAT` | Sum value | +| `update_time` | `FLOAT` | Update time | + +Primary key: `(time, channel, subtopic, protocol, publisher, name)` + +Timescale writer creates a hypertable on `messages` and adds time-series indexes for common query paths. + +### JSON payload tables (dynamic) + +If the transformer emits JSON payloads, the writers create a table named after the payload format: + +Postgres JSON table: +`id UUID`, `created BIGINT`, `channel VARCHAR(254)`, `subtopic VARCHAR(254)`, `publisher VARCHAR(254)`, `protocol TEXT`, `payload JSONB` (PK: `id`) + +Timescale JSON table: +`created BIGINT`, `channel VARCHAR(254)`, `subtopic VARCHAR(254)`, `publisher VARCHAR(254)`, `protocol TEXT`, `payload JSONB` (PK: `created`, `publisher`, `subtopic`) + +## Deployment + +### Build and run locally + +Postgres writer: + +```bash +make postgres-writer + +MG_POSTGRES_WRITER_LOG_LEVEL=debug \ +MG_POSTGRES_WRITER_CONFIG_PATH=./docker/addons/postgres-writer/config.toml \ +MG_POSTGRES_WRITER_HTTP_PORT=9007 \ +MG_POSTGRES_HOST=localhost \ +MG_POSTGRES_PORT=5432 \ +MG_POSTGRES_USER=supermq \ +MG_POSTGRES_PASS=supermq \ +MG_POSTGRES_NAME=messages \ +MG_MESSAGE_BROKER_URL=nats://localhost:4222 \ +MG_JAEGER_URL=http://localhost:4318/v1/traces \ +./build/postgres-writer +``` + +Timescale writer: + +```bash +make timescale-writer + +MG_TIMESCALE_WRITER_LOG_LEVEL=debug \ +MG_TIMESCALE_WRITER_CONFIG_PATH=./docker/addons/timescale-writer/config.toml \ +MG_TIMESCALE_WRITER_HTTP_PORT=9012 \ +MG_TIMESCALE_HOST=localhost \ +MG_TIMESCALE_PORT=5432 \ +MG_TIMESCALE_USER=supermq \ +MG_TIMESCALE_PASS=supermq \ +MG_TIMESCALE_NAME=supermq \ +MG_MESSAGE_BROKER_URL=nats://localhost:4222 \ +MG_JAEGER_URL=http://localhost:4318/v1/traces \ +./build/timescale-writer +``` + +### Docker Compose + +Postgres writer add-on: + +```bash +docker compose -f docker/docker-compose.yaml -f docker/addons/postgres-writer/docker-compose.yaml up +``` + +Timescale writer: + +```bash +docker compose -f docker/docker-compose.yaml up +``` + +### Health check + +```bash +curl -X GET http://localhost:9007/health \ + -H "accept: application/health+json" +``` + +## Testing + +```bash +go test ./consumers/writers/... +``` + +## Usage + +Writers do not expose a message ingestion API. Messages are written via the message broker, and writers consume them through the stream-backed broker adapter. The HTTP API provides only health and metrics endpoints. + +| Endpoint | Description | +| --- | --- | +| `GET /health` | Service health check | +| `GET /metrics` | Prometheus metrics | + +For an in-depth explanation of Writers, see the [official documentation][doc]. + +[doc]: https://docs.magistrala.absmach.eu/dev-guide/consumers/ diff --git a/consumers/writers/api/doc.go b/consumers/writers/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/consumers/writers/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/consumers/writers/api/logging.go b/consumers/writers/api/logging.go new file mode 100644 index 000000000..2f0cd67b9 --- /dev/null +++ b/consumers/writers/api/logging.go @@ -0,0 +1,47 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package api + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/consumers" +) + +var _ consumers.BlockingConsumer = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + consumer consumers.BlockingConsumer +} + +// LoggingMiddleware adds logging facilities to the adapter. +func LoggingMiddleware(consumer consumers.BlockingConsumer, logger *slog.Logger) consumers.BlockingConsumer { + return &loggingMiddleware{ + logger: logger, + consumer: consumer, + } +} + +// ConsumeBlocking logs the consume request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) ConsumeBlocking(ctx context.Context, msgs any) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Blocking consumer failed to consume messages successfully", args...) + return + } + lm.logger.Info("Blocking consumer consumed messages successfully", args...) + }(time.Now()) + + return lm.consumer.ConsumeBlocking(ctx, msgs) +} diff --git a/consumers/writers/api/metrics.go b/consumers/writers/api/metrics.go new file mode 100644 index 000000000..2b1a6c547 --- /dev/null +++ b/consumers/writers/api/metrics.go @@ -0,0 +1,41 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package api + +import ( + "context" + "time" + + "github.com/absmach/supermq/consumers" + "github.com/go-kit/kit/metrics" +) + +var _ consumers.BlockingConsumer = (*metricsMiddleware)(nil) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + consumer consumers.BlockingConsumer +} + +// MetricsMiddleware returns new message repository +// with Save method wrapped to expose metrics. +func MetricsMiddleware(consumer consumers.BlockingConsumer, counter metrics.Counter, latency metrics.Histogram) consumers.BlockingConsumer { + return &metricsMiddleware{ + counter: counter, + latency: latency, + consumer: consumer, + } +} + +// ConsumeBlocking instruments ConsumeBlocking method with metrics. +func (mm *metricsMiddleware) ConsumeBlocking(ctx context.Context, msgs any) error { + defer func(begin time.Time) { + mm.counter.With("method", "consume").Add(1) + mm.latency.With("method", "consume").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.consumer.ConsumeBlocking(ctx, msgs) +} diff --git a/consumers/writers/api/transport.go b/consumers/writers/api/transport.go new file mode 100644 index 000000000..7ae45b4c1 --- /dev/null +++ b/consumers/writers/api/transport.go @@ -0,0 +1,21 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "net/http" + + "github.com/absmach/supermq" + "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// MakeHandler returns a HTTP API handler with health check and metrics. +func MakeHandler(svcName, instanceID string) http.Handler { + r := chi.NewRouter() + r.Get("/health", supermq.Health(svcName, instanceID)) + r.Handle("/metrics", promhttp.Handler()) + + return r +} diff --git a/consumers/writers/brokers/brokers_fluxmq.go b/consumers/writers/brokers/brokers_fluxmq.go new file mode 100644 index 000000000..9e8266dff --- /dev/null +++ b/consumers/writers/brokers/brokers_fluxmq.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build msg_fluxmq +// +build msg_fluxmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/fluxmq" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + AllTopic = "writers/#" + + prefix = "writers" +) + +var cfg = jetstream.StreamConfig{ + Name: "writers", + Description: "SuperMQ Rules Engine stream for handling internal messages", + Subjects: []string{"writers/#"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, +} + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix(prefix), broker.JSStreamConfig(cfg), broker.ConnectionName("writers-msg-pubsub")) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewPublisher(ctx context.Context, url string) (messaging.Publisher, error) { + pb, err := broker.NewPublisher(ctx, url, broker.Prefix(prefix), broker.JSStreamConfig(cfg), broker.ConnectionName("writers-msg-pub")) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/consumers/writers/brokers/brokers_nats.go b/consumers/writers/brokers/brokers_nats.go new file mode 100644 index 000000000..18bfc19c7 --- /dev/null +++ b/consumers/writers/brokers/brokers_nats.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !msg_fluxmq && !msg_rabbitmq && !rabbitmq +// +build !msg_fluxmq,!msg_rabbitmq,!rabbitmq + +package brokers + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/messaging" + broker "github.com/absmach/supermq/pkg/messaging/nats" + "github.com/nats-io/nats.go/jetstream" +) + +const ( + AllTopic = "writers.>" + + prefix = "writers" +) + +var cfg = jetstream.StreamConfig{ + Name: "writers", + Description: "SuperMQ Rules Engine stream for handling internal messages", + Subjects: []string{"writers.>"}, + Retention: jetstream.LimitsPolicy, + MaxMsgsPerSubject: 1e6, + MaxAge: time.Hour * 24, + MaxMsgSize: 1024 * 1024, + Discard: jetstream.DiscardOld, + Storage: jetstream.FileStorage, +} + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger) (messaging.PubSub, error) { + pb, err := broker.NewPubSub(ctx, url, logger, broker.Prefix(prefix), broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewPublisher(ctx context.Context, url string) (messaging.Publisher, error) { + pb, err := broker.NewPublisher(ctx, url, broker.Prefix(prefix), broker.JSStreamConfig(cfg)) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/consumers/writers/doc.go b/consumers/writers/doc.go new file mode 100644 index 000000000..644079245 --- /dev/null +++ b/consumers/writers/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package writers contain the domain concept definitions needed to +// support SuperMQ writer services functionality. +package writers diff --git a/consumers/writers/postgres/README.md b/consumers/writers/postgres/README.md new file mode 100644 index 000000000..3097a66ce --- /dev/null +++ b/consumers/writers/postgres/README.md @@ -0,0 +1,77 @@ +# Postgres writer + +Postgres writer provides message repository implementation for Postgres. + +## Configuration + +The service is configured using the environment variables presented in the +following table. Note that any unset variables will be replaced with their +default values. + +| Variable | Description | Default | +| ------------------------------------ | --------------------------------------------------------------------------------- | ---------------------------- | +| MG_POSTGRES_WRITER_LOG_LEVEL | Service log level | info | +| MG_POSTGRES_WRITER_CONFIG_PATH | Config file path with Message broker subjects list, payload type and content-type | /config.toml | +| MG_POSTGRES_WRITER_HTTP_HOST | Service HTTP host | localhost | +| MG_POSTGRES_WRITER_HTTP_PORT | Service HTTP port | 9010 | +| MG_POSTGRES_WRITER_HTTP_SERVER_CERT | Service HTTP server certificate path | "" | +| MG_POSTGRES_WRITER_HTTP_SERVER_KEY | Service HTTP server key | "" | +| MG_POSTGRES_HOST | Postgres DB host | postgres | +| MG_POSTGRES_PORT | Postgres DB port | 5432 | +| MG_POSTGRES_USER | Postgres user | supermq | +| MG_POSTGRES_PASS | Postgres password | supermq | +| MG_POSTGRES_NAME | Postgres database name | messages | +| MG_POSTGRES_SSL_MODE | Postgres SSL mode | disabled | +| MG_POSTGRES_SSL_CERT | Postgres SSL certificate path | "" | +| MG_POSTGRES_SSL_KEY | Postgres SSL key | "" | +| MG_POSTGRES_SSL_ROOT_CERT | Postgres SSL root certificate path | "" | +| MG_MESSAGE_BROKER_URL | Message broker instance URL | nats://localhost:4222 | +| MG_JAEGER_URL | Jaeger server URL | http://jaeger:4318/v1/traces | +| MG_SEND_TELEMETRY | Send telemetry to supermq call home server | true | +| MG_POSTGRES_WRITER_INSTANCE_ID | Service instance ID | "" | + +## Deployment + +The service itself is distributed as Docker container. Check the [`postgres-writer`](https://github.com/absmach/supermq/blob/main/docker/addons/postgres-writer/docker-compose.yaml#L34-L59) service section in docker-compose file to see how service is deployed. + +To start the service, execute the following shell script: + +```bash +# download the latest version of the service +git clone https://github.com/absmach/supermq + +cd supermq + +# compile the postgres writer +make postgres-writer + +# copy binary to bin +make install + +# Set the environment variables and run the service +MG_POSTGRES_WRITER_LOG_LEVEL=[Service log level] \ +MG_POSTGRES_WRITER_CONFIG_PATH=[Config file path with Message broker subjects list, payload type and content-type] \ +MG_POSTGRES_WRITER_HTTP_HOST=[Service HTTP host] \ +MG_POSTGRES_WRITER_HTTP_PORT=[Service HTTP port] \ +MG_POSTGRES_WRITER_HTTP_SERVER_CERT=[Service HTTP server cert] \ +MG_POSTGRES_WRITER_HTTP_SERVER_KEY=[Service HTTP server key] \ +MG_POSTGRES_HOST=[Postgres host] \ +MG_POSTGRES_PORT=[Postgres port] \ +MG_POSTGRES_USER=[Postgres user] \ +MG_POSTGRES_PASS=[Postgres password] \ +MG_POSTGRES_NAME=[Postgres database name] \ +MG_POSTGRES_SSL_MODE=[Postgres SSL mode] \ +MG_POSTGRES_SSL_CERT=[Postgres SSL cert] \ +MG_POSTGRES_SSL_KEY=[Postgres SSL key] \ +MG_POSTGRES_SSL_ROOT_CERT=[Postgres SSL Root cert] \ +MG_MESSAGE_BROKER_URL=[Message broker instance URL] \ +MG_JAEGER_URL=[Jaeger server URL] \ +MG_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ +MG_POSTGRES_WRITER_INSTANCE_ID=[Service instance ID] \ + +$GOBIN/supermq-postgres-writer +``` + +## Usage + +Starting service will start consuming normalized messages in SenML format. diff --git a/consumers/writers/postgres/consumer.go b/consumers/writers/postgres/consumer.go new file mode 100644 index 000000000..3fd958ca9 --- /dev/null +++ b/consumers/writers/postgres/consumer.go @@ -0,0 +1,216 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/absmach/supermq/consumers" + "github.com/absmach/supermq/pkg/errors" + smqjson "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/gofrs/uuid/v5" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" // required for DB access +) + +var ( + errInvalidMessage = errors.New("invalid message representation") + errSaveMessage = errors.New("failed to save message to postgres database") + errTransRollback = errors.New("failed to rollback transaction") + errNoTable = errors.New("relation does not exist") +) + +var _ consumers.BlockingConsumer = (*postgresRepo)(nil) + +type postgresRepo struct { + db *sqlx.DB +} + +// New returns new PostgreSQL writer. +func New(db *sqlx.DB) consumers.BlockingConsumer { + return &postgresRepo{db: db} +} + +func (pr postgresRepo) ConsumeBlocking(ctx context.Context, message any) (err error) { + switch m := message.(type) { + case smqjson.Messages: + return pr.saveJSON(ctx, m) + default: + return pr.saveSenml(ctx, m) + } +} + +func (pr postgresRepo) saveSenml(ctx context.Context, messages any) (err error) { + msgs, ok := messages.([]senml.Message) + if !ok { + return errSaveMessage + } + q := `INSERT INTO messages (id, channel, subtopic, publisher, protocol, + name, unit, value, string_value, bool_value, data_value, sum, + time, update_time) + VALUES (:id, :channel, :subtopic, :publisher, :protocol, :name, :unit, + :value, :string_value, :bool_value, :data_value, :sum, + :time, :update_time);` + + tx, err := pr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + defer func() { + if err != nil { + if txErr := tx.Rollback(); txErr != nil { + err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr)) + } + return + } + + if err = tx.Commit(); err != nil { + err = errors.Wrap(errSaveMessage, err) + } + }() + + for _, msg := range msgs { + id, err := uuid.NewV4() + if err != nil { + return err + } + m := senmlMessage{Message: msg, ID: id.String()} + if _, err := tx.NamedExec(q, m); err != nil { + pgErr, ok := err.(*pgconn.PgError) + if ok { + if pgErr.Code == pgerrcode.InvalidTextRepresentation { + return errors.Wrap(errSaveMessage, errInvalidMessage) + } + } + + return errors.Wrap(errSaveMessage, err) + } + } + return err +} + +func (pr postgresRepo) saveJSON(ctx context.Context, msgs smqjson.Messages) error { + if err := pr.insertJSON(ctx, msgs); err != nil { + if err == errNoTable { + if err := pr.createTable(msgs.Format); err != nil { + return err + } + return pr.insertJSON(ctx, msgs) + } + return err + } + return nil +} + +func (pr postgresRepo) insertJSON(ctx context.Context, msgs smqjson.Messages) error { + tx, err := pr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + defer func() { + if err != nil { + if txErr := tx.Rollback(); txErr != nil { + err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr)) + } + return + } + + if err = tx.Commit(); err != nil { + err = errors.Wrap(errSaveMessage, err) + } + }() + + q := `INSERT INTO %s (id, channel, created, subtopic, publisher, protocol, payload) + VALUES (:id, :channel, :created, :subtopic, :publisher, :protocol, :payload);` + q = fmt.Sprintf(q, msgs.Format) + + for _, m := range msgs.Data { + var dbmsg jsonMessage + dbmsg, err = toJSONMessage(m) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + + if _, err = tx.NamedExec(q, dbmsg); err != nil { + if preErr, ok := err.(*pgconn.PrepareError); ok { + err = preErr.Unwrap() + } + pgErr, ok := err.(*pgconn.PgError) + if ok { + switch pgErr.Code { + case pgerrcode.InvalidTextRepresentation: + return errors.Wrap(errSaveMessage, errInvalidMessage) + case pgerrcode.UndefinedTable: + return errNoTable + } + } + return err + } + } + return nil +} + +func (pr postgresRepo) createTable(name string) error { + q := `CREATE TABLE IF NOT EXISTS %s ( + id UUID, + created BIGINT, + channel VARCHAR(254), + subtopic VARCHAR(254), + publisher VARCHAR(254), + protocol TEXT, + payload JSONB, + PRIMARY KEY (id) + )` + q = fmt.Sprintf(q, name) + + _, err := pr.db.Exec(q) + return err +} + +type senmlMessage struct { + senml.Message + ID string `db:"id"` +} + +type jsonMessage struct { + ID string `db:"id"` + Channel string `db:"channel"` + Created int64 `db:"created"` + Subtopic string `db:"subtopic"` + Publisher string `db:"publisher"` + Protocol string `db:"protocol"` + Payload []byte `db:"payload"` +} + +func toJSONMessage(msg smqjson.Message) (jsonMessage, error) { + id, err := uuid.NewV4() + if err != nil { + return jsonMessage{}, err + } + + data := []byte("{}") + if msg.Payload != nil { + b, err := json.Marshal(msg.Payload) + if err != nil { + return jsonMessage{}, errors.Wrap(errSaveMessage, err) + } + data = b + } + + m := jsonMessage{ + ID: id.String(), + Channel: msg.Channel, + Created: msg.Created, + Subtopic: msg.Subtopic, + Publisher: msg.Publisher, + Protocol: msg.Protocol, + Payload: data, + } + + return m, nil +} diff --git a/consumers/writers/postgres/consumer_test.go b/consumers/writers/postgres/consumer_test.go new file mode 100644 index 000000000..80401fa5c --- /dev/null +++ b/consumers/writers/postgres/consumer_test.go @@ -0,0 +1,112 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/consumers/writers/postgres" + "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/gofrs/uuid/v5" + "github.com/stretchr/testify/assert" +) + +const ( + msgsNum = 42 + valueFields = 5 + subtopic = "topic" +) + +var ( + v float64 = 5 + stringV = "value" + boolV = true + dataV = "base64" + sum float64 = 42 +) + +func TestSaveSenml(t *testing.T) { + repo := postgres.New(db) + + chid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + msg := senml.Message{} + msg.Channel = chid.String() + + pubid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + msg.Publisher = pubid.String() + + now := time.Now().Unix() + var msgs []senml.Message + + for i := 0; i < msgsNum; i++ { + // Mix possible values as well as value sum. + count := i % valueFields + switch count { + case 0: + msg.Subtopic = subtopic + msg.Value = &v + case 1: + msg.BoolValue = &boolV + case 2: + msg.StringValue = &stringV + case 3: + msg.DataValue = &dataV + case 4: + msg.Sum = &sum + } + + msg.Time = float64(now + int64(i)) + msgs = append(msgs, msg) + } + + err = repo.ConsumeBlocking(context.TODO(), msgs) + assert.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) +} + +func TestSaveJSON(t *testing.T) { + repo := postgres.New(db) + + chid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + pubid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + msg := json.Message{ + Channel: chid.String(), + Publisher: pubid.String(), + Created: time.Now().Unix(), + Subtopic: "subtopic/format/some_json", + Protocol: "mqtt", + Payload: map[string]any{ + "field_1": 123, + "field_2": "value", + "field_3": false, + "field_4": 12.344, + "field_5": map[string]any{ + "field_1": "value", + "field_2": 42, + }, + }, + } + + now := time.Now().Unix() + msgs := json.Messages{ + Format: "some_json", + } + + for i := 0; i < msgsNum; i++ { + msg.Created = now + int64(i) + msgs.Data = append(msgs.Data, msg) + } + + err = repo.ConsumeBlocking(context.TODO(), msgs) + assert.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) +} diff --git a/consumers/writers/postgres/doc.go b/consumers/writers/postgres/doc.go new file mode 100644 index 000000000..a92d4f9b5 --- /dev/null +++ b/consumers/writers/postgres/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres contains repository implementations using Postgres as +// the underlying database. +package postgres diff --git a/consumers/writers/postgres/init.go b/consumers/writers/postgres/init.go new file mode 100644 index 000000000..de140b258 --- /dev/null +++ b/consumers/writers/postgres/init.go @@ -0,0 +1,46 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import migrate "github.com/rubenv/sql-migrate" + +// Migration of postgres-writer. +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "messages_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS messages ( + id UUID, + channel UUID, + subtopic VARCHAR(254), + publisher UUID, + protocol TEXT, + name TEXT, + unit TEXT, + value FLOAT, + string_value TEXT, + bool_value BOOL, + data_value BYTEA, + sum FLOAT, + time FLOAT, + update_time FLOAT, + PRIMARY KEY (id) + )`, + }, + Down: []string{ + "DROP TABLE messages", + }, + }, + { + Id: "messages_2", + Up: []string{ + `ALTER TABLE messages DROP CONSTRAINT messages_pkey`, + `ALTER TABLE messages ADD PRIMARY KEY (time, publisher, subtopic, name)`, + }, + }, + }, + } +} diff --git a/consumers/writers/postgres/setup_test.go b/consumers/writers/postgres/setup_test.go new file mode 100644 index 000000000..ecb2ab34c --- /dev/null +++ b/consumers/writers/postgres/setup_test.go @@ -0,0 +1,85 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres_test contains tests for PostgreSQL repository +// implementations. +package postgres_test + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/absmach/supermq/consumers/writers/postgres" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var db *sqlx.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + db, err = pgclient.Setup(dbConfig, *postgres.Migration()) + if err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/consumers/writers/timescale/README.md b/consumers/writers/timescale/README.md new file mode 100644 index 000000000..45dc7f010 --- /dev/null +++ b/consumers/writers/timescale/README.md @@ -0,0 +1,76 @@ +# Timescale writer + +Timescale writer provides message repository implementation for Timescale. + +## Configuration + +The service is configured using the environment variables presented in the +following table. Note that any unset variables will be replaced with their +default values. + +| Variable | Description | Default | +| ------------------------------------- | --------------------------------------------------------- | ---------------------------- | +| MG_TIMESCALE_WRITER_LOG_LEVEL | Service log level | info | +| MG_TIMESCALE_WRITER_CONFIG_PATH | Configuration file path with Message broker subjects list | /config.toml | +| MG_TIMESCALE_WRITER_HTTP_HOST | Service HTTP host | localhost | +| MG_TIMESCALE_WRITER_HTTP_PORT | Service HTTP port | 9012 | +| MG_TIMESCALE_WRITER_HTTP_SERVER_CERT | Service HTTP server certificate path | "" | +| MG_TIMESCALE_WRITER_HTTP_SERVER_KEY | Service HTTP server key | "" | +| MG_TIMESCALE_HOST | Timescale DB host | timescale | +| MG_TIMESCALE_PORT | Timescale DB port | 5432 | +| MG_TIMESCALE_USER | Timescale user | supermq | +| MG_TIMESCALE_PASS | Timescale password | supermq | +| MG_TIMESCALE_NAME | Timescale database name | messages | +| MG_TIMESCALE_SSL_MODE | Timescale SSL mode | disabled | +| MG_TIMESCALE_SSL_CERT | Timescale SSL certificate path | "" | +| MG_TIMESCALE_SSL_KEY | Timescale SSL key | "" | +| MG_TIMESCALE_SSL_ROOT_CERT | Timescale SSL root certificate path | "" | +| MG_MESSAGE_BROKER_URL | Message broker instance URL | nats://localhost:4222 | +| MG_JAEGER_URL | Jaeger server URL | http://jaeger:4318/v1/traces | +| MG_SEND_TELEMETRY | Send telemetry to supermq call home server | true | +| MG_TIMESCALE_WRITER_INSTANCE_ID | Timescale writer instance ID | "" | + +## Deployment + +The service itself is distributed as Docker container. Check the [`timescale-writer`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in the root docker-compose file to see how service is deployed. + +To start the service, execute the following shell script: + +```bash +# download the latest version of the service +git clone https://github.com/absmach/supermq + +cd supermq + +# compile the timescale writer +make timescale-writer + +# copy binary to bin +make install + +# Set the environment variables and run the service +MG_TIMESCALE_WRITER_LOG_LEVEL=[Service log level] \ +MG_TIMESCALE_WRITER_CONFIG_PATH=[Configuration file path with Message broker subjects list] \ +MG_TIMESCALE_WRITER_HTTP_HOST=[Service HTTP host] \ +MG_TIMESCALE_WRITER_HTTP_PORT=[Service HTTP port] \ +MG_TIMESCALE_WRITER_HTTP_SERVER_CERT=[Service HTTP server cert] \ +MG_TIMESCALE_WRITER_HTTP_SERVER_KEY=[Service HTTP server key] \ +MG_TIMESCALE_HOST=[Timescale host] \ +MG_TIMESCALE_PORT=[Timescale port] \ +MG_TIMESCALE_USER=[Timescale user] \ +MG_TIMESCALE_PASS=[Timescale password] \ +MG_TIMESCALE_NAME=[Timescale database name] \ +MG_TIMESCALE_SSL_MODE=[Timescale SSL mode] \ +MG_TIMESCALE_SSL_CERT=[Timescale SSL cert] \ +MG_TIMESCALE_SSL_KEY=[Timescale SSL key] \ +MG_TIMESCALE_SSL_ROOT_CERT=[Timescale SSL Root cert] \ +MG_MESSAGE_BROKER_URL=[Message broker instance URL] \ +MG_JAEGER_URL=[Jaeger server URL] \ +MG_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ +MG_TIMESCALE_WRITER_INSTANCE_ID=[Timescale writer instance ID] \ +$GOBIN/supermq-timescale-writer +``` + +## Usage + +Starting service will start consuming normalized messages in SenML format. diff --git a/consumers/writers/timescale/consumer.go b/consumers/writers/timescale/consumer.go new file mode 100644 index 000000000..40b82b61c --- /dev/null +++ b/consumers/writers/timescale/consumer.go @@ -0,0 +1,209 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package timescale + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/absmach/supermq/consumers" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/postgres" + smqjson "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" // required for DB access +) + +var ( + errInvalidMessage = errors.New("invalid message representation") + errSaveMessage = errors.New("failed to save message to timescale database") + errTransRollback = errors.New("failed to rollback transaction") + errNoTable = errors.New("relation does not exist") +) + +var _ consumers.BlockingConsumer = (*timescaleRepo)(nil) + +type timescaleRepo struct { + db *sqlx.DB +} + +// New returns new TimescaleSQL writer. +func New(db *sqlx.DB) consumers.BlockingConsumer { + return ×caleRepo{db: db} +} + +func (tr *timescaleRepo) ConsumeBlocking(ctx context.Context, message any) (err error) { + switch m := message.(type) { + case smqjson.Messages: + return tr.saveJSON(ctx, m) + default: + err := tr.saveSenml(ctx, m) + if err != nil && errors.Contains(err, repoerr.ErrConflict) { + return messaging.NewError(repoerr.ErrConflict, messaging.Term) + } + return err + } +} + +func (tr timescaleRepo) saveSenml(ctx context.Context, messages any) (err error) { + msgs, ok := messages.([]senml.Message) + if !ok { + return errSaveMessage + } + q := `INSERT INTO messages (channel, subtopic, publisher, protocol, + name, unit, value, string_value, bool_value, data_value, sum, + time, update_time) + VALUES (:channel, :subtopic, :publisher, :protocol, :name, :unit, + :value, :string_value, :bool_value, :data_value, :sum, + :time, :update_time);` + + tx, err := tr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + defer func() { + if err != nil { + if txErr := tx.Rollback(); txErr != nil { + err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr)) + } + return + } + + if err = tx.Commit(); err != nil { + err = errors.Wrap(errSaveMessage, err) + } + }() + + for _, msg := range msgs { + m := senmlMessage{Message: msg} + if _, err := tx.NamedExec(q, m); err != nil { + pgErr, ok := err.(*pgconn.PgError) + if ok { + if pgErr.Code == pgerrcode.InvalidTextRepresentation { + return errors.Wrap(errSaveMessage, errInvalidMessage) + } + return postgres.HandleError(errSaveMessage, err) + } + + return errors.Wrap(errSaveMessage, err) + } + } + return err +} + +func (tr timescaleRepo) saveJSON(ctx context.Context, msgs smqjson.Messages) error { + if err := tr.insertJSON(ctx, msgs); err != nil { + if err == errNoTable { + if err := tr.createTable(msgs.Format); err != nil { + return err + } + return tr.insertJSON(ctx, msgs) + } + return err + } + return nil +} + +func (tr timescaleRepo) insertJSON(ctx context.Context, msgs smqjson.Messages) error { + tx, err := tr.db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + defer func() { + if err != nil { + if txErr := tx.Rollback(); txErr != nil { + err = errors.Wrap(err, errors.Wrap(errTransRollback, txErr)) + } + return + } + + if err = tx.Commit(); err != nil { + err = errors.Wrap(errSaveMessage, err) + } + }() + + q := `INSERT INTO %s (channel, created, subtopic, publisher, protocol, payload) + VALUES (:channel, :created, :subtopic, :publisher, :protocol, :payload);` + q = fmt.Sprintf(q, msgs.Format) + + for _, m := range msgs.Data { + var dbmsg jsonMessage + dbmsg, err = toJSONMessage(m) + if err != nil { + return errors.Wrap(errSaveMessage, err) + } + if _, err = tx.NamedExec(q, dbmsg); err != nil { + if preErr, ok := err.(*pgconn.PrepareError); ok { + err = preErr.Unwrap() + } + pgErr, ok := err.(*pgconn.PgError) + if ok { + switch pgErr.Code { + case pgerrcode.InvalidTextRepresentation: + return errors.Wrap(errSaveMessage, errInvalidMessage) + case pgerrcode.UndefinedTable: + return errNoTable + } + } + return err + } + } + return nil +} + +func (tr timescaleRepo) createTable(name string) error { + q := `CREATE TABLE IF NOT EXISTS %s ( + created BIGINT NOT NULL, + channel VARCHAR(254), + subtopic VARCHAR(254), + publisher VARCHAR(254), + protocol TEXT, + payload JSONB, + PRIMARY KEY (created, publisher, subtopic) + );` + q = fmt.Sprintf(q, name) + + _, err := tr.db.Exec(q) + return err +} + +type senmlMessage struct { + senml.Message +} + +type jsonMessage struct { + Channel string `db:"channel"` + Created int64 `db:"created"` + Subtopic string `db:"subtopic"` + Publisher string `db:"publisher"` + Protocol string `db:"protocol"` + Payload []byte `db:"payload"` +} + +func toJSONMessage(msg smqjson.Message) (jsonMessage, error) { + data := []byte("{}") + if msg.Payload != nil { + b, err := json.Marshal(msg.Payload) + if err != nil { + return jsonMessage{}, errors.Wrap(errSaveMessage, err) + } + data = b + } + + m := jsonMessage{ + Channel: msg.Channel, + Created: msg.Created, + Subtopic: msg.Subtopic, + Publisher: msg.Publisher, + Protocol: msg.Protocol, + Payload: data, + } + + return m, nil +} diff --git a/consumers/writers/timescale/consumer_test.go b/consumers/writers/timescale/consumer_test.go new file mode 100644 index 000000000..45646014e --- /dev/null +++ b/consumers/writers/timescale/consumer_test.go @@ -0,0 +1,112 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package timescale_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/consumers/writers/timescale" + "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/gofrs/uuid/v5" + "github.com/stretchr/testify/assert" +) + +const ( + msgsNum = 42 + valueFields = 5 + subtopic = "topic" +) + +var ( + v float64 = 5 + stringV = "value" + boolV = true + dataV = "base64" + sum float64 = 42 +) + +func TestSaveSenml(t *testing.T) { + repo := timescale.New(db) + + chid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + msg := senml.Message{} + msg.Channel = chid.String() + + pubid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + msg.Publisher = pubid.String() + + now := time.Now().Unix() + var msgs []senml.Message + + for i := 0; i < msgsNum; i++ { + // Mix possible values as well as value sum. + count := i % valueFields + switch count { + case 0: + msg.Subtopic = subtopic + msg.Value = &v + case 1: + msg.BoolValue = &boolV + case 2: + msg.StringValue = &stringV + case 3: + msg.DataValue = &dataV + case 4: + msg.Sum = &sum + } + + msg.Time = float64(now + int64(i)) + msgs = append(msgs, msg) + } + + err = repo.ConsumeBlocking(context.TODO(), msgs) + assert.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) +} + +func TestSaveJSON(t *testing.T) { + repo := timescale.New(db) + + chid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + pubid, err := uuid.NewV4() + assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) + + msg := json.Message{ + Channel: chid.String(), + Publisher: pubid.String(), + Created: time.Now().Unix(), + Subtopic: "subtopic/format/some_json", + Protocol: "mqtt", + Payload: map[string]any{ + "field_1": 123, + "field_2": "value", + "field_3": false, + "field_4": 12.344, + "field_5": map[string]any{ + "field_1": "value", + "field_2": 42, + }, + }, + } + + now := time.Now().Unix() + msgs := json.Messages{ + Format: "some_json", + } + + for i := 0; i < msgsNum; i++ { + msg.Created = now + int64(i) + msgs.Data = append(msgs.Data, msg) + } + + err = repo.ConsumeBlocking(context.TODO(), msgs) + assert.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) +} diff --git a/consumers/writers/timescale/doc.go b/consumers/writers/timescale/doc.go new file mode 100644 index 000000000..302be6ea5 --- /dev/null +++ b/consumers/writers/timescale/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package timescale contains repository implementations using Timescale as +// the underlying database. +package timescale diff --git a/consumers/writers/timescale/init.go b/consumers/writers/timescale/init.go new file mode 100644 index 000000000..d4e9525c1 --- /dev/null +++ b/consumers/writers/timescale/init.go @@ -0,0 +1,72 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package timescale + +import migrate "github.com/rubenv/sql-migrate" + +// Migration of timescale-writer. +func Migration() *migrate.MemoryMigrationSource { + return &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "messages_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS messages ( + time BIGINT NOT NULL, + channel UUID, + subtopic VARCHAR(254), + publisher VARCHAR(254), + protocol TEXT, + name VARCHAR(254), + unit TEXT, + value FLOAT, + string_value TEXT, + bool_value BOOL, + data_value BYTEA, + sum FLOAT, + update_time FLOAT, + PRIMARY KEY (time, channel, subtopic, protocol, publisher, name) + );`, + + // Creating HyperTable with chunks interval of 1 day = 86400000000000 Nanoseconds + "SELECT create_hypertable('messages', by_range('time', 86400000000000 ), if_not_exists => TRUE, migrate_data => TRUE);", + }, + Down: []string{ + "DROP TABLE messages", + }, + }, + { + Id: "messages_2", + Up: []string{ + // Index on channel, time + "CREATE INDEX IF NOT EXISTS idx_channel_time ON messages (channel, time DESC) WITH (timescaledb.transaction_per_chunk);", + + // Index on channel, name, time + "CREATE INDEX IF NOT EXISTS idx_channel_name_time ON messages (channel, name, time DESC) WITH (timescaledb.transaction_per_chunk);", + + // Index on channel, subtopic, name, time + "CREATE INDEX IF NOT EXISTS idx_channel_subtopic_name_time ON messages (channel, subtopic, name, time DESC) WITH (timescaledb.transaction_per_chunk);", + + // Index on channel, publisher, name, time + "CREATE INDEX IF NOT EXISTS idx_channel_publisher_name_time ON messages (channel, publisher, name, time DESC) WITH (timescaledb.transaction_per_chunk);", + + // Index on channel, subtopic, publisher, name, time + "CREATE INDEX IF NOT EXISTS idx_channel_subtopic_publisher_name_time ON messages (channel, subtopic, publisher, name, time DESC) WITH (timescaledb.transaction_per_chunk);", + }, + DisableTransactionUp: true, + Down: []string{ + "DROP INDEX IF EXISTS idx_channel_time ;", + + "DROP INDEX IF EXISTS idx_channel_name_time ;", + + "DROP INDEX IF EXISTS idx_channel_subtopic_name_time ;", + + "DROP INDEX IF EXISTS idx_channel_publisher_name_time ;", + + "DROP INDEX IF EXISTS idx_channel_subtopic_publisher_name_time ;", + }, + }, + }, + } +} diff --git a/consumers/writers/timescale/setup_test.go b/consumers/writers/timescale/setup_test.go new file mode 100644 index 000000000..aacf15c05 --- /dev/null +++ b/consumers/writers/timescale/setup_test.go @@ -0,0 +1,85 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package timescale_test contains tests for TimescaleSQL repository +// implementations. +package timescale_test + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/absmach/supermq/consumers/writers/timescale" + pgclient "github.com/absmach/supermq/pkg/postgres" + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var db *sqlx.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "timescale/timescaledb", + Tag: "2.13.1-pg16", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + db, err = pgclient.Setup(dbConfig, *timescale.Migration()) + if err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/docker/.env b/docker/.env index 8c6afec18..1a8e82a5a 100644 --- a/docker/.env +++ b/docker/.env @@ -7,560 +7,845 @@ GRPC_MTLS= ## NginX -SMQ_NGINX_HTTP_PORT=80 -SMQ_NGINX_SSL_PORT=443 -SMQ_NGINX_MQTT_PORT=1883 -SMQ_NGINX_MQTTS_PORT=8883 -SMQ_NGINX_SERVER_NAME= +MG_NGINX_HTTP_PORT=80 +MG_NGINX_SSL_PORT=443 +MG_NGINX_MQTT_PORT=1883 +MG_NGINX_MQTTS_PORT=8883 +MG_NGINX_AMQP_PORT=5682 +MG_NGINX_SERVER_NAME= -## Nats -SMQ_NATS_PORT=4222 -SMQ_NATS_HTTP_PORT=8222 -SMQ_NATS_JETSTREAM_KEY=u7wFoAPgXpDueXOFldBnXDh4xjnSOyEJ2Cb8Z5SZvGLzIZ3U4exWhhoIBZHzuNvh -SMQ_NATS_URL=nats://nats:${SMQ_NATS_PORT} -# Configs for nats as MQTT broker -SMQ_NATS_HEALTH_CHECK=http://nats:${SMQ_NATS_HTTP_PORT}/healthz -SMQ_NATS_WS_TARGET_PATH= -SMQ_NATS_MQTT_QOS=0 - -## RabbitMQ -SMQ_RABBITMQ_PORT=5672 -SMQ_RABBITMQ_HTTP_PORT=15672 -SMQ_RABBITMQ_WS_PORT=15675 -SMQ_RABBITMQ_USER=supermq -SMQ_RABBITMQ_PASS=supermq -SMQ_RABBITMQ_COOKIE=supermq -SMQ_RABBITMQ_VHOST=/ -SMQ_RABBITMQ_URL=amqp://${SMQ_RABBITMQ_USER}:${SMQ_RABBITMQ_PASS}@rabbitmq:${SMQ_RABBITMQ_PORT}${SMQ_RABBITMQ_VHOST} -SMQ_RABBITMQ_MQTT_QOS=0 -SMQ_RABBITMQ_WS_TARGET_PATH=/ws +## FluxMQ Cluster +MG_FLUXMQ_IMAGE_TAG=latest +MG_FLUXMQ_AMQP091_PORT=5682 +MG_FLUXMQ_API_PORT_1=9081 +MG_FLUXMQ_API_PORT_2=9082 +MG_FLUXMQ_API_PORT_3=9083 ## Message Broker -SMQ_MESSAGE_BROKER_TYPE=msg_nats -SMQ_MESSAGE_BROKER_URL=${SMQ_NATS_URL} - -## MQTT Broker -SMQ_MQTT_BROKER_TYPE=rabbitmq -SMQ_MQTT_BROKER_HEALTH_CHECK= -SMQ_MQTT_ADAPTER_MQTT_QOS=${SMQ_RABBITMQ_MQTT_QOS} -SMQ_MQTT_ADAPTER_MQTT_TARGET_PROTOCOL=mqtt -SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE} -SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT=1883 -SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME=${SMQ_RABBITMQ_USER} -SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD=${SMQ_RABBITMQ_PASS} -SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK=${SMQ_MQTT_BROKER_HEALTH_CHECK} -SMQ_MQTT_ADAPTER_WS_TARGET_PROTOCOL=http -SMQ_MQTT_ADAPTER_WS_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE} -SMQ_MQTT_ADAPTER_WS_TARGET_PORT=${SMQ_RABBITMQ_WS_PORT} -SMQ_MQTT_ADAPTER_WS_TARGET_PATH=${SMQ_RABBITMQ_WS_TARGET_PATH} +MG_MESSAGE_BROKER_URL=amqp://guest:guest@nginx:${MG_FLUXMQ_AMQP091_PORT}/ ## Redis -SMQ_REDIS_TCP_PORT=6379 -SMQ_REDIS_URL=redis://es-redis:${SMQ_REDIS_TCP_PORT}/0 +MG_REDIS_TCP_PORT=6379 +MG_REDIS_URL=redis://es-redis:${MG_REDIS_TCP_PORT}/0 ## Event Store -SMQ_ES_TYPE=${SMQ_MESSAGE_BROKER_TYPE} -SMQ_ES_URL=${SMQ_MESSAGE_BROKER_URL} +MG_ES_TYPE=es_fluxmq +MG_ES_URL=amqp://guest:guest@nginx:5682/ ## Jaeger -SMQ_JAEGER_COLLECTOR_OTLP_ENABLED=true -SMQ_JAEGER_FRONTEND=16686 -SMQ_JAEGER_OLTP_HTTP=4318 -SMQ_JAEGER_URL=http://jaeger:4318/v1/traces -SMQ_JAEGER_TRACE_RATIO=1.0 -SMQ_JAEGER_MEMORY_MAX_TRACES=5000 +MG_JAEGER_COLLECTOR_OTLP_ENABLED=true +MG_JAEGER_FRONTEND=16686 +MG_JAEGER_OLTP_HTTP=4318 +MG_JAEGER_URL=http://jaeger:4318/v1/traces +MG_JAEGER_TRACE_RATIO=1.0 +MG_JAEGER_MEMORY_MAX_TRACES=5000 ## Call home -SMQ_SEND_TELEMETRY=true +MG_SEND_TELEMETRY=true ## Postgres -SMQ_POSTGRES_MAX_CONNECTIONS=100 +MG_POSTGRES_MAX_CONNECTIONS=100 ## Core Services ### Auth -SMQ_AUTH_LOG_LEVEL=debug -SMQ_AUTH_HTTP_HOST=auth -SMQ_AUTH_HTTP_PORT=9001 -SMQ_AUTH_HTTP_SERVER_CERT= -SMQ_AUTH_HTTP_SERVER_KEY= -SMQ_AUTH_GRPC_HOST=auth -SMQ_AUTH_GRPC_PORT=7001 -SMQ_AUTH_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-server.crt}${GRPC_TLS:+./ssl/certs/auth-grpc-server.crt} -SMQ_AUTH_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-server.key}${GRPC_TLS:+./ssl/certs/auth-grpc-server.key} -SMQ_AUTH_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_AUTH_DB_HOST=auth-db -SMQ_AUTH_DB_PORT=5432 -SMQ_AUTH_DB_USER=supermq -SMQ_AUTH_DB_PASS=supermq -SMQ_AUTH_DB_NAME=auth -SMQ_AUTH_DB_SSL_MODE=disable -SMQ_AUTH_DB_SSL_CERT= -SMQ_AUTH_DB_SSL_KEY= -SMQ_AUTH_DB_SSL_ROOT_CERT= -SMQ_AUTH_ACCESS_TOKEN_DURATION="1h" -SMQ_AUTH_REFRESH_TOKEN_DURATION="24h" -SMQ_AUTH_KEYS_ALGORITHM="EdDSA" -SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" -SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" -SMQ_AUTH_INVITATION_DURATION="168h" -SMQ_AUTH_ADAPTER_INSTANCE_ID= -SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0 -SMQ_AUTH_CACHE_KEY_DURATION=10m -SMQ_AUTH_JWKS_URL=http://${SMQ_AUTH_HTTP_HOST}:${SMQ_AUTH_HTTP_PORT}/keys/.well-known/jwks.json -SMQ_AUTH_JWKS_CACHE_MAX_AGE=900 -SMQ_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE=60 +MG_AUTH_LOG_LEVEL=debug +MG_AUTH_HTTP_HOST=auth +MG_AUTH_HTTP_PORT=9001 +MG_AUTH_HTTP_SERVER_CERT= +MG_AUTH_HTTP_SERVER_KEY= +MG_AUTH_GRPC_HOST=auth +MG_AUTH_GRPC_PORT=7001 +MG_AUTH_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-server.crt}${GRPC_TLS:+./ssl/certs/auth-grpc-server.crt} +MG_AUTH_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-server.key}${GRPC_TLS:+./ssl/certs/auth-grpc-server.key} +MG_AUTH_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_AUTH_DB_HOST=auth-db +MG_AUTH_DB_PORT=5432 +MG_AUTH_DB_USER=magistrala +MG_AUTH_DB_PASS=magistrala +MG_AUTH_DB_NAME=auth +MG_AUTH_DB_SSL_MODE=disable +MG_AUTH_DB_SSL_CERT= +MG_AUTH_DB_SSL_KEY= +MG_AUTH_DB_SSL_ROOT_CERT= +MG_AUTH_ACCESS_TOKEN_DURATION="1h" +MG_AUTH_REFRESH_TOKEN_DURATION="24h" +MG_AUTH_KEYS_ALGORITHM="EdDSA" +MG_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/active.key" +MG_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" +MG_AUTH_INVITATION_DURATION="168h" +MG_AUTH_ADAPTER_INSTANCE_ID= +MG_AUTH_CACHE_URL=redis://auth-redis:${MG_REDIS_TCP_PORT}/0 +MG_AUTH_CACHE_KEY_DURATION=10m +MG_AUTH_JWKS_URL=http://${MG_AUTH_HTTP_HOST}:${MG_AUTH_HTTP_PORT}/keys/.well-known/jwks.json +MG_AUTH_JWKS_CACHE_MAX_AGE=900 +MG_AUTH_JWKS_CACHE_STALE_WHILE_REVALIDATE=60 #### Client Callout -SMQ_CLIENTS_CALLOUT_URLS="" -SMQ_CLIENTS_CALLOUT_METHOD="POST" -SMQ_CLIENTS_CALLOUT_TLS_VERIFICATION="false" -SMQ_CLIENTS_CALLOUT_TIMEOUT="10s" -SMQ_CLIENTS_CALLOUT_CA_CERT="" -SMQ_CLIENTS_CALLOUT_CERT="" -SMQ_CLIENTS_CALLOUT_KEY="" -SMQ_CLIENTS_CALLOUT_OPERATIONS="" +MG_CLIENTS_CALLOUT_URLS="" +MG_CLIENTS_CALLOUT_METHOD="POST" +MG_CLIENTS_CALLOUT_TLS_VERIFICATION="false" +MG_CLIENTS_CALLOUT_TIMEOUT="10s" +MG_CLIENTS_CALLOUT_CA_CERT="" +MG_CLIENTS_CALLOUT_CERT="" +MG_CLIENTS_CALLOUT_KEY="" +MG_CLIENTS_CALLOUT_OPERATIONS="" #### Channel Callout -SMQ_CHANNELS_CALLOUT_URLS="" -SMQ_CHANNELS_CALLOUT_METHOD="POST" -SMQ_CHANNELS_CALLOUT_TLS_VERIFICATION="false" -SMQ_CHANNELS_CALLOUT_TIMEOUT="10s" -SMQ_CHANNELS_CALLOUT_CA_CERT="" -SMQ_CHANNELS_CALLOUT_CERT="" -SMQ_CHANNELS_CALLOUT_KEY="" -SMQ_CHANNELS_CALLOUT_OPERATIONS="" +MG_CHANNELS_CALLOUT_URLS="" +MG_CHANNELS_CALLOUT_METHOD="POST" +MG_CHANNELS_CALLOUT_TLS_VERIFICATION="false" +MG_CHANNELS_CALLOUT_TIMEOUT="10s" +MG_CHANNELS_CALLOUT_CA_CERT="" +MG_CHANNELS_CALLOUT_CERT="" +MG_CHANNELS_CALLOUT_KEY="" +MG_CHANNELS_CALLOUT_OPERATIONS="" #### Group Callout -SMQ_GROUPS_CALLOUT_URLS="" -SMQ_GROUPS_CALLOUT_METHOD="POST" -SMQ_GROUPS_CALLOUT_TLS_VERIFICATION="false" -SMQ_GROUPS_CALLOUT_TIMEOUT="10s" -SMQ_GROUPS_CALLOUT_CA_CERT="" -SMQ_GROUPS_CALLOUT_CERT="" -SMQ_GROUPS_CALLOUT_KEY="" -SMQ_GROUPS_CALLOUT_OPERATIONS="" +MG_GROUPS_CALLOUT_URLS="" +MG_GROUPS_CALLOUT_METHOD="POST" +MG_GROUPS_CALLOUT_TLS_VERIFICATION="false" +MG_GROUPS_CALLOUT_TIMEOUT="10s" +MG_GROUPS_CALLOUT_CA_CERT="" +MG_GROUPS_CALLOUT_CERT="" +MG_GROUPS_CALLOUT_KEY="" +MG_GROUPS_CALLOUT_OPERATIONS="" #### Domain Callout -SMQ_DOMAINS_CALLOUT_URLS="" -SMQ_DOMAINS_CALLOUT_METHOD="POST" -SMQ_DOMAINS_CALLOUT_TLS_VERIFICATION="false" -SMQ_DOMAINS_CALLOUT_TIMEOUT="10s" -SMQ_DOMAINS_CALLOUT_CA_CERT="" -SMQ_DOMAINS_CALLOUT_CERT="" -SMQ_DOMAINS_CALLOUT_KEY="" -SMQ_DOMAINS_CALLOUT_OPERATIONS="" +MG_DOMAINS_CALLOUT_URLS="" +MG_DOMAINS_CALLOUT_METHOD="POST" +MG_DOMAINS_CALLOUT_TLS_VERIFICATION="false" +MG_DOMAINS_CALLOUT_TIMEOUT="10s" +MG_DOMAINS_CALLOUT_CA_CERT="" +MG_DOMAINS_CALLOUT_CERT="" +MG_DOMAINS_CALLOUT_KEY="" +MG_DOMAINS_CALLOUT_OPERATIONS="" #### Auth Client Config -SMQ_AUTH_URL=auth:9001 -SMQ_AUTH_GRPC_URL=auth:7001 -SMQ_AUTH_GRPC_TIMEOUT=300s -SMQ_AUTH_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt} -SMQ_AUTH_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key} -SMQ_AUTH_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_AUTH_URL=auth:9001 +MG_AUTH_GRPC_URL=auth:7001 +MG_AUTH_GRPC_TIMEOUT=300s +MG_AUTH_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt} +MG_AUTH_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key} +MG_AUTH_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} ### Domains -SMQ_DOMAINS_LOG_LEVEL=debug -SMQ_DOMAINS_HTTP_HOST=domains -SMQ_DOMAINS_HTTP_PORT=9003 -SMQ_DOMAINS_HTTP_SERVER_KEY= -SMQ_DOMAINS_HTTP_SERVER_CERT= -SMQ_DOMAINS_GRPC_HOST=domains -SMQ_DOMAINS_GRPC_PORT=7003 -SMQ_DOMAINS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.crt}${GRPC_TLS:+./ssl/certs/domains-grpc-server.crt} -SMQ_DOMAINS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.key}${GRPC_TLS:+./ssl/certs/domains-grpc-server.key} -SMQ_DOMAINS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_DOMAINS_DB_HOST=domains-db -SMQ_DOMAINS_DB_PORT=5432 -SMQ_DOMAINS_DB_NAME=domains -SMQ_DOMAINS_DB_USER=supermq -SMQ_DOMAINS_DB_PASS=supermq -SMQ_DOMAINS_DB_SSL_MODE= -SMQ_DOMAINS_DB_SSL_KEY= -SMQ_DOMAINS_DB_SSL_CERT= -SMQ_DOMAINS_DB_SSL_ROOT_CERT= -SMQ_DOMAINS_INSTANCE_ID= -SMQ_DOMAINS_CACHE_URL=redis://domains-redis:${SMQ_REDIS_TCP_PORT}/0 -SMQ_DOMAINS_CACHE_KEY_DURATION=10m +MG_DOMAINS_LOG_LEVEL=debug +MG_DOMAINS_HTTP_HOST=domains +MG_DOMAINS_HTTP_PORT=9003 +MG_DOMAINS_HTTP_SERVER_KEY= +MG_DOMAINS_HTTP_SERVER_CERT= +MG_DOMAINS_GRPC_HOST=domains +MG_DOMAINS_GRPC_PORT=7003 +MG_DOMAINS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.crt}${GRPC_TLS:+./ssl/certs/domains-grpc-server.crt} +MG_DOMAINS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.key}${GRPC_TLS:+./ssl/certs/domains-grpc-server.key} +MG_DOMAINS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_DOMAINS_DB_HOST=domains-db +MG_DOMAINS_DB_PORT=5432 +MG_DOMAINS_DB_NAME=domains +MG_DOMAINS_DB_USER=magistrala +MG_DOMAINS_DB_PASS=magistrala +MG_DOMAINS_DB_SSL_MODE= +MG_DOMAINS_DB_SSL_KEY= +MG_DOMAINS_DB_SSL_CERT= +MG_DOMAINS_DB_SSL_ROOT_CERT= +MG_DOMAINS_INSTANCE_ID= +MG_DOMAINS_CACHE_URL=redis://domains-redis:${MG_REDIS_TCP_PORT}/0 +MG_DOMAINS_CACHE_KEY_DURATION=10m #### Domains Client Config -SMQ_DOMAINS_URL=http://domains:9003 -SMQ_DOMAINS_GRPC_URL=domains:7003 -SMQ_DOMAINS_GRPC_TIMEOUT=300s -SMQ_DOMAINS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} -SMQ_DOMAINS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} -SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_DOMAINS_URL=http://domains:9003 +MG_DOMAINS_GRPC_URL=domains:7003 +MG_DOMAINS_GRPC_TIMEOUT=300s +MG_DOMAINS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} +MG_DOMAINS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} +MG_DOMAINS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} ### SpiceDB Datastore config -SMQ_SPICEDB_DB_USER=supermq -SMQ_SPICEDB_DB_PASS=supermq -SMQ_SPICEDB_DB_NAME=spicedb -SMQ_SPICEDB_DB_PORT=5432 +MG_SPICEDB_DB_USER=magistrala +MG_SPICEDB_DB_PASS=magistrala +MG_SPICEDB_DB_NAME=spicedb +MG_SPICEDB_DB_PORT=5432 ### SpiceDB config -SMQ_SPICEDB_PRE_SHARED_KEY="12345678" -SMQ_SPICEDB_SCHEMA_FILE="/schema.zed" -SMQ_SPICEDB_HOST=supermq-spicedb -SMQ_SPICEDB_PORT=50051 -SMQ_SPICEDB_DATASTORE_ENGINE=postgres - -### UI -SMQ_UI_LOG_LEVEL=debug -SMQ_UI_PORT=9095 -SMQ_HTTP_ADAPTER_URL=http://http-adapter:8008 -SMQ_CLIENTS_URL=http://clients:9006 -SMQ_USERS_URL=http://users:9002 -SMQ_INVITATIONS_URL=http://invitations:9020 -SMQ_DOMAINS_URL=http://domains:9003 -SMQ_UI_HOST_URL=http://localhost:9095 -SMQ_UI_VERIFICATION_TLS=false -SMQ_UI_CONTENT_TYPE=application/senml+json -SMQ_UI_INSTANCE_ID= -SMQ_UI_DB_HOST=ui-db -SMQ_UI_DB_PORT=5432 -SMQ_UI_DB_USER=supermq -SMQ_UI_DB_PASS=supermq -SMQ_UI_DB_NAME=ui -SMQ_UI_DB_SSL_MODE=disable -SMQ_UI_DB_SSL_CERT= -SMQ_UI_DB_SSL_KEY= -SMQ_UI_DB_SSL_ROOT_CERT= -SMQ_UI_HASH_KEY=5jx4x2Qg9OUmzpP5dbveWQ -SMQ_UI_BLOCK_KEY=UtgZjr92jwRY6SPUndHXiyl9QY8qTUyZ -SMQ_UI_PATH_PREFIX=/ui +MG_SPICEDB_PRE_SHARED_KEY="12345678" +MG_SPICEDB_SCHEMA_FILE="/schema.zed" +MG_PERMISSIONS_FILE="/permission.yaml" +MG_SPICEDB_HOST=spicedb +MG_SPICEDB_PORT=50051 +MG_SPICEDB_DATASTORE_ENGINE=postgres ### Users -SMQ_USERS_LOG_LEVEL=debug -SMQ_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH -SMQ_USERS_ADMIN_EMAIL=admin@example.com -SMQ_USERS_ADMIN_PASSWORD=12345678 -SMQ_USERS_ADMIN_USERNAME=admin -SMQ_USERS_ADMIN_FIRST_NAME=super -SMQ_USERS_ADMIN_LAST_NAME=admin -SMQ_USERS_PASS_REGEX=^.{8,}$ -SMQ_USERS_HTTP_HOST=users -SMQ_USERS_HTTP_PORT=9002 -SMQ_USERS_HTTP_SERVER_CERT= -SMQ_USERS_HTTP_SERVER_KEY= -SMQ_USERS_GRPC_HOST=users -SMQ_USERS_GRPC_PORT=7002 -SMQ_USERS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.crt}${GRPC_TLS:+./ssl/certs/domains-grpc-server.crt} -SMQ_USERS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.key}${GRPC_TLS:+./ssl/certs/domains-grpc-server.key} -SMQ_USERS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_USERS_DB_HOST=users-db -SMQ_USERS_DB_PORT=5432 -SMQ_USERS_DB_USER=supermq -SMQ_USERS_DB_PASS=supermq -SMQ_USERS_DB_NAME=users -SMQ_USERS_DB_SSL_MODE=disable -SMQ_USERS_DB_SSL_CERT= -SMQ_USERS_DB_SSL_KEY= -SMQ_USERS_DB_SSL_ROOT_CERT= -SMQ_USERS_INSTANCE_ID= -SMQ_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH -SMQ_USERS_ADMIN_EMAIL=admin@example.com -SMQ_USERS_ADMIN_PASSWORD=12345678 -SMQ_USERS_PASS_REGEX=^.{8,}$ -SMQ_USERS_ALLOW_SELF_REGISTER=true -SMQ_OAUTH_UI_REDIRECT_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/tokens/secure -SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/error -SMQ_USERS_DELETE_INTERVAL=24h -SMQ_USERS_DELETE_AFTER=720h -SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset -SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl -SMQ_VERIFICATION_URL_PREFIX=http://localhost/verify-email -SMQ_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl +MG_USERS_LOG_LEVEL=debug +MG_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH +MG_USERS_ADMIN_EMAIL=admin@example.com +MG_USERS_ADMIN_PASSWORD=12345678 +MG_USERS_ADMIN_USERNAME=admin +MG_USERS_ADMIN_FIRST_NAME=super +MG_USERS_ADMIN_LAST_NAME=admin +MG_USERS_PASS_REGEX=^.{8,}$ +MG_USERS_HTTP_HOST=users +MG_USERS_HTTP_PORT=9002 +MG_USERS_HTTP_SERVER_CERT= +MG_USERS_HTTP_SERVER_KEY= +MG_USERS_GRPC_HOST=users +MG_USERS_GRPC_PORT=7002 +MG_USERS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.crt}${GRPC_TLS:+./ssl/certs/domains-grpc-server.crt} +MG_USERS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-server.key}${GRPC_TLS:+./ssl/certs/domains-grpc-server.key} +MG_USERS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_USERS_DB_HOST=users-db +MG_USERS_DB_PORT=5432 +MG_USERS_DB_USER=magistrala +MG_USERS_DB_PASS=magistrala +MG_USERS_DB_NAME=users +MG_USERS_DB_SSL_MODE=disable +MG_USERS_DB_SSL_CERT= +MG_USERS_DB_SSL_KEY= +MG_USERS_DB_SSL_ROOT_CERT= +MG_USERS_INSTANCE_ID= +MG_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH +MG_USERS_ADMIN_EMAIL=admin@example.com +MG_USERS_ADMIN_PASSWORD=12345678 +MG_USERS_PASS_REGEX=^.{8,}$ +MG_USERS_ALLOW_SELF_REGISTER=true +MG_UI_PATH_PREFIX=/ui +MG_OAUTH_UI_REDIRECT_URL=http://localhost:9095${MG_UI_PATH_PREFIX}/tokens/secure +MG_OAUTH_UI_ERROR_URL=http://localhost:9095${MG_UI_PATH_PREFIX}/error +MG_USERS_DELETE_INTERVAL=24h +MG_USERS_DELETE_AFTER=720h +MG_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset +MG_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl +MG_VERIFICATION_URL_PREFIX=http://localhost/verify-email +MG_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl #### Users Client Config -SMQ_USERS_URL=http://users:9002 -SMQ_USERS_GRPC_URL=users:7002 -SMQ_USERS_GRPC_TIMEOUT=300s -SMQ_USERS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} -SMQ_USERS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} -SMQ_USERS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_USERS_URL=http://users:9002 +MG_USERS_GRPC_URL=users:7002 +MG_USERS_GRPC_TIMEOUT=300s +MG_USERS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} +MG_USERS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} +MG_USERS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} ### Email utility -SMQ_EMAIL_HOST=host.docker.internal -SMQ_EMAIL_PORT=2525 -SMQ_EMAIL_USERNAME=from@example.com -SMQ_EMAIL_PASSWORD=password -SMQ_EMAIL_FROM_ADDRESS=from@example.com -SMQ_EMAIL_FROM_NAME=Example -SMQ_EMAIL_INVITATION_TEMPLATE=invitation-sent-email.tmpl -SMQ_EMAIL_ACCEPTANCE_TEMPLATE=invitation-accepted-email.tmpl -SMQ_EMAIL_REJECTION_TEMPLATE=invitation-rejected-email.tmpl +MG_EMAIL_HOST=host.docker.internal +MG_EMAIL_PORT=2525 +MG_EMAIL_USERNAME=from@example.com +MG_EMAIL_PASSWORD=password +MG_EMAIL_FROM_ADDRESS=from@example.com +MG_EMAIL_FROM_NAME=Example +MG_EMAIL_INVITATION_TEMPLATE=invitation-sent-email.tmpl +MG_EMAIL_ACCEPTANCE_TEMPLATE=invitation-accepted-email.tmpl +MG_EMAIL_REJECTION_TEMPLATE=invitation-rejected-email.tmpl ### Notifications -SMQ_NOTIFICATIONS_LOG_LEVEL=debug -SMQ_NOTIFICATIONS_INSTANCE_ID= +MG_NOTIFICATIONS_LOG_LEVEL=debug +MG_NOTIFICATIONS_INSTANCE_ID= ### Google OAuth2 -SMQ_GOOGLE_CLIENT_ID= -SMQ_GOOGLE_CLIENT_SECRET= -SMQ_GOOGLE_REDIRECT_URL= -SMQ_GOOGLE_STATE= +MG_GOOGLE_CLIENT_ID= +MG_GOOGLE_CLIENT_SECRET= +MG_GOOGLE_REDIRECT_URL= +MG_GOOGLE_STATE= ### Groups -SMQ_GROUPS_LOG_LEVEL=debug -SMQ_GROUPS_HTTP_HOST=groups -SMQ_GROUPS_HTTP_PORT=9004 -SMQ_GROUPS_HTTP_SERVER_CERT= -SMQ_GROUPS_HTTP_SERVER_KEY= -SMQ_GROUPS_GRPC_HOST=groups -SMQ_GROUPS_GRPC_PORT=7004 -SMQ_GROUPS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/groups-grpc-server.crt}${GRPC_TLS:+./ssl/certs/groups-grpc-server.crt} -SMQ_GROUPS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/groups-grpc-server.key}${GRPC_TLS:+./ssl/certs/groups-grpc-server.key} -SMQ_GROUPS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_GROUPS_DB_HOST=groups-db -SMQ_GROUPS_DB_PORT=5432 -SMQ_GROUPS_DB_USER=supermq -SMQ_GROUPS_DB_PASS=supermq -SMQ_GROUPS_DB_NAME=groups -SMQ_GROUPS_DB_SSL_MODE=disable -SMQ_GROUPS_DB_SSL_CERT= -SMQ_GROUPS_DB_SSL_KEY= -SMQ_GROUPS_DB_SSL_ROOT_CERT= -SMQ_GROUPS_INSTANCE_ID= +MG_GROUPS_LOG_LEVEL=debug +MG_GROUPS_HTTP_HOST=groups +MG_GROUPS_HTTP_PORT=9004 +MG_GROUPS_HTTP_SERVER_CERT= +MG_GROUPS_HTTP_SERVER_KEY= +MG_GROUPS_GRPC_HOST=groups +MG_GROUPS_GRPC_PORT=7004 +MG_GROUPS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/groups-grpc-server.crt}${GRPC_TLS:+./ssl/certs/groups-grpc-server.crt} +MG_GROUPS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/groups-grpc-server.key}${GRPC_TLS:+./ssl/certs/groups-grpc-server.key} +MG_GROUPS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_GROUPS_DB_HOST=groups-db +MG_GROUPS_DB_PORT=5432 +MG_GROUPS_DB_USER=magistrala +MG_GROUPS_DB_PASS=magistrala +MG_GROUPS_DB_NAME=groups +MG_GROUPS_DB_SSL_MODE=disable +MG_GROUPS_DB_SSL_CERT= +MG_GROUPS_DB_SSL_KEY= +MG_GROUPS_DB_SSL_ROOT_CERT= +MG_GROUPS_INSTANCE_ID= #### Groups Client Config -SMQ_GROUPS_URL=groups:9004 -SMQ_GROUPS_GRPC_URL=groups:7004 -SMQ_GROUPS_GRPC_TIMEOUT=300s -SMQ_GROUPS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/groups-grpc-client.crt} -SMQ_GROUPS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/groups-grpc-client.key} -SMQ_GROUPS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_GROUPS_URL=groups:9004 +MG_GROUPS_GRPC_URL=groups:7004 +MG_GROUPS_GRPC_TIMEOUT=300s +MG_GROUPS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/groups-grpc-client.crt} +MG_GROUPS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/groups-grpc-client.key} +MG_GROUPS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} ### Clients -SMQ_CLIENTS_LOG_LEVEL=debug -SMQ_CLIENTS_STANDALONE_ID= -SMQ_CLIENTS_STANDALONE_TOKEN= -SMQ_CLIENTS_CACHE_KEY_DURATION=10m -SMQ_CLIENTS_HTTP_HOST=clients -SMQ_CLIENTS_HTTP_PORT=9006 -SMQ_CLIENTS_GRPC_HOST=clients -SMQ_CLIENTS_GRPC_PORT=7006 -SMQ_CLIENTS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/clients-grpc-server.crt}${GRPC_TLS:+./ssl/certs/clients-grpc-server.crt} -SMQ_CLIENTS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/clients-grpc-server.key}${GRPC_TLS:+./ssl/certs/clients-grpc-server.key} -SMQ_CLIENTS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_CLIENTS_CACHE_URL=redis://clients-redis:${SMQ_REDIS_TCP_PORT}/0 -SMQ_CLIENTS_DB_HOST=clients-db -SMQ_CLIENTS_DB_PORT=5432 -SMQ_CLIENTS_DB_USER=supermq -SMQ_CLIENTS_DB_PASS=supermq -SMQ_CLIENTS_DB_NAME=clients -SMQ_CLIENTS_DB_SSL_MODE=disable -SMQ_CLIENTS_DB_SSL_CERT= -SMQ_CLIENTS_DB_SSL_KEY= -SMQ_CLIENTS_DB_SSL_ROOT_CERT= -SMQ_CLIENTS_INSTANCE_ID= +MG_CLIENTS_LOG_LEVEL=debug +MG_CLIENTS_STANDALONE_ID= +MG_CLIENTS_STANDALONE_TOKEN= +MG_CLIENTS_CACHE_KEY_DURATION=10m +MG_CLIENTS_HTTP_HOST=clients +MG_CLIENTS_HTTP_PORT=9006 +MG_CLIENTS_GRPC_HOST=clients +MG_CLIENTS_GRPC_PORT=7006 +MG_CLIENTS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/clients-grpc-server.crt}${GRPC_TLS:+./ssl/certs/clients-grpc-server.crt} +MG_CLIENTS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/clients-grpc-server.key}${GRPC_TLS:+./ssl/certs/clients-grpc-server.key} +MG_CLIENTS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_CLIENTS_CACHE_URL=redis://clients-redis:${MG_REDIS_TCP_PORT}/0 +MG_CLIENTS_DB_HOST=clients-db +MG_CLIENTS_DB_PORT=5432 +MG_CLIENTS_DB_USER=magistrala +MG_CLIENTS_DB_PASS=magistrala +MG_CLIENTS_DB_NAME=clients +MG_CLIENTS_DB_SSL_MODE=disable +MG_CLIENTS_DB_SSL_CERT= +MG_CLIENTS_DB_SSL_KEY= +MG_CLIENTS_DB_SSL_ROOT_CERT= +MG_CLIENTS_INSTANCE_ID= #### Clients Client Config -SMQ_CLIENTS_URL=http://clients:9006 -SMQ_CLIENTS_GRPC_URL=clients:7006 -SMQ_CLIENTS_GRPC_TIMEOUT=300s -SMQ_CLIENTS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/clients-grpc-client.crt} -SMQ_CLIENTS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/clients-grpc-client.key} -SMQ_CLIENTS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_CLIENTS_URL=http://clients:9006 +MG_CLIENTS_GRPC_URL=clients:7006 +MG_CLIENTS_GRPC_TIMEOUT=300s +MG_CLIENTS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/clients-grpc-client.crt} +MG_CLIENTS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/clients-grpc-client.key} +MG_CLIENTS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} ### Channels -SMQ_CHANNELS_LOG_LEVEL=debug -SMQ_CHANNELS_HTTP_HOST=channels -SMQ_CHANNELS_HTTP_PORT=9005 -SMQ_CHANNELS_GRPC_HOST=channels -SMQ_CHANNELS_GRPC_PORT=7005 -SMQ_CHANNELS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/channels-grpc-server.crt}${GRPC_TLS:+./ssl/certs/channels-grpc-server.crt} -SMQ_CHANNELS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/channels-grpc-server.key}${GRPC_TLS:+./ssl/certs/channels-grpc-server.key} -SMQ_CHANNELS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} -SMQ_CHANNELS_DB_HOST=channels-db -SMQ_CHANNELS_DB_PORT=5432 -SMQ_CHANNELS_DB_USER=supermq -SMQ_CHANNELS_DB_PASS=supermq -SMQ_CHANNELS_DB_NAME=channels -SMQ_CHANNELS_DB_SSL_MODE=disable -SMQ_CHANNELS_DB_SSL_CERT= -SMQ_CHANNELS_DB_SSL_KEY= -SMQ_CHANNELS_DB_SSL_ROOT_CERT= -SMQ_CHANNELS_INSTANCE_ID= -SMQ_CHANNELS_CACHE_URL=redis://channels-redis:${SMQ_REDIS_TCP_PORT}/0 -SMQ_CHANNELS_CACHE_KEY_DURATION=10m +MG_CHANNELS_LOG_LEVEL=debug +MG_CHANNELS_HTTP_HOST=channels +MG_CHANNELS_HTTP_PORT=9005 +MG_CHANNELS_GRPC_HOST=channels +MG_CHANNELS_GRPC_PORT=7005 +MG_CHANNELS_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/channels-grpc-server.crt}${GRPC_TLS:+./ssl/certs/channels-grpc-server.crt} +MG_CHANNELS_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/channels-grpc-server.key}${GRPC_TLS:+./ssl/certs/channels-grpc-server.key} +MG_CHANNELS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} +MG_CHANNELS_DB_HOST=channels-db +MG_CHANNELS_DB_PORT=5432 +MG_CHANNELS_DB_USER=magistrala +MG_CHANNELS_DB_PASS=magistrala +MG_CHANNELS_DB_NAME=channels +MG_CHANNELS_DB_SSL_MODE=disable +MG_CHANNELS_DB_SSL_CERT= +MG_CHANNELS_DB_SSL_KEY= +MG_CHANNELS_DB_SSL_ROOT_CERT= +MG_CHANNELS_INSTANCE_ID= +MG_CHANNELS_CACHE_URL=redis://channels-redis:${MG_REDIS_TCP_PORT}/0 +MG_CHANNELS_CACHE_KEY_DURATION=10m #### Channels Client Config -SMQ_CHANNELS_URL=http://channels:9005 -SMQ_CHANNELS_GRPC_URL=channels:7005 -SMQ_CHANNELS_GRPC_TIMEOUT=300s -SMQ_CHANNELS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/channels-grpc-client.crt} -SMQ_CHANNELS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/channels-grpc-client.key} -SMQ_CHANNELS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_CHANNELS_URL=http://channels:9005 +MG_CHANNELS_GRPC_URL=channels:7005 +MG_CHANNELS_GRPC_TIMEOUT=300s +MG_CHANNELS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/channels-grpc-client.crt} +MG_CHANNELS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/channels-grpc-client.key} +MG_CHANNELS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} -### HTTP -SMQ_HTTP_ADAPTER_LOG_LEVEL=debug -SMQ_HTTP_ADAPTER_HOST=http-adapter -SMQ_HTTP_ADAPTER_PORT=8008 -SMQ_HTTP_ADAPTER_SERVER_CERT= -SMQ_HTTP_ADAPTER_SERVER_KEY= -SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS=200000 -SMQ_HTTP_ADAPTER_CACHE_MAX_COST=1048576 -SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS=64 -SMQ_HTTP_ADAPTER_INSTANCE_ID= - -### MQTT -SMQ_MQTT_ADAPTER_LOG_LEVEL=debug -SMQ_MQTT_ADAPTER_MQTT_PORT=1883 -SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT=30s -SMQ_MQTT_ADAPTER_WS_PORT=8080 -SMQ_MQTT_ADAPTER_INSTANCE= -SMQ_MQTT_ADAPTER_INSTANCE_ID= -SMQ_MQTT_ADAPTER_ES_DB=0 -SMQ_MQTT_ADAPTER_CACHE_NUM_COUNTERS=200000 -SMQ_MQTT_ADAPTER_CACHE_MAX_COST=1048576 -SMQ_MQTT_ADAPTER_CACHE_BUFFER_ITEMS=64 -SMQ_MQTT_ADAPTER_CERT_FILE= -SMQ_MQTT_ADAPTER_KEY_FILE= -SMQ_MQTT_ADAPTER_SERVER_CA_FILE= -SMQ_MQTT_ADAPTER_CLIENT_CA_FILE= -SMQ_MQTT_ADAPTER_CERT_VERIFICATION_METHODS= -SMQ_MQTT_ADAPTER_OCSP_RESPONDER_URL= +### FluxMQ Auth Bridge +MG_FLUXMQ_LOG_LEVEL=debug +MG_FLUXMQ_GRPC_HOST=fluxmq-auth +MG_FLUXMQ_GRPC_PORT=7016 +MG_FLUXMQ_GRPC_URL=fluxmq-auth:7016 +MG_FLUXMQ_INSTANCE_ID= +MG_FLUXMQ_CACHE_NUM_COUNTERS=200000 +MG_FLUXMQ_CACHE_MAX_COST=1048576 +MG_FLUXMQ_CACHE_BUFFER_ITEMS=64 ### CoAP -## If enabled run make all inside docker/ssl directory to generate the DTLS certs -SMQ_COAP_DTLS= -SMQ_COAP_ADAPTER_LOG_LEVEL=debug -SMQ_COAP_ADAPTER_HOST=coap-adapter -SMQ_COAP_ADAPTER_PORT=5683 -SMQ_COAP_ADAPTER_SERVER_CERT_FILE=${SMQ_COAP_DTLS:+./ssl/certs/coap-server.crt} -SMQ_COAP_ADAPTER_SERVER_KEY_FILE=${SMQ_COAP_DTLS:+./ssl/certs/coap-server.key} -SMQ_COAP_ADAPTER_SERVER_CA_FILE=${SMQ_COAP_DTLS:+./ssl/certs/coap-server-ca.crt} -SMQ_COAP_ADAPTER_HTTP_HOST=coap-adapter -SMQ_COAP_ADAPTER_HTTP_PORT=5683 -SMQ_COAP_ADAPTER_HTTP_SERVER_CERT= -SMQ_COAP_ADAPTER_HTTP_SERVER_KEY= -SMQ_COAP_ADAPTER_CACHE_NUM_COUNTERS=200000 -SMQ_COAP_ADAPTER_CACHE_MAX_COST=1048576 -SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS=64 -SMQ_COAP_ADAPTER_INSTANCE_ID= +MG_COAP_PORT=5683 ## Addons Services # Certs -AM_CERTS_LOG_LEVEL=debug -AM_CERTS_HTTP_HOST=certs -AM_CERTS_HTTP_PORT=9019 -AM_CERTS_GRPC_HOST=certs -AM_CERTS_GRPC_PORT=7012 -AM_CERTS_RELEASE_TAG=latest -AM_CERTS_SECRET=12345678 +MG_CERTS_LOG_LEVEL=debug +MG_CERTS_HTTP_HOST=certs +MG_CERTS_HTTP_PORT=9019 +MG_CERTS_GRPC_HOST=certs +MG_CERTS_GRPC_PORT=7012 +# WARNING: This is a development/testing secret only. +# NEVER use this weak secret in production! Generate a strong random secret for production deployments. +MG_CERTS_SECRET=12345678 ## Certs Database Configuration -AM_CERTS_DB_HOST=certs-db -AM_CERTS_DB_PORT=5432 -AM_CERTS_DB_USER=absmach -AM_CERTS_DB_PASS=absmach -AM_CERTS_DB=certs -AM_CERTS_DB_SSL_MODE=disable -AM_CERTS_DB_MAX_CONNECTIONS=100 +MG_CERTS_DB_HOST=certs-db +MG_CERTS_DB_PORT=5432 +MG_CERTS_DB_USER=absmach +MG_CERTS_DB_PASS=absmach +MG_CERTS_DB=certs +MG_CERTS_DB_SSL_MODE=disable +MG_CERTS_DB_MAX_CONNECTIONS=100 ## OpenBao Configuration for Certs -AM_CERTS_OPENBAO_HOST=http://certs-openbao:8200 -AM_CERTS_OPENBAO_APP_ROLE=absmach -AM_CERTS_OPENBAO_APP_SECRET=absmach -AM_CERTS_OPENBAO_NAMESPACE= -AM_CERTS_OPENBAO_PKI_PATH=pki -AM_CERTS_OPENBAO_ROLE=absmach -AM_CERTS_OPENBAO_SECRET_ID_TTL=720h -AM_CERTS_SERVICE_TOKEN_PATH=/openbao/service_token -AM_CERTS_SECRET_ID_PATH=/openbao/secret_id -AM_CERTS_SECRET_RENEW_THRESHOLD=24h -AM_CERTS_SECRET_CHECK_INTERVAL=1h +MG_CERTS_OPENBAO_HOST=http://openbao:8200 +MG_CERTS_OPENBAO_APP_ROLE=absmach +MG_CERTS_OPENBAO_APP_SECRET=absmach +MG_CERTS_OPENBAO_NAMESPACE= +MG_CERTS_OPENBAO_PKI_PATH=pki +MG_CERTS_OPENBAO_ROLE=absmach +MG_CERTS_OPENBAO_SECRET_ID_TTL=720h +MG_CERTS_SERVICE_TOKEN_PATH=/openbao/service_token +MG_CERTS_SECRET_ID_PATH=/openbao/secret_id +MG_CERTS_SECRET_RENEW_THRESHOLD=24h +MG_CERTS_SECRET_CHECK_INTERVAL=1h ## OpenBao PKI CA Configuration -AM_CERTS_OPENBAO_PKI_CA_CN=Abstract Machines Certificate Authority -AM_CERTS_OPENBAO_PKI_CA_OU=Abstract Machines -AM_CERTS_OPENBAO_PKI_CA_O=AbstractMachines -AM_CERTS_OPENBAO_PKI_CA_C=FRANCE -AM_CERTS_OPENBAO_PKI_CA_L=PARIS -AM_CERTS_OPENBAO_PKI_CA_ST=PARIS -AM_CERTS_OPENBAO_PKI_CA_ADDR=5 Av. Anatole -AM_CERTS_OPENBAO_PKI_CA_PO=75007 -AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES=localhost -AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES=127.0.0.1,::1 -AM_CERTS_OPENBAO_PKI_CA_URI_SANS= -AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES=info@abstractmachines.rs +MG_CERTS_OPENBAO_PKI_CA_CN=Abstract Machines Certificate Authority +MG_CERTS_OPENBAO_PKI_CA_OU=Abstract Machines +MG_CERTS_OPENBAO_PKI_CA_O=AbstractMachines +MG_CERTS_OPENBAO_PKI_CA_C=FRANCE +MG_CERTS_OPENBAO_PKI_CA_L=PARIS +MG_CERTS_OPENBAO_PKI_CA_ST=PARIS +MG_CERTS_OPENBAO_PKI_CA_ADDR=5 Av. Anatole +MG_CERTS_OPENBAO_PKI_CA_PO=75007 +MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES=localhost +MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES=127.0.0.1,::1 +MG_CERTS_OPENBAO_PKI_CA_URI_SANS= +MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES=info@abstractmachines.rs ## OpenBao Unseal Keys and Token -AM_CERTS_OPENBAO_UNSEAL_KEY_1= -AM_CERTS_OPENBAO_UNSEAL_KEY_2= -AM_CERTS_OPENBAO_UNSEAL_KEY_3= -AM_CERTS_OPENBAO_ROOT_TOKEN= +MG_CERTS_OPENBAO_UNSEAL_KEY_1= +MG_CERTS_OPENBAO_UNSEAL_KEY_2= +MG_CERTS_OPENBAO_UNSEAL_KEY_3= +MG_CERTS_OPENBAO_ROOT_TOKEN= ## Jaeger Configuration for Certs -AM_JAEGER_URL=http://jaeger:4318/v1/traces -AM_JAEGER_TRACE_RATIO=1.0 +MG_JAEGER_URL=http://jaeger:4318/v1/traces +MG_JAEGER_TRACE_RATIO=1.0 #### Auth Client Config for Certs Service -SMQ_ADDONS_CERTS_PATH_PREFIX=../../ -AM_AUTH_GRPC_URL=auth:7001 -AM_AUTH_GRPC_TIMEOUT=300s -AM_AUTH_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt} -AM_AUTH_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key} -AM_AUTH_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_ADDONS_CERTS_PATH_PREFIX=../../ +MG_AUTH_GRPC_URL=auth:7001 +MG_AUTH_GRPC_TIMEOUT=300s +MG_AUTH_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt} +MG_AUTH_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key} +MG_AUTH_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} #### Domains Client Config for Certs Service -AM_DOMAINS_GRPC_URL=domains:7003 -AM_DOMAINS_GRPC_TIMEOUT=300s -AM_DOMAINS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} -AM_DOMAINS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} -AM_DOMAINS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_DOMAINS_GRPC_URL=domains:7003 +MG_DOMAINS_GRPC_TIMEOUT=300s +MG_DOMAINS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} +MG_DOMAINS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} +MG_DOMAINS_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} -SMQ_CERTS_JAEGER_FRONTEND=16687 -SMQ_CERTS_JAEGER_OLTP_HTTP=4319 +MG_CERTS_JAEGER_FRONTEND=16687 +MG_CERTS_JAEGER_OLTP_HTTP=4319 ### Postgres -SMQ_POSTGRES_HOST=supermq-postgres -SMQ_POSTGRES_PORT=5432 -SMQ_POSTGRES_USER=supermq -SMQ_POSTGRES_PASS=supermq -SMQ_POSTGRES_NAME=messages -SMQ_POSTGRES_SSL_MODE=disable -SMQ_POSTGRES_SSL_CERT= -SMQ_POSTGRES_SSL_KEY= -SMQ_POSTGRES_SSL_ROOT_CERT= +MG_POSTGRES_HOST=postgres +MG_POSTGRES_PORT=5432 +MG_POSTGRES_USER=magistrala +MG_POSTGRES_PASS=magistrala +MG_POSTGRES_NAME=messages +MG_POSTGRES_SSL_MODE=disable +MG_POSTGRES_SSL_CERT= +MG_POSTGRES_SSL_KEY= +MG_POSTGRES_SSL_ROOT_CERT= ### Timescale -SMQ_TIMESCALE_HOST=supermq-timescale -SMQ_TIMESCALE_PORT=5432 -SMQ_TIMESCALE_USER=supermq -SMQ_TIMESCALE_PASS=supermq -SMQ_TIMESCALE_NAME=supermq -SMQ_TIMESCALE_SSL_MODE=disable -SMQ_TIMESCALE_SSL_CERT= -SMQ_TIMESCALE_SSL_KEY= -SMQ_TIMESCALE_SSL_ROOT_CERT= +MG_TIMESCALE_HOST=timescale +MG_TIMESCALE_PORT=5432 +MG_TIMESCALE_USER=magistrala +MG_TIMESCALE_PASS=magistrala +MG_TIMESCALE_NAME=magistrala +MG_TIMESCALE_SSL_MODE=disable +MG_TIMESCALE_SSL_CERT= +MG_TIMESCALE_SSL_KEY= +MG_TIMESCALE_SSL_ROOT_CERT= ### Journal -SMQ_JOURNAL_LOG_LEVEL=info -SMQ_JOURNAL_HTTP_HOST=journal -SMQ_JOURNAL_HTTP_PORT=9021 -SMQ_JOURNAL_HTTP_SERVER_CERT= -SMQ_JOURNAL_HTTP_SERVER_KEY= -SMQ_JOURNAL_DB_HOST=journal-db -SMQ_JOURNAL_DB_PORT=5432 -SMQ_JOURNAL_DB_USER=supermq -SMQ_JOURNAL_DB_PASS=supermq -SMQ_JOURNAL_DB_NAME=journal -SMQ_JOURNAL_DB_SSL_MODE=disable -SMQ_JOURNAL_DB_SSL_CERT= -SMQ_JOURNAL_DB_SSL_KEY= -SMQ_JOURNAL_DB_SSL_ROOT_CERT= -SMQ_JOURNAL_INSTANCE_ID= +MG_JOURNAL_LOG_LEVEL=info +MG_JOURNAL_HTTP_HOST=journal +MG_JOURNAL_HTTP_PORT=9021 +MG_JOURNAL_HTTP_SERVER_CERT= +MG_JOURNAL_HTTP_SERVER_KEY= +MG_JOURNAL_DB_HOST=journal-db +MG_JOURNAL_DB_PORT=5432 +MG_JOURNAL_DB_USER=magistrala +MG_JOURNAL_DB_PASS=magistrala +MG_JOURNAL_DB_NAME=journal +MG_JOURNAL_DB_SSL_MODE=disable +MG_JOURNAL_DB_SSL_CERT= +MG_JOURNAL_DB_SSL_KEY= +MG_JOURNAL_DB_SSL_ROOT_CERT= +MG_JOURNAL_INSTANCE_ID= ### GRAFANA and PROMETHEUS -SMQ_PROMETHEUS_PORT=9090 -SMQ_GRAFANA_PORT=3000 -SMQ_GRAFANA_ADMIN_USER=supermq -SMQ_GRAFANA_ADMIN_PASSWORD=supermq - -## Allow unverified user to access -SMQ_ALLOW_UNVERIFIED_USER=true +MG_PROMETHEUS_PORT=9090 +MG_GRAFANA_PORT=3001 +MG_GRAFANA_ADMIN_USER=magistrala +MG_GRAFANA_ADMIN_PASSWORD=magistrala +# Allow unverified user +MG_ALLOW_UNVERIFIED_USER=true # Docker image tag -SMQ_RELEASE_TAG=latest +MG_RELEASE_TAG=latest + +MG_BOOTSTRAP_URL=http://bootstrap:9013 +MG_CERTS_URL=http://certs:9019 +MG_HTTP_ADAPTER_URL=http://http-adapter:8008 +MG_READER_URL=http://timescale-reader:9011 +MG_JOURNAL_URL=http://journal:9021 + +## Object Storage (SeaweedFS) +MG_BACKEND_OBJECT_STORAGE_REGION=us-east-1 +MG_BACKEND_OBJECT_STORAGE_BUCKET=magistrala +MG_BACKEND_OBJECT_STORAGE_ENDPOINT=http://seaweedfs-s3:8333 +MG_BACKEND_OBJECT_STORAGE_USE_PATH_STYLE=true +MG_BACKEND_OBJECT_STORAGE_PRESIGN_ENDPOINT= +MG_BACKEND_OBJECT_STORAGE_ACCESS_KEY=admin +MG_BACKEND_OBJECT_STORAGE_SECRET_KEY=admin +MG_BACKEND_OBJECT_STORAGE_TTL=1h +MG_BACKEND_OBJECT_STORAGE_READ_TTL=1h + +#### Timescale Reader gRPC Client Config +MG_TIMESCALE_READER_GRPC_URL=timescale-reader:7011 +MG_TIMESCALE_READER_GRPC_TIMEOUT=300s +MG_TIMESCALE_READER_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key} +MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} + +## Rules Engine +MG_RE_LOG_LEVEL=debug +MG_RE_HTTP_HOST=re +MG_RE_HTTP_PORT=9008 +MG_RE_HTTP_SERVER_CERT= +MG_RE_HTTP_SERVER_KEY= +MG_RE_DB_HOST=re-db +MG_RE_DB_PORT=5432 +MG_RE_DB_USER=magistrala +MG_RE_DB_PASS=magistrala +MG_RE_DB_NAME=rules_engine +MG_RE_DB_SSL_MODE=disable +MG_RE_DB_SSL_CERT= +MG_RE_DB_SSL_KEY= +MG_RE_DB_SSL_ROOT_CERT= +MG_RE_INSTANCE_ID= +MG_RE_EMAIL_TEMPLATE=re.tmpl +MG_RE_CALLOUT_URLS="" +MG_RE_CALLOUT_METHOD="POST" +MG_RE_CALLOUT_TLS_VERIFICATION="false" +MG_RE_CALLOUT_TIMEOUT="10s" +MG_RE_CALLOUT_CA_CERT="" +MG_RE_CALLOUT_CERT="" +MG_RE_CALLOUT_KEY="" +MG_RE_CALLOUT_OPERATIONS="" +MG_RE_URL=http://re:9008 + +## Email +MG_EMAIL_HOST=host.docker.internal +MG_EMAIL_PORT=2525 +MG_EMAIL_USERNAME=from@example.com +MG_EMAIL_PASSWORD=password +MG_EMAIL_FROM_ADDRESS=from@example.com +MG_EMAIL_FROM_NAME=Example +MG_EMAIL_TEMPLATE=email.tmpl + +## Alarms +MG_ALARMS_LOG_LEVEL=debug +MG_ALARMS_HTTP_HOST=alarms +MG_ALARMS_HTTP_PORT=8050 +MG_ALARMS_HTTP_SERVER_CERT= +MG_ALARMS_HTTP_SERVER_KEY= +MG_ALARMS_DB_HOST=alarms-db +MG_ALARMS_DB_PORT=5432 +MG_ALARMS_DB_USER=magistrala +MG_ALARMS_DB_PASS=magistrala +MG_ALARMS_DB_NAME=alarms +MG_ALARMS_DB_SSL_MODE=disable +MG_ALARMS_DB_SSL_CERT= +MG_ALARMS_DB_SSL_KEY= +MG_ALARMS_DB_SSL_ROOT_CERT= +MG_ALARMS_INSTANCE_ID= +MG_ALARMS_EVENT_CONSUMER=alarms +MG_ALARMS_URL=http://alarms:8050 + +## Reports +MG_REPORTS_LOG_LEVEL=debug +MG_REPORTS_HTTP_HOST=reports +MG_REPORTS_HTTP_PORT=9017 +MG_REPORTS_HTTP_SERVER_CERT= +MG_REPORTS_HTTP_SERVER_KEY= +MG_REPORTS_DB_HOST=reports-db +MG_REPORTS_DB_PORT=5432 +MG_REPORTS_DB_USER=magistrala +MG_REPORTS_DB_PASS=magistrala +MG_REPORTS_DB_NAME=reports +MG_REPORTS_DB_SSL_MODE=disable +MG_REPORTS_DB_SSL_CERT= +MG_REPORTS_DB_SSL_KEY= +MG_REPORTS_DB_SSL_ROOT_CERT= +MG_REPORTS_INSTANCE_ID= +MG_REPORTS_EMAIL_TEMPLATE=reports.tmpl +MG_REPORTS_DEFAULT_TEMPLATE= +MG_PDF_CONVERTER_URL=http://pdf-generator:3000/forms/chromium/convert/html +MG_REPORTS_URL=http://reports:9017 + +## Addon Services + +### Bootstrap +MG_BOOTSTRAP_LOG_LEVEL=debug +MG_BOOTSTRAP_ENCRYPT_KEY=v7aT0HGxJxt2gULzr3RHwf4WIf6DusPp +MG_BOOTSTRAP_EVENT_CONSUMER=bootstrap +MG_BOOTSTRAP_HTTP_HOST=bootstrap +MG_BOOTSTRAP_HTTP_PORT=9013 +MG_BOOTSTRAP_HTTP_SERVER_CERT= +MG_BOOTSTRAP_HTTP_SERVER_KEY= +MG_BOOTSTRAP_DB_HOST=bootstrap-db +MG_BOOTSTRAP_DB_PORT=5432 +MG_BOOTSTRAP_DB_USER=magistrala +MG_BOOTSTRAP_DB_PASS=magistrala +MG_BOOTSTRAP_DB_NAME=bootstrap +MG_BOOTSTRAP_DB_SSL_MODE=disable +MG_BOOTSTRAP_DB_SSL_CERT= +MG_BOOTSTRAP_DB_SSL_KEY= +MG_BOOTSTRAP_DB_SSL_ROOT_CERT= +MG_BOOTSTRAP_INSTANCE_ID= + +### Provision +MG_PROVISION_CONFIG_FILE=/configs/config.toml +MG_PROVISION_LOG_LEVEL=debug +MG_PROVISION_HTTP_PORT=9016 +MG_PROVISION_ENV_CLIENTS_TLS=false +MG_PROVISION_SERVER_CERT= +MG_PROVISION_SERVER_KEY= +MG_PROVISION_USERS_URL=http://users:9002 +MG_PROVISION_CHANNELS_URL=http://channels:9005 +MG_PROVISION_CLIENTS_URL=http://clients:9006 +MG_PROVISION_CERTS_URL=http://certs:9019 +MG_PROVISION_USER= +MG_PROVISION_USERNAME= +MG_PROVISION_PASS= +MG_PROVISION_API_KEY= +MG_PROVISION_X509_PROVISIONING=false +MG_PROVISION_BS_SVC_URL=http://bootstrap:9013 +MG_PROVISION_BS_CONFIG_PROVISIONING=true +MG_PROVISION_BS_AUTO_WHITELIST=true +MG_PROVISION_BS_CONTENT= +MG_PROVISION_CERTS_HOURS_VALID=2400h +MG_PROVISION_CERTS_RSA_BITS=2048 +MG_PROVISION_INSTANCE_ID= + +### Postgres Writer +MG_POSTGRES_WRITER_LOG_LEVEL=debug +MG_POSTGRES_WRITER_CONFIG_PATH=/config.toml +MG_POSTGRES_WRITER_HTTP_HOST=postgres-writer +MG_POSTGRES_WRITER_HTTP_PORT=9007 +MG_POSTGRES_WRITER_HTTP_SERVER_CERT= +MG_POSTGRES_WRITER_HTTP_SERVER_KEY= +MG_POSTGRES_WRITER_INSTANCE_ID= + +### Postgres Reader +MG_POSTGRES_READER_LOG_LEVEL=debug +MG_POSTGRES_READER_HTTP_HOST=postgres-reader +MG_POSTGRES_READER_HTTP_PORT=9009 +MG_POSTGRES_READER_GRPC_HOST=postgres-reader +MG_POSTGRES_READER_GRPC_PORT=7009 +MG_POSTGRES_READER_HTTP_SERVER_CERT= +MG_POSTGRES_READER_HTTP_SERVER_KEY= +MG_POSTGRES_READER_INSTANCE_ID= +MG_POSTGRES_READER_GRPC_URL=postgres-reader:7009 +MG_POSTGRES_READER_GRPC_TIMEOUT=300s +MG_POSTGRES_READER_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.crt} +MG_POSTGRES_READER_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_POSTGRES_READER_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key} +MG_POSTGRES_READER_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/readers-grpc-server.crt}${GRPC_TLS:+./ssl/certs/readers-grpc-server.crt} +MG_POSTGRES_READER_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-server.key}${GRPC_TLS:+./ssl/certs/readers-grpc-server.key} +MG_POSTGRES_READER_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} + +### Timescale Writer +MG_TIMESCALE_WRITER_LOG_LEVEL=debug +MG_TIMESCALE_WRITER_CONFIG_PATH=/config.toml +MG_TIMESCALE_WRITER_HTTP_HOST=timescale-writer +MG_TIMESCALE_WRITER_HTTP_PORT=9012 +MG_TIMESCALE_WRITER_HTTP_SERVER_CERT= +MG_TIMESCALE_WRITER_HTTP_SERVER_KEY= +MG_TIMESCALE_WRITER_INSTANCE_ID= + +### Timescale Reader +MG_TIMESCALE_READER_LOG_LEVEL=debug +MG_TIMESCALE_READER_HTTP_HOST=timescale-reader +MG_TIMESCALE_READER_HTTP_PORT=9011 +MG_TIMESCALE_READER_GRPC_HOST=timescale-reader +MG_TIMESCALE_READER_GRPC_PORT=7011 +MG_TIMESCALE_READER_HTTP_SERVER_CERT= +MG_TIMESCALE_READER_HTTP_SERVER_KEY= +MG_TIMESCALE_READER_INSTANCE_ID= +MG_TIMESCALE_READER_GRPC_SERVER_CERT=${GRPC_MTLS:+./ssl/certs/readers-grpc-server.crt}${GRPC_TLS:+./ssl/certs/readers-grpc-server.crt} +MG_TIMESCALE_READER_GRPC_SERVER_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-server.key}${GRPC_TLS:+./ssl/certs/readers-grpc-server.key} +MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} + +#### Timescale Reader Client Config +MG_TIMESCALE_READER_URL=http://timescale-reader:9011 +MG_TIMESCALE_READER_GRPC_URL=timescale-reader:7011 +MG_TIMESCALE_READER_GRPC_TIMEOUT=300s +MG_TIMESCALE_READER_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/reader-grpc-client.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key} + +## Magistrala Services (MG_ prefix) + +### RE (Rules Engine) +MG_RE_LOG_LEVEL=debug +MG_RE_HTTP_HOST=re +MG_RE_HTTP_PORT=9008 +MG_RE_HTTP_SERVER_CERT= +MG_RE_HTTP_SERVER_KEY= +MG_RE_DB_HOST=re-db +MG_RE_DB_PORT=5432 +MG_RE_DB_USER=magistrala +MG_RE_DB_PASS=magistrala +MG_RE_DB_NAME=rules_engine +MG_RE_DB_SSL_MODE=disable +MG_RE_DB_SSL_CERT= +MG_RE_DB_SSL_KEY= +MG_RE_DB_SSL_ROOT_CERT= +MG_RE_INSTANCE_ID= +MG_RE_EMAIL_TEMPLATE=re.tmpl +MG_RE_CALLOUT_URLS="" +MG_RE_CALLOUT_METHOD="POST" +MG_RE_CALLOUT_TLS_VERIFICATION="false" +MG_RE_CALLOUT_TIMEOUT="10s" +MG_RE_CALLOUT_CA_CERT="" +MG_RE_CALLOUT_CERT="" +MG_RE_CALLOUT_KEY="" +MG_RE_CALLOUT_OPERATIONS="" +MG_RE_URL=http://re:9008 + +### Email (shared by RE and Reports) +MG_EMAIL_HOST=smtp.mailtrap.io +MG_EMAIL_PORT=2525 +MG_EMAIL_USERNAME=18bf7f70705139 +MG_EMAIL_PASSWORD=2b0d302e775b1e +MG_EMAIL_FROM_ADDRESS=from@example.com +MG_EMAIL_FROM_NAME=Example +MG_EMAIL_TEMPLATE=email.tmpl + +### Alarms +MG_ALARMS_LOG_LEVEL=debug +MG_ALARMS_HTTP_HOST=alarms +MG_ALARMS_HTTP_PORT=8050 +MG_ALARMS_HTTP_SERVER_CERT= +MG_ALARMS_HTTP_SERVER_KEY= +MG_ALARMS_DB_HOST=alarms-db +MG_ALARMS_DB_PORT=5432 +MG_ALARMS_DB_USER=magistrala +MG_ALARMS_DB_PASS=magistrala +MG_ALARMS_DB_NAME=alarms +MG_ALARMS_DB_SSL_MODE=disable +MG_ALARMS_DB_SSL_CERT= +MG_ALARMS_DB_SSL_KEY= +MG_ALARMS_DB_SSL_ROOT_CERT= +MG_ALARMS_INSTANCE_ID= +MG_ALARMS_EVENT_CONSUMER=alarms +MG_ALARMS_URL=http://alarms:8050 + +### Reports +MG_REPORTS_LOG_LEVEL=debug +MG_REPORTS_HTTP_HOST=reports +MG_REPORTS_HTTP_PORT=9017 +MG_REPORTS_HTTP_SERVER_CERT= +MG_REPORTS_HTTP_SERVER_KEY= +MG_REPORTS_DB_HOST=reports-db +MG_REPORTS_DB_PORT=5432 +MG_REPORTS_DB_USER=magistrala +MG_REPORTS_DB_PASS=magistrala +MG_REPORTS_DB_NAME=reports +MG_REPORTS_DB_SSL_MODE=disable +MG_REPORTS_DB_SSL_CERT= +MG_REPORTS_DB_SSL_KEY= +MG_REPORTS_DB_SSL_ROOT_CERT= +MG_REPORTS_INSTANCE_ID= +MG_REPORTS_EMAIL_TEMPLATE=reports.tmpl +MG_REPORTS_DEFAULT_TEMPLATE= +MG_REPORTS_URL=http://reports:9017 +MG_PDF_CONVERTER_URL=http://pdf-generator:3000/forms/chromium/convert/html + +### Timescale Reader gRPC Client Config (Magistrala) +MG_TIMESCALE_READER_GRPC_URL=timescale-reader:7011 +MG_TIMESCALE_READER_GRPC_TIMEOUT=300s +MG_TIMESCALE_READER_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/reader-grpc-client.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} +MG_TIMESCALE_READER_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key} +MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt}${GRPC_TLS:+./ssl/certs/ca.crt} + +# UI components and RE +## Dashboards +MG_UI_BACKEND_LOG_LEVEL=debug +MG_UI_BACKEND_HTTP_HOST=ui-backend +MG_UI_BACKEND_HTTP_PORT=9097 +MG_UI_BACKEND_HTTP_SERVER_CERT= +MG_UI_BACKEND_HTTP_SERVER_KEY= +MG_UI_BACKEND_INSTANCE_ID= +MG_UI_BACKEND_URL=http://ui-backend:9097 +MG_UI_VERIFICATION_TLS=false +MG_UI_CONTENT_TYPE=application/senml+json +# Set to yes to accept the EULA for the UI services. To view the EULA visit: https://github.com/absmach/eula +MG_UI_DOCKER_ACCEPT_EULA=no + +# Object storage for images +# See docker/seaweedfs/s3.json. +MG_BACKEND_OBJECT_STORAGE_REGION=fra1 +MG_BACKEND_OBJECT_STORAGE_BUCKET=mg-ui-images +MG_BACKEND_OBJECT_STORAGE_ENDPOINT=http://seaweedfs-s3:8333 +MG_BACKEND_OBJECT_STORAGE_USE_PATH_STYLE=true +MG_BACKEND_OBJECT_STORAGE_PRESIGN_ENDPOINT=http://localhost:8333 +MG_BACKEND_OBJECT_STORAGE_ACCESS_KEY=localKey +MG_BACKEND_OBJECT_STORAGE_SECRET_KEY=localSecret +MG_BACKEND_OBJECT_STORAGE_TTL=15m +MG_BACKEND_OBJECT_STORAGE_READ_TTL=15m + +## Postgres +MG_UI_BACKEND_DB_HOST=ui-backend-db +MG_UI_BACKEND_DB_PORT=5432 +MG_UI_BACKEND_DB_USER=magistrala +MG_UI_BACKEND_DB_PASS=magistrala +MG_UI_BACKEND_DB_NAME=magistrala +MG_UI_BACKEND_DB_SSL_MODE=disable +MG_UI_BACKEND_DB_SSL_CERT= +MG_UI_BACKEND_DB_SSL_KEY= +MG_UI_BACKEND_DB_SSL_ROOT_CERT= + +## UI +MG_AUTH_URL=http://auth:9001 +MG_DOMAINS_URL=http://domains:9003 +MG_USERS_URL=http://users:9002 +MG_CLIENTS_URL=http://clients:9006 +MG_CHANNELS_URL=http://channels:9005 +MG_GROUPS_URL=http://groups:9004 +MG_BOOTSTRAP_URL=http://bootstrap:9013 +MG_CERTS_URL=http://certs:9019 +MG_HTTP_ADAPTER_URL=http://nginx:80/http +MG_READER_URL=http://timescale-reader:9011 +MG_JOURNAL_URL=http://journal:9021 + +### UI Configuration +MG_UI_TYPE=mg +MG_UI_BASE_PATH=/ +MG_NEXTAUTH_BASE_PATH=/api/auth +NEXTAUTH_SECRET=4WdW0Z0tAOyQ/ZAI3YLVV/wNu+yUZXBLDDQ3AGrgfJ4= +NEXTAUTH_URL=http://localhost:3000 +MG_HOST_URL=http://localhost:3000 +MG_UI_IMAGE_URL=http://ui-backend:9097 +MG_UI_BASEURL=http://localhost:3000 + +### Google OAuth2 (UI) +MG_GOOGLE_CLIENT_ID= +MG_GOOGLE_CLIENT_SECRET= +MG_GOOGLE_REDIRECT_URL=http://localhost:3000/oauth/callback/google +MG_GOOGLE_STATE=pGXVNhEeKfycuBzk5InlSfMlEU9UrhlkTUOSqhsgDzXP2Y4RsN + +#Customer support email variables +MG_SUPPORT_EMAIL= +MG_SUPPORT_EMAIL_PASS= + +## SMTP Variables +MG_UI_SMTP_HOST=host.docker.internal +MG_UI_SMTP_PORT=2525 +MG_UI_SMTP_SECURE= +MG_UI_SUPPORT_FROM=from@example.com + +# Message cli variables +MG_UI_CLI_MQTT_HOST=localhost +MG_UI_CLI_WS_URL=ws://localhost:80/mqtt +MG_UI_CLI_COAP_HOST=0.0.0.0 +MG_UI_CLI_COAP_PORT=5684 +MG_UI_CLI_HTTP_URL=http://localhost:80/http diff --git a/docker/README.md b/docker/README.md index bcf209f2a..942a66e79 100644 --- a/docker/README.md +++ b/docker/README.md @@ -22,7 +22,7 @@ To start additional addon services: docker compose -f docker/addons//docker-compose.yaml up ``` -To pull docker images from a specific release you need to change the value of `SMQ_RELEASE_TAG` in `.env` before running these commands. +To pull docker images from a specific release you need to change the value of `MG_RELEASE_TAG` in `.env` before running these commands. ## Broker Configuration @@ -53,41 +53,41 @@ Depending on the desired setup, the following broker configurations are valid: > For non-default brokers (e.g. RabbitMQ as message broker), adjust the environment variables appropriately and rebuild Docker images. Example: ```bash -SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq make dockers +MG_MESSAGE_BROKER_TYPE=msg_rabbitmq make dockers ``` Then in `.env`: ```text -SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq -SMQ_MESSAGE_BROKER_URL=${SMQ_RABBITMQ_URL} +MG_MESSAGE_BROKER_TYPE=msg_rabbitmq +MG_MESSAGE_BROKER_URL=${MG_RABBITMQ_URL} ``` For Redis as an events store, you would need to run RabbitMQ or NATS as a message broker. For example, to use Redis as an events store with rabbitmq as a message broker: ```bash -SMQ_ES_TYPE=es_redis SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq make dockers +MG_ES_TYPE=es_redis MG_MESSAGE_BROKER_TYPE=msg_rabbitmq make dockers ``` ```env -SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq -SMQ_MESSAGE_BROKER_URL=${SMQ_RABBITMQ_URL} -SMQ_ES_TYPE=es_redis -SMQ_ES_URL=${SMQ_REDIS_URL} +MG_MESSAGE_BROKER_TYPE=msg_rabbitmq +MG_MESSAGE_BROKER_URL=${MG_RABBITMQ_URL} +MG_ES_TYPE=es_redis +MG_ES_URL=${MG_REDIS_URL} ``` For MQTT broker other than RabbitMQ, you would need to change the `docker/.env`. For example, to use NATS as a MQTT broker: ```env -SMQ_MQTT_BROKER_TYPE=nats -SMQ_MQTT_BROKER_HEALTH_CHECK=${SMQ_NATS_HEALTH_CHECK} -SMQ_MQTT_ADAPTER_MQTT_QOS=${SMQ_NATS_MQTT_QOS} -SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE} -SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT=1883 -SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK=${SMQ_MQTT_BROKER_HEALTH_CHECK} -SMQ_MQTT_ADAPTER_WS_TARGET_HOST=${SMQ_MQTT_BROKER_TYPE} -SMQ_MQTT_ADAPTER_WS_TARGET_PORT=8080 -SMQ_MQTT_ADAPTER_WS_TARGET_PATH=${SMQ_NATS_WS_TARGET_PATH} +MG_MQTT_BROKER_TYPE=nats +MG_MQTT_BROKER_HEALTH_CHECK=${MG_NATS_HEALTH_CHECK} +MG_MQTT_ADAPTER_MQTT_QOS=${MG_NATS_MQTT_QOS} +MG_MQTT_ADAPTER_MQTT_TARGET_HOST=${MG_MQTT_BROKER_TYPE} +MG_MQTT_ADAPTER_MQTT_TARGET_PORT=1883 +MG_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK=${MG_MQTT_BROKER_HEALTH_CHECK} +MG_MQTT_ADAPTER_WS_TARGET_HOST=${MG_MQTT_BROKER_TYPE} +MG_MQTT_ADAPTER_WS_TARGET_PORT=8080 +MG_MQTT_ADAPTER_WS_TARGET_PATH=${MG_NATS_WS_TARGET_PATH} ``` ### RabbitMQ configuration (as MQTT broker or MESSAGE_BROKER) @@ -99,13 +99,13 @@ services: container_name: supermq-rabbitmq restart: on-failure environment: - RABBITMQ_ERLANG_COOKIE: ${SMQ_RABBITMQ_COOKIE} - RABBITMQ_DEFAULT_USER: ${SMQ_RABBITMQ_USER} - RABBITMQ_DEFAULT_PASS: ${SMQ_RABBITMQ_PASS} - RABBITMQ_DEFAULT_VHOST: ${SMQ_RABBITMQ_VHOST} + RABBITMQ_ERLANG_COOKIE: ${MG_RABBITMQ_COOKIE} + RABBITMQ_DEFAULT_USER: ${MG_RABBITMQ_USER} + RABBITMQ_DEFAULT_PASS: ${MG_RABBITMQ_PASS} + RABBITMQ_DEFAULT_VHOST: ${MG_RABBITMQ_VHOST} ports: - - ${SMQ_RABBITMQ_PORT}:${SMQ_RABBITMQ_PORT} - - ${SMQ_RABBITMQ_HTTP_PORT}:${SMQ_RABBITMQ_HTTP_PORT} + - ${MG_RABBITMQ_PORT}:${MG_RABBITMQ_PORT} + - ${MG_RABBITMQ_HTTP_PORT}:${MG_RABBITMQ_HTTP_PORT} networks: - supermq-base-net ``` @@ -131,11 +131,11 @@ By using environment variables file at `docker/.env` you can modify the below gi | Environment Variable | Description | |----------------------|-------------| -| `SMQ_NGINX_SERVER_NAME` | `SMQ_NGINX_SERVER_NAME` environmental variable is used to configure nginx directive `server_name`. If environmental variable `SMQ_NGINX_SERVER_NAME` is empty then default value `localhost` will set to `server_name`. | -| `SMQ_NGINX_SERVER_CERT` | `SMQ_NGINX_SERVER_CERT` environmental variable is used to configure nginx directive `ssl_certificate`. If environmental variable `SMQ_NGINX_SERVER_CERT` is empty then by default server certificate in the path `docker/ssl/certs/supermq-server.crt` will be assigned. | -| `SMQ_NGINX_SERVER_KEY` | `SMQ_NGINX_SERVER_KEY` environmental variable is used to configure nginx directive `ssl_certificate_key`. If environmental variable `SMQ_NGINX_SERVER_KEY` is empty then by default server certificate key in the path `docker/ssl/certs/supermq-server.key` will be assigned. | -| `SMQ_NGINX_SERVER_CLIENT_CA` | `SMQ_NGINX_SERVER_CLIENT_CA` environmental variable is used to configure nginx directive `ssl_client_certificate`. If environmental variable `SMQ_NGINX_SERVER_CLIENT_CA` is empty then by default certificate in the path `docker/ssl/certs/ca.crt` will be assigned. | -| `SMQ_NGINX_SERVER_DHPARAM` | `SMQ_NGINX_SERVER_DHPARAM` environmental variable is used to configure nginx directive `ssl_dhparam`. If environmental variable `SMQ_NGINX_SERVER_DHPARAM` is empty then by default file in the path `docker/ssl/dhparam.pem` will be assigned. | +| `MG_NGINX_SERVER_NAME` | `MG_NGINX_SERVER_NAME` environmental variable is used to configure nginx directive `server_name`. If environmental variable `MG_NGINX_SERVER_NAME` is empty then default value `localhost` will set to `server_name`. | +| `MG_NGINX_SERVER_CERT` | `MG_NGINX_SERVER_CERT` environmental variable is used to configure nginx directive `ssl_certificate`. If environmental variable `MG_NGINX_SERVER_CERT` is empty then by default server certificate in the path `docker/ssl/certs/magistrala-server.crt` will be assigned. | +| `MG_NGINX_SERVER_KEY` | `MG_NGINX_SERVER_KEY` environmental variable is used to configure nginx directive `ssl_certificate_key`. If environmental variable `MG_NGINX_SERVER_KEY` is empty then by default server certificate key in the path `docker/ssl/certs/magistrala-server.key` will be assigned. | +| `MG_NGINX_SERVER_CLIENT_CA` | `MG_NGINX_SERVER_CLIENT_CA` environmental variable is used to configure nginx directive `ssl_client_certificate`. If environmental variable `MG_NGINX_SERVER_CLIENT_CA` is empty then by default certificate in the path `docker/ssl/certs/ca.crt` will be assigned. | +| `MG_NGINX_SERVER_DHPARAM` | `MG_NGINX_SERVER_DHPARAM` environmental variable is used to configure nginx directive `ssl_dhparam`. If environmental variable `MG_NGINX_SERVER_DHPARAM` is empty then by default file in the path `docker/ssl/dhparam.pem` will be assigned. | Adjust these values in `.env` to configure TLS / SSL behavior for your deployment. diff --git a/docker/addons/bootstrap/docker-compose.yaml b/docker/addons/bootstrap/docker-compose.yaml new file mode 100644 index 000000000..88e3fcc7b --- /dev/null +++ b/docker/addons/bootstrap/docker-compose.yaml @@ -0,0 +1,93 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# This docker-compose file contains optional bootstrap services. Since it's optional, this file is +# dependent of docker-compose file from /docker. In order to run this services, execute command: +# docker compose -f docker/docker-compose.yaml -f docker/addons/bootstrap/docker-compose.yaml up +# from project root. + +networks: + magistrala-base-net: + external: true + +volumes: + magistrala-bootstrap-db-volume: + +services: + bootstrap-db: + image: postgres:16.2-alpine + container_name: magistrala-bootstrap-db + restart: on-failure + environment: + POSTGRES_USER: ${MG_BOOTSTRAP_DB_USER} + POSTGRES_PASSWORD: ${MG_BOOTSTRAP_DB_PASS} + POSTGRES_DB: ${MG_BOOTSTRAP_DB_NAME} + networks: + - magistrala-base-net + volumes: + - magistrala-bootstrap-db-volume:/var/lib/postgresql/data + + bootstrap: + image: docker.io/magistrala/bootstrap:${MG_RELEASE_TAG} + container_name: magistrala-bootstrap + depends_on: + - bootstrap-db + restart: on-failure + ports: + - ${MG_BOOTSTRAP_HTTP_PORT}:${MG_BOOTSTRAP_HTTP_PORT} + environment: + MG_BOOTSTRAP_LOG_LEVEL: ${MG_BOOTSTRAP_LOG_LEVEL} + MG_BOOTSTRAP_ENCRYPT_KEY: ${MG_BOOTSTRAP_ENCRYPT_KEY} + MG_BOOTSTRAP_EVENT_CONSUMER: ${MG_BOOTSTRAP_EVENT_CONSUMER} + MG_ES_URL: ${MG_ES_URL} + MG_BOOTSTRAP_HTTP_HOST: ${MG_BOOTSTRAP_HTTP_HOST} + MG_BOOTSTRAP_HTTP_PORT: ${MG_BOOTSTRAP_HTTP_PORT} + MG_BOOTSTRAP_HTTP_SERVER_CERT: ${MG_BOOTSTRAP_HTTP_SERVER_CERT} + MG_BOOTSTRAP_HTTP_SERVER_KEY: ${MG_BOOTSTRAP_HTTP_SERVER_KEY} + MG_BOOTSTRAP_DB_HOST: ${MG_BOOTSTRAP_DB_HOST} + MG_BOOTSTRAP_DB_PORT: ${MG_BOOTSTRAP_DB_PORT} + MG_BOOTSTRAP_DB_USER: ${MG_BOOTSTRAP_DB_USER} + MG_BOOTSTRAP_DB_PASS: ${MG_BOOTSTRAP_DB_PASS} + MG_BOOTSTRAP_DB_NAME: ${MG_BOOTSTRAP_DB_NAME} + MG_BOOTSTRAP_DB_SSL_MODE: ${MG_BOOTSTRAP_DB_SSL_MODE} + MG_BOOTSTRAP_DB_SSL_CERT: ${MG_BOOTSTRAP_DB_SSL_CERT} + MG_BOOTSTRAP_DB_SSL_KEY: ${MG_BOOTSTRAP_DB_SSL_KEY} + MG_BOOTSTRAP_DB_SSL_ROOT_CERT: ${MG_BOOTSTRAP_DB_SSL_ROOT_CERT} + MG_BOOTSTRAP_INSTANCE_ID: ${MG_BOOTSTRAP_INSTANCE_ID} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_CLIENTS_URL: ${MG_CLIENTS_URL} + MG_CHANNELS_URL: ${MG_CHANNELS_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + networks: + - magistrala-base-net + volumes: + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca${MG_AUTH_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true diff --git a/docker/addons/bootstrap/ssl/placeholder b/docker/addons/bootstrap/ssl/placeholder new file mode 100644 index 000000000..f5f101481 --- /dev/null +++ b/docker/addons/bootstrap/ssl/placeholder @@ -0,0 +1 @@ +optional bind-mount placeholder diff --git a/docker/addons/certs/.env b/docker/addons/certs/.env deleted file mode 100644 index 91eb12128..000000000 --- a/docker/addons/certs/.env +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 -# Docker: Environment variables in Compose - -## CERTS -AM_CERTS_LOG_LEVEL=debug -AM_CERTS_DB_HOST=certs-db -AM_CERTS_DB_PORT=5432 -AM_CERTS_DB_USER=absmach -AM_CERTS_DB_PASS=absmach -AM_CERTS_DB=certs -AM_CERTS_DB_SSL_MODE=disable -AM_CERTS_DB_SSL_CERT= -AM_CERTS_DB_SSL_KEY= -AM_CERTS_DB_SSL_ROOT_CERT= -AM_CERTS_DB_MAX_CONNECTIONS=100 -AM_CERTS_HTTP_HOST=certs -AM_CERTS_HTTP_PORT=9010 -AM_CERTS_HTTP_SERVER_CERT= -AM_CERTS_HTTP_SERVER_KEY= -AM_CERTS_GRPC_HOST=certs -AM_CERTS_GRPC_PORT=7012 -AM_CERTS_GRPC_SERVER_CERT= -AM_CERTS_GRPC_SERVER_KEY= -AM_CERTS_GRPC_SERVER_CA_CERTS= -AM_CERTS_GRPC_SERVER_CA_KEY= -AM_CERTS_GRPC_CLIENT_CA_CERTS= -AM_CERTS_GRPC_URL=${AM_CERTS_GRPC_HOST}:${AM_CERTS_GRPC_PORT} -AM_CERTS_GRPC_TIMEOUT= -AM_CERTS_GRPC_CLIENT_CERT= -AM_CERTS_GRPC_CLIENT_KEY= -AM_CERTS_GRPC_CLIENT_TLS= -AM_CERTS_GRPC_CA_CERTS= -AM_CERTS_INSTANCE_ID= -AM_CERTS_RELEASE_TAG=latest -# WARNING: This is a development/testing secret only. -# NEVER use this weak secret in production! Generate a strong random secret for production deployments. -AM_CERTS_SECRET=12345678 - -## OpenBao PKI Config -AM_CERTS_OPENBAO_HOST=http://certs-openbao:8200 -AM_CERTS_OPENBAO_APP_ROLE=absmach -AM_CERTS_OPENBAO_APP_SECRET=absmach -AM_CERTS_OPENBAO_SECRET_ID_TTL=720h -AM_CERTS_OPENBAO_NAMESPACE= -AM_CERTS_OPENBAO_PKI_PATH=pki -AM_CERTS_OPENBAO_ROLE=absmach -AM_CERTS_SERVICE_TOKEN_PATH=/openbao/service_token -AM_CERTS_SECRET_ID_PATH=/openbao/secret_id -AM_CERTS_SECRET_RENEW_THRESHOLD=24h -AM_CERTS_SECRET_CHECK_INTERVAL=1h -AM_CERTS_OPENBAO_PKI_CA_CN=Abstract Machines Certificate Authority -AM_CERTS_OPENBAO_PKI_CA_OU=Abstract Machines -AM_CERTS_OPENBAO_PKI_CA_O=AbstractMachines -AM_CERTS_OPENBAO_PKI_CA_C=FRANCE -AM_CERTS_OPENBAO_PKI_CA_L=PARIS -AM_CERTS_OPENBAO_PKI_CA_ST=PARIS -AM_CERTS_OPENBAO_PKI_CA_ADDR=5 Av. Anatole -AM_CERTS_OPENBAO_PKI_CA_PO=75007 -AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES=localhost -AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES=127.0.0.1,::1 -AM_CERTS_OPENBAO_PKI_CA_URI_SANS= -AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES=info@abstractmachines.rs -AM_CERTS_OPENBAO_UNSEAL_KEY_1= -AM_CERTS_OPENBAO_UNSEAL_KEY_2= -AM_CERTS_OPENBAO_UNSEAL_KEY_3= -AM_CERTS_OPENBAO_ROOT_TOKEN= - -## Jaeger -AM_JAEGER_PORT=6831 -AM_JAEGER_FRONTEND=16686 -AM_JAEGER_URL=http://jaeger:4318/v1/traces -AM_JAEGER_TRACE_RATIO=1.0 -AM_JAEGER_COLLECTOR_OTLP_ENABLED=true -AM_JAEGER_OLTP_HTTP_PORT=4318 -AM_JAEGER_MEMORY_MAX_TRACES=5000 - -#### Auth Client Config -AM_AUTH_URL=auth:9001 -AM_AUTH_GRPC_URL=auth:7001 -AM_AUTH_GRPC_TIMEOUT=300s -AM_AUTH_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt} -AM_AUTH_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key} -AM_AUTH_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} -AM_AUTH_GRPC_SERVER_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} - -#### Domains Client Config -AM_DOMAINS_URL=domains:9003 -AM_DOMAINS_GRPC_URL=domains:7003 -AM_DOMAINS_GRPC_TIMEOUT=300s -AM_DOMAINS_GRPC_CLIENT_CERT=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt} -AM_DOMAINS_GRPC_CLIENT_KEY=${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key} -AM_DOMAINS_GRPC_CLIENT_CA_CERTS=${GRPC_MTLS:+./ssl/certs/ca.crt} diff --git a/docker/addons/certs/Dockerfile b/docker/addons/certs/Dockerfile deleted file mode 100644 index 3724d09c8..000000000 --- a/docker/addons/certs/Dockerfile +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -FROM golang:1.26-alpine AS builder - -ARG SVC -ARG GOARCH -ARG GOARM -ARG VERSION -ARG COMMIT -ARG TIME - -WORKDIR /go/src/github.com/absmach/certs - -COPY . . - -RUN apk update \ - && apk add make upx\ - && make $SVC \ - && upx build/$SVC \ - && mv build/$SVC /exe - -FROM scratch - -# Required for certs service -COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt -COPY --from=builder /exe / -ENTRYPOINT ["/exe"] diff --git a/docker/addons/certs/Dockerfile.dev b/docker/addons/certs/Dockerfile.dev deleted file mode 100644 index 2f6c330c2..000000000 --- a/docker/addons/certs/Dockerfile.dev +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -FROM scratch -ARG SVC -COPY --from=alpine:latest /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ca-certificates.crt -COPY ./build/$SVC /exe -ENTRYPOINT ["/exe"] diff --git a/docker/addons/certs/docker-compose.yaml b/docker/addons/certs/docker-compose.yaml deleted file mode 100644 index f406d5c7b..000000000 --- a/docker/addons/certs/docker-compose.yaml +++ /dev/null @@ -1,161 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -name: "certs" - -networks: - certs-base-net: - driver: bridge - -volumes: - openbao-data: - certs-db-volume: - -services: - certs: - image: ghcr.io/absmach/certs:${AM_CERTS_RELEASE_TAG} - container_name: certs - depends_on: - openbao: - condition: service_healthy - certs-db: - condition: service_started - restart: on-failure - networks: - - certs-base-net - environment: - AM_CERTS_LOG_LEVEL: ${AM_CERTS_LOG_LEVEL} - AM_CERTS_HTTP_HOST: ${AM_CERTS_HTTP_HOST} - AM_CERTS_HTTP_PORT: ${AM_CERTS_HTTP_PORT} - AM_CERTS_GRPC_HOST: ${AM_CERTS_GRPC_HOST} - AM_CERTS_GRPC_PORT: ${AM_CERTS_GRPC_PORT} - AM_JAEGER_URL: ${AM_JAEGER_URL} - AM_JAEGER_TRACE_RATIO: ${AM_JAEGER_TRACE_RATIO} - AM_CERTS_OPENBAO_HOST: ${AM_CERTS_OPENBAO_HOST} - AM_CERTS_OPENBAO_APP_ROLE: ${AM_CERTS_OPENBAO_APP_ROLE} - AM_CERTS_OPENBAO_APP_SECRET: ${AM_CERTS_OPENBAO_APP_SECRET} - AM_CERTS_OPENBAO_NAMESPACE: ${AM_CERTS_OPENBAO_NAMESPACE} - AM_CERTS_OPENBAO_PKI_PATH: ${AM_CERTS_OPENBAO_PKI_PATH} - AM_CERTS_OPENBAO_ROLE: ${AM_CERTS_OPENBAO_ROLE} - AM_CERTS_OPENBAO_SECRET_ID_TTL: ${AM_CERTS_OPENBAO_SECRET_ID_TTL} - AM_CERTS_DB_HOST: ${AM_CERTS_DB_HOST} - AM_CERTS_DB_PORT: ${AM_CERTS_DB_PORT} - AM_CERTS_DB_USER: ${AM_CERTS_DB_USER} - AM_CERTS_DB_PASS: ${AM_CERTS_DB_PASS} - AM_CERTS_DB: ${AM_CERTS_DB} - AM_CERTS_DB_SSL_MODE: ${AM_CERTS_DB_SSL_MODE} - AM_AUTH_GRPC_URL: ${AM_AUTH_GRPC_URL} - AM_AUTH_GRPC_TIMEOUT: ${AM_AUTH_GRPC_TIMEOUT} - AM_AUTH_GRPC_CLIENT_CERT: ${AM_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - AM_AUTH_GRPC_CLIENT_KEY: ${AM_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - AM_AUTH_GRPC_SERVER_CA_CERTS: ${AM_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - AM_DOMAINS_GRPC_URL: ${AM_DOMAINS_GRPC_URL} - AM_DOMAINS_GRPC_TIMEOUT: ${AM_DOMAINS_GRPC_TIMEOUT} - AM_DOMAINS_GRPC_CLIENT_CERT: ${AM_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - AM_DOMAINS_GRPC_CLIENT_KEY: ${AM_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - AM_DOMAINS_GRPC_SERVER_CA_CERTS: ${AM_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - AM_CERTS_SECRET: ${AM_CERTS_SECRET} - AM_CERTS_SERVICE_TOKEN_PATH: ${AM_CERTS_SERVICE_TOKEN_PATH} - AM_CERTS_SECRET_ID_PATH: ${AM_CERTS_SECRET_ID_PATH} - AM_CERTS_SECRET_RENEW_THRESHOLD: ${AM_CERTS_SECRET_RENEW_THRESHOLD} - AM_CERTS_SECRET_CHECK_INTERVAL: ${AM_CERTS_SECRET_CHECK_INTERVAL} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} - ports: - - ${AM_CERTS_HTTP_PORT}:${AM_CERTS_HTTP_PORT} - - ${AM_CERTS_GRPC_PORT}:${AM_CERTS_GRPC_PORT} - volumes: - - openbao-data:/openbao:ro - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_AUTH_GRPC_CLIENT_CERT:-./ssl/certs/dummy/client_cert} - target: /auth-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_AUTH_GRPC_CLIENT_KEY:-./ssl/certs/dummy/client_key} - target: /auth-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/certs/dummy/server_ca} - target: /auth-grpc-server-ca.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_DOMAINS_GRPC_CLIENT_CERT:-./ssl/certs/dummy/client_cert} - target: /domains-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_DOMAINS_GRPC_CLIENT_KEY:-./ssl/certs/dummy/client_key} - target: /domains-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_ADDONS_CERTS_PATH_PREFIX}${AM_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/certs/dummy/server_ca} - target: /domains-grpc-server-ca.crt - bind: - create_host_path: true - - certs-db: - image: postgres:16.2-alpine - container_name: certs-db - restart: on-failure - networks: - - certs-base-net - command: postgres -c "max_connections=${AM_CERTS_DB_MAX_CONNECTIONS}" - environment: - POSTGRES_USER: ${AM_CERTS_DB_USER} - POSTGRES_PASSWORD: ${AM_CERTS_DB_PASS} - POSTGRES_DB: ${AM_CERTS_DB} - ports: - - 5454:5432 - volumes: - - certs-db-volume:/var/lib/postgresql/data - - openbao: - image: openbao/openbao:2.4.0 - container_name: certs-openbao - restart: on-failure - networks: - - certs-base-net - ports: - - 8200:8200 - healthcheck: - test: ["CMD", "sh", "-c", "test -f /opt/openbao/data/service_token"] - interval: 5s - timeout: 3s - retries: 20 - start_period: 30s - environment: - - BAO_ADDR=http://127.0.0.1:8200 - - BAO_LOG_LEVEL=info - - AM_CERTS_OPENBAO_PKI_ROLE=${AM_CERTS_OPENBAO_ROLE} - - AM_CERTS_OPENBAO_APP_ROLE=${AM_CERTS_OPENBAO_APP_ROLE} - - AM_CERTS_OPENBAO_APP_SECRET=${AM_CERTS_OPENBAO_APP_SECRET} - - AM_CERTS_OPENBAO_SECRET_ID_TTL=${AM_CERTS_OPENBAO_SECRET_ID_TTL} - - AM_CERTS_OPENBAO_NAMESPACE=${AM_CERTS_OPENBAO_NAMESPACE} - - AM_CERTS_OPENBAO_PKI_CA_CN=${AM_CERTS_OPENBAO_PKI_CA_CN} - - AM_CERTS_OPENBAO_PKI_CA_OU=${AM_CERTS_OPENBAO_PKI_CA_OU} - - AM_CERTS_OPENBAO_PKI_CA_O=${AM_CERTS_OPENBAO_PKI_CA_O} - - AM_CERTS_OPENBAO_PKI_CA_C=${AM_CERTS_OPENBAO_PKI_CA_C} - - AM_CERTS_OPENBAO_PKI_CA_L=${AM_CERTS_OPENBAO_PKI_CA_L} - - AM_CERTS_OPENBAO_PKI_CA_ST=${AM_CERTS_OPENBAO_PKI_CA_ST} - - AM_CERTS_OPENBAO_PKI_CA_ADDR=${AM_CERTS_OPENBAO_PKI_CA_ADDR} - - AM_CERTS_OPENBAO_PKI_CA_PO=${AM_CERTS_OPENBAO_PKI_CA_PO} - - AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES=${AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES} - - AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES=${AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES} - - AM_CERTS_OPENBAO_PKI_CA_URI_SANS=${AM_CERTS_OPENBAO_PKI_CA_URI_SANS} - - AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES=${AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES} - - AM_CERTS_OPENBAO_UNSEAL_KEY_1=${AM_CERTS_OPENBAO_UNSEAL_KEY_1} - - AM_CERTS_OPENBAO_UNSEAL_KEY_2=${AM_CERTS_OPENBAO_UNSEAL_KEY_2} - - AM_CERTS_OPENBAO_UNSEAL_KEY_3=${AM_CERTS_OPENBAO_UNSEAL_KEY_3} - - AM_CERTS_OPENBAO_ROOT_TOKEN=${AM_CERTS_OPENBAO_ROOT_TOKEN} - cap_add: - - IPC_LOCK - mem_swappiness: 0 - volumes: - - openbao-data:/opt/openbao/data - - openbao-data:/opt/openbao/config - - ./openbao-entrypoint.sh:/entrypoint.sh - entrypoint: /bin/sh - command: /entrypoint.sh diff --git a/docker/addons/journal/docker-compose.yaml b/docker/addons/journal/docker-compose.yaml deleted file mode 100644 index 6ababaf53..000000000 --- a/docker/addons/journal/docker-compose.yaml +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -# This docker-compose file contains optional Postgres and journal services -# for SuperMQ platform. Since these are optional, this file is dependent of docker-compose file -# from /docker. In order to run these services, execute command: -# docker-compose -f docker/docker-compose.yaml -f docker/addons/journal/docker-compose.yaml up -# from project root. PostgreSQL default port (5432) is exposed, so you can use various tools for database -# inspection and data visualization. - -networks: - supermq-base-net: - name: supermq-base-net - external: true - -volumes: - supermq-journal-volume: - - -services: - journal-db: - image: postgres:16.2-alpine - container_name: supermq-journal-db - restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" - environment: - POSTGRES_USER: ${SMQ_JOURNAL_DB_USER} - POSTGRES_PASSWORD: ${SMQ_JOURNAL_DB_PASS} - POSTGRES_DB: ${SMQ_JOURNAL_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} - networks: - - supermq-base-net - volumes: - - supermq-journal-volume:/var/lib/postgresql/data - - journal: - image: supermq/journal:${SMQ_RELEASE_TAG} - container_name: supermq-journal - depends_on: - - journal-db - restart: on-failure - environment: - SMQ_JOURNAL_LOG_LEVEL: ${SMQ_JOURNAL_LOG_LEVEL} - SMQ_JOURNAL_HTTP_HOST: ${SMQ_JOURNAL_HTTP_HOST} - SMQ_JOURNAL_HTTP_PORT: ${SMQ_JOURNAL_HTTP_PORT} - SMQ_JOURNAL_HTTP_SERVER_CERT: ${SMQ_JOURNAL_HTTP_SERVER_CERT} - SMQ_JOURNAL_HTTP_SERVER_KEY: ${SMQ_JOURNAL_HTTP_SERVER_KEY} - SMQ_JOURNAL_DB_HOST: ${SMQ_JOURNAL_DB_HOST} - SMQ_JOURNAL_DB_PORT: ${SMQ_JOURNAL_DB_PORT} - SMQ_JOURNAL_DB_USER: ${SMQ_JOURNAL_DB_USER} - SMQ_JOURNAL_DB_PASS: ${SMQ_JOURNAL_DB_PASS} - SMQ_JOURNAL_DB_NAME: ${SMQ_JOURNAL_DB_NAME} - SMQ_JOURNAL_DB_SSL_MODE: ${SMQ_JOURNAL_DB_SSL_MODE} - SMQ_JOURNAL_DB_SSL_CERT: ${SMQ_JOURNAL_DB_SSL_CERT} - SMQ_JOURNAL_DB_SSL_KEY: ${SMQ_JOURNAL_DB_SSL_KEY} - SMQ_JOURNAL_DB_SSL_ROOT_CERT: ${SMQ_JOURNAL_DB_SSL_ROOT_CERT} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_JOURNAL_INSTANCE_ID: ${SMQ_JOURNAL_INSTANCE_ID} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} - ports: - - ${SMQ_JOURNAL_HTTP_PORT}:${SMQ_JOURNAL_HTTP_PORT} - networks: - - supermq-base-net diff --git a/docker/addons/postgres-reader/docker-compose.yaml b/docker/addons/postgres-reader/docker-compose.yaml new file mode 100644 index 000000000..b3c34a2b5 --- /dev/null +++ b/docker/addons/postgres-reader/docker-compose.yaml @@ -0,0 +1,123 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# This docker-compose file contains optional Postgres-reader service for Magistrala platform. +# Since this service is optional, this file is dependent of docker-compose.yaml file +# from /docker. In order to run this service, execute command: +# docker compose -f docker/docker-compose.yaml -f docker/addons/postgres-reader/docker-compose.yaml up +# from project root. + +networks: + magistrala-base-net: + external: true + +services: + postgres-reader: + image: docker.io/magistrala/postgres-reader:${MG_RELEASE_TAG} + container_name: magistrala-postgres-reader + restart: on-failure + environment: + MG_POSTGRES_READER_LOG_LEVEL: ${MG_POSTGRES_READER_LOG_LEVEL} + MG_POSTGRES_READER_HTTP_HOST: ${MG_POSTGRES_READER_HTTP_HOST} + MG_POSTGRES_READER_HTTP_PORT: ${MG_POSTGRES_READER_HTTP_PORT} + MG_POSTGRES_READER_HTTP_SERVER_CERT: ${MG_POSTGRES_READER_HTTP_SERVER_CERT} + MG_POSTGRES_READER_HTTP_SERVER_KEY: ${MG_POSTGRES_READER_HTTP_SERVER_KEY} + MG_POSTGRES_HOST: ${MG_POSTGRES_HOST} + MG_POSTGRES_PORT: ${MG_POSTGRES_PORT} + MG_POSTGRES_USER: ${MG_POSTGRES_USER} + MG_POSTGRES_PASS: ${MG_POSTGRES_PASS} + MG_POSTGRES_NAME: ${MG_POSTGRES_NAME} + MG_POSTGRES_SSL_MODE: ${MG_POSTGRES_SSL_MODE} + MG_POSTGRES_SSL_CERT: ${MG_POSTGRES_SSL_CERT} + MG_POSTGRES_SSL_KEY: ${MG_POSTGRES_SSL_KEY} + MG_POSTGRES_SSL_ROOT_CERT: ${MG_POSTGRES_SSL_ROOT_CERT} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_POSTGRES_READER_GRPC_URL: ${MG_POSTGRES_READER_GRPC_URL} + MG_POSTGRES_READER_GRPC_PORT: ${MG_POSTGRES_READER_GRPC_PORT} + MG_POSTGRES_READER_GRPC_HOST: ${MG_POSTGRES_READER_GRPC_HOST} + MG_POSTGRES_READER_GRPC_TIMEOUT: ${MG_POSTGRES_READER_GRPC_TIMEOUT} + MG_POSTGRES_READER_GRPC_CLIENT_CERT: ${MG_POSTGRES_READER_GRPC_CLIENT_CERT:+./ssl/certs/reader-grpc-client.crt} + MG_POSTGRES_READER_GRPC_CLIENT_CA_CERTS: ${MG_POSTGRES_READER_GRPC_CLIENT_CA_CERTS:+./ssl/certs/ca.crt} + MG_POSTGRES_READER_GRPC_SERVER_CA_CERTS: ${MG_POSTGRES_READER_GRPC_SERVER_CA_CERTS:+./ssl/certs/ca.crt} + MG_POSTGRES_READER_GRPC_CLIENT_KEY: ${MG_POSTGRES_READER_GRPC_CLIENT_KEY:+/readers-grpc-client.key} + MG_POSTGRES_READER_GRPC_SERVER_CERT: ${MG_POSTGRES_READER_GRPC_SERVER_CERT:+./ssl/certs/readers-grpc-server.crt} + MG_POSTGRES_READER_GRPC_SERVER_KEY: ${MG_POSTGRES_READER_GRPC_SERVER_KEY:+./ssl/certs/readers-grpc-server.key} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_POSTGRES_READER_INSTANCE_ID: ${MG_POSTGRES_READER_INSTANCE_ID} + ports: + - ${MG_POSTGRES_READER_HTTP_PORT}:${MG_POSTGRES_READER_HTTP_PORT} + - ${MG_POSTGRES_READER_GRPC_PORT}:${MG_POSTGRES_READER_GRPC_PORT} + networks: + - magistrala-base-net + volumes: + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca${MG_AUTH_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + # Clients gRPC mTLS client certificates + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_ADDONS_CERTS_PATH_PREFIX}${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /clients-grpc-server-ca${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + # Reader gRPC mTLS client certificates + - type: bind + source: ${MG_POSTGRES_READER_GRPC_SERVER_CERT:-./ssl/placeholder} + target: /readers-grpc-server${MG_POSTGRES_READER_GRPC_SERVER_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_POSTGRES_READER_GRPC_SERVER_KEY:-./ssl/placeholder} + target: /readers-grpc-server${MG_POSTGRES_READER_GRPC_SERVER_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_POSTGRES_READER_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /readers-grpc-server-ca${MG_POSTGRES_READER_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_POSTGRES_READER_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} + target: /readers-grpc-server${MG_POSTGRES_READER_GRPC_CLIENT_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_POSTGRES_READER_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /readers-grpc-client${MG_POSTGRES_READER_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_POSTGRES_READER_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /readers-grpc-client${MG_POSTGRES_READER_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true diff --git a/docker/addons/postgres-reader/ssl/placeholder b/docker/addons/postgres-reader/ssl/placeholder new file mode 100644 index 000000000..f5f101481 --- /dev/null +++ b/docker/addons/postgres-reader/ssl/placeholder @@ -0,0 +1 @@ +optional bind-mount placeholder diff --git a/docker/addons/postgres-writer/config.toml b/docker/addons/postgres-writer/config.toml new file mode 100644 index 000000000..0f3343bcc --- /dev/null +++ b/docker/addons/postgres-writer/config.toml @@ -0,0 +1,22 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# Writers consume through the broker's stream-backed path; this file only +# selects topic filters. Use NATS-style filters here even when built with +# FluxMQ. +# To listen on all writer topics use the default value "writers.>". +# To subscribe to specific topics use values starting with "writers." and +# followed by a subtopic (e.g. ["writers..sub.topic.x", ...]). +["subscriber"] +topics = ["writers.>"] + +[transformer] +# SenML or JSON +format = "senml" +# Used if format is SenML +content_type = "application/senml+json" +# Used as timestamp fields if format is JSON +time_fields = [{ field_name = "seconds_key", field_format = "unix", location = "UTC"}, + { field_name = "millis_key", field_format = "unix_ms", location = "UTC"}, + { field_name = "micros_key", field_format = "unix_us", location = "UTC"}, + { field_name = "nanos_key", field_format = "unix_ns", location = "UTC"}] diff --git a/docker/addons/postgres-writer/docker-compose.yaml b/docker/addons/postgres-writer/docker-compose.yaml new file mode 100644 index 000000000..bf8b971af --- /dev/null +++ b/docker/addons/postgres-writer/docker-compose.yaml @@ -0,0 +1,66 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# This docker-compose file contains optional Postgres and Postgres-writer services +# for Magistrala platform. Since these are optional, this file is dependent of docker-compose file +# from /docker. In order to run these services, execute command: +# docker compose -f docker/docker-compose.yaml -f docker/addons/postgres-writer/docker-compose.yaml up +# from project root. PostgreSQL default port (5432) is exposed, so you can use various tools for database +# inspection and data visualization. + +networks: + magistrala-base-net: + external: true + +volumes: + magistrala-postgres-writer-volume: + +services: + postgres: + image: postgres:16.2-alpine + container_name: magistrala-postgres + restart: on-failure + environment: + POSTGRES_USER: ${MG_POSTGRES_USER} + POSTGRES_PASSWORD: ${MG_POSTGRES_PASS} + POSTGRES_DB: ${MG_POSTGRES_NAME} + ports: + - 5434:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-postgres-writer-volume:/var/lib/postgresql/data + + postgres-writer: + image: docker.io/magistrala/postgres-writer:${MG_RELEASE_TAG} + container_name: magistrala-postgres-writer + depends_on: + - postgres + restart: on-failure + environment: + MG_POSTGRES_WRITER_LOG_LEVEL: ${MG_POSTGRES_WRITER_LOG_LEVEL} + MG_POSTGRES_WRITER_CONFIG_PATH: ${MG_POSTGRES_WRITER_CONFIG_PATH} + MG_POSTGRES_WRITER_HTTP_HOST: ${MG_POSTGRES_WRITER_HTTP_HOST} + MG_POSTGRES_WRITER_HTTP_PORT: ${MG_POSTGRES_WRITER_HTTP_PORT} + MG_POSTGRES_WRITER_HTTP_SERVER_CERT: ${MG_POSTGRES_WRITER_HTTP_SERVER_CERT} + MG_POSTGRES_WRITER_HTTP_SERVER_KEY: ${MG_POSTGRES_WRITER_HTTP_SERVER_KEY} + MG_POSTGRES_HOST: ${MG_POSTGRES_HOST} + MG_POSTGRES_PORT: ${MG_POSTGRES_PORT} + MG_POSTGRES_USER: ${MG_POSTGRES_USER} + MG_POSTGRES_PASS: ${MG_POSTGRES_PASS} + MG_POSTGRES_NAME: ${MG_POSTGRES_NAME} + MG_POSTGRES_SSL_MODE: ${MG_POSTGRES_SSL_MODE} + MG_POSTGRES_SSL_CERT: ${MG_POSTGRES_SSL_CERT} + MG_POSTGRES_SSL_KEY: ${MG_POSTGRES_SSL_KEY} + MG_POSTGRES_SSL_ROOT_CERT: ${MG_POSTGRES_SSL_ROOT_CERT} + MG_MESSAGE_BROKER_URL: ${MG_MESSAGE_BROKER_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_POSTGRES_WRITER_INSTANCE_ID: ${MG_POSTGRES_WRITER_INSTANCE_ID} + ports: + - ${MG_POSTGRES_WRITER_HTTP_PORT}:${MG_POSTGRES_WRITER_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + - ./config.toml:/config.toml diff --git a/docker/addons/prometheus/docker-compose.yaml b/docker/addons/prometheus/docker-compose.yaml index 5abc0bf22..06543f359 100644 --- a/docker/addons/prometheus/docker-compose.yaml +++ b/docker/addons/prometheus/docker-compose.yaml @@ -1,49 +1,49 @@ # Copyright (c) Abstract Machines # SPDX-License-Identifier: Apache-2.0 -# This docker-compose file contains optional Prometheus and Grafana service for SuperMQ platform. +# This docker-compose file contains optional Prometheus and Grafana service for Magistrala platform. # Since this service is optional, this file is dependent of docker-compose.yaml file # from /docker. In order to run this service, execute command: # docker compose -f docker/addons/prometheus/docker-compose.yaml up # from project root. networks: - supermq-base-net: - name: supermq-base-net + magistrala-base-net: + name: magistrala-base-net external: true volumes: - supermq-prometheus-volume: + magistrala-prometheus-volume: services: promethues: image: prom/prometheus:v2.49.1 - container_name: supermq-prometheus + container_name: magistrala-prometheus restart: on-failure ports: - - ${SMQ_PROMETHEUS_PORT}:${SMQ_PROMETHEUS_PORT} + - ${MG_PROMETHEUS_PORT}:${MG_PROMETHEUS_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - type: bind source: ./metrics/prometheus.yaml target: /etc/prometheus/prometheus.yaml - - supermq-prometheus-volume:/prometheus + - magistrala-prometheus-volume:/prometheus grafana: image: grafana/grafana:10.2.3 - container_name: supermq-grafana + container_name: magistrala-grafana depends_on: - promethues restart: on-failure ports: - - ${SMQ_GRAFANA_PORT}:${SMQ_GRAFANA_PORT} + - ${MG_GRAFANA_PORT}:${MG_GRAFANA_PORT} environment: - - GF_SECURITY_ADMIN_USER=${SMQ_GRAFANA_ADMIN_USER} - - GF_SECURITY_ADMIN_PASSWORD=${SMQ_GRAFANA_ADMIN_PASSWORD} + - GF_SECURITY_ADMIN_USER=${MG_GRAFANA_ADMIN_USER} + - GF_SECURITY_ADMIN_PASSWORD=${MG_GRAFANA_ADMIN_PASSWORD} networks: - - supermq-base-net + - magistrala-base-net volumes: - type: bind source: ./grafana/datasource.yaml diff --git a/docker/addons/provision/configs/config.toml b/docker/addons/provision/configs/config.toml new file mode 100644 index 000000000..d40a4d663 --- /dev/null +++ b/docker/addons/provision/configs/config.toml @@ -0,0 +1,74 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +[bootstrap] + [bootstrap.content] + [bootstrap.content.agent.edgex] + url = "http://localhost:48090/api/v1/" + + [bootstrap.content.agent.log] + level = "info" + + [bootstrap.content.agent.mqtt] + mtls = false + qos = 0 + retain = false + skip_tls_ver = true + url = "localhost:1883" + + [bootstrap.content.agent.server] + nats_url = "localhost:4222" + port = "9000" + + [bootstrap.content.agent.heartbeat] + interval = "30s" + + [bootstrap.content.agent.terminal] + session_timeout = "30s" + + + [bootstrap.content.export.exp] + log_level = "debug" + nats = "nats://localhost:4222" + port = "8172" + cache_url = "localhost:6379" + cache_pass = "" + cache_db = "0" + + [bootstrap.content.export.mqtt] + ca_path = "ca.crt" + cert_path = "thing.crt" + channel = "" + host = "tcp://localhost:1883" + mtls = false + password = "" + priv_key_path = "thing.key" + qos = 0 + retain = false + skip_tls_ver = false + username = "" + + [[bootstrap.content.export.routes]] + mqtt_topic = "" + nats_topic = ">" + subtopic = "" + type = "plain" + workers = 10 + +[[clients]] + name = "client" + + [clients.metadata] + external_id = "xxxxxx" + +[[channels]] + name = "control-channel" + + [channels.metadata] + type = "control" + +[[channels]] + name = "data-channel" + + [channels.metadata] + type = "data" diff --git a/docker/addons/provision/docker-compose.yaml b/docker/addons/provision/docker-compose.yaml new file mode 100644 index 000000000..a3ab05be4 --- /dev/null +++ b/docker/addons/provision/docker-compose.yaml @@ -0,0 +1,54 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# This docker-compose file contains optional provision services. Since it's optional, this file is +# dependent of docker-compose file from /docker. In order to run this services, execute command: +# docker compose -f docker/docker-compose.yaml -f docker/addons/provision/docker-compose.yaml up +# from project root. + +networks: + magistrala-base-net: + external: true + +services: + provision: + image: docker.io/magistrala/provision:${MG_RELEASE_TAG} + container_name: magistrala-provision + restart: on-failure + networks: + - magistrala-base-net + ports: + - ${MG_PROVISION_HTTP_PORT}:${MG_PROVISION_HTTP_PORT} + environment: + MG_PROVISION_LOG_LEVEL: ${MG_PROVISION_LOG_LEVEL} + MG_PROVISION_HTTP_PORT: ${MG_PROVISION_HTTP_PORT} + MG_PROVISION_CONFIG_FILE: ${MG_PROVISION_CONFIG_FILE} + MG_PROVISION_ENV_CLIENTS_TLS: ${MG_PROVISION_ENV_CLIENTS_TLS} + MG_PROVISION_SERVER_CERT: ${MG_PROVISION_SERVER_CERT} + MG_PROVISION_SERVER_KEY: ${MG_PROVISION_SERVER_KEY} + MG_PROVISION_USERS_URL: ${MG_PROVISION_USERS_URL} + MG_PROVISION_CHANNELS_URL: ${MG_PROVISION_CHANNELS_URL} + MG_PROVISION_CLIENTS_URL: ${MG_PROVISION_CLIENTS_URL} + MG_PROVISION_USER: ${MG_PROVISION_USER} + MG_PROVISION_USERNAME: ${MG_PROVISION_USERNAME} + MG_PROVISION_PASS: ${MG_PROVISION_PASS} + MG_PROVISION_API_KEY: ${MG_PROVISION_API_KEY} + MG_PROVISION_CERTS_URL: ${MG_PROVISION_CERTS_URL} + MG_PROVISION_X509_PROVISIONING: ${MG_PROVISION_X509_PROVISIONING} + MG_PROVISION_BS_SVC_URL: ${MG_PROVISION_BS_SVC_URL} + MG_PROVISION_BS_CONFIG_PROVISIONING: ${MG_PROVISION_BS_CONFIG_PROVISIONING} + MG_PROVISION_BS_AUTO_WHITELIST: ${MG_PROVISION_BS_AUTO_WHITELIST} + MG_PROVISION_BS_CONTENT: ${MG_PROVISION_BS_CONTENT} + MG_PROVISION_CERTS_HOURS_VALID: ${MG_PROVISION_CERTS_HOURS_VALID} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_PROVISION_INSTANCE_ID: ${MG_PROVISION_INSTANCE_ID} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + volumes: + - ./configs:/configs + - ../../ssl/certs/ca.key:/etc/ssl/certs/ca.key + - ../../ssl/certs/ca.crt:/etc/ssl/certs/ca.crt diff --git a/docker/addons/ssl/placeholder b/docker/addons/ssl/placeholder new file mode 100644 index 000000000..f5f101481 --- /dev/null +++ b/docker/addons/ssl/placeholder @@ -0,0 +1 @@ +optional bind-mount placeholder diff --git a/docker/addons/timescale-writer/config.toml b/docker/addons/timescale-writer/config.toml new file mode 100644 index 000000000..820dde6ac --- /dev/null +++ b/docker/addons/timescale-writer/config.toml @@ -0,0 +1,11 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# Writers consume through the broker's stream-backed path; this file only +# selects topic filters. Use NATS-style filters here even when built with +# FluxMQ. +# To listen on all writer topics use the default value "writers.>". +# To subscribe to specific topics use values starting with "writers." and +# followed by a subtopic (e.g. ["writers..sub.topic.x", ...]). +["subscriber"] +topics = ["writers.>"] diff --git a/docker/certs-docker-compose-override.yaml b/docker/certs-docker-compose-override.yaml deleted file mode 100644 index 1aae5c3bf..000000000 --- a/docker/certs-docker-compose-override.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -services: - certs: - environment: - AM_CERTS_LOG_LEVEL: ${AM_CERTS_LOG_LEVEL} - AM_CERTS_HTTP_HOST: ${AM_CERTS_HTTP_HOST} - AM_CERTS_HTTP_PORT: ${AM_CERTS_HTTP_PORT} - AM_CERTS_GRPC_HOST: ${AM_CERTS_GRPC_HOST} - AM_CERTS_GRPC_PORT: ${AM_CERTS_GRPC_PORT} - AM_CERTS_RELEASE_TAG: ${AM_CERTS_RELEASE_TAG} - AM_CERTS_SECRET: ${AM_CERTS_SECRET} - AM_CERTS_DB_HOST: ${AM_CERTS_DB_HOST} - AM_CERTS_DB_PORT: ${AM_CERTS_DB_PORT} - AM_CERTS_DB_USER: ${AM_CERTS_DB_USER} - AM_CERTS_DB_PASS: ${AM_CERTS_DB_PASS} - AM_CERTS_DB: ${AM_CERTS_DB} - AM_CERTS_DB_SSL_MODE: ${AM_CERTS_DB_SSL_MODE} - AM_CERTS_DB_MAX_CONNECTIONS: ${AM_CERTS_DB_MAX_CONNECTIONS} - AM_CERTS_OPENBAO_HOST: ${AM_CERTS_OPENBAO_HOST} - AM_CERTS_OPENBAO_APP_ROLE: ${AM_CERTS_OPENBAO_APP_ROLE} - AM_CERTS_OPENBAO_APP_SECRET: ${AM_CERTS_OPENBAO_APP_SECRET} - AM_CERTS_OPENBAO_NAMESPACE: ${AM_CERTS_OPENBAO_NAMESPACE} - AM_CERTS_OPENBAO_PKI_PATH: ${AM_CERTS_OPENBAO_PKI_PATH} - AM_CERTS_OPENBAO_ROLE: ${AM_CERTS_OPENBAO_ROLE} - AM_CERTS_OPENBAO_SECRET_ID_TTL: ${AM_CERTS_OPENBAO_SECRET_ID_TTL} - AM_CERTS_SERVICE_TOKEN_PATH: ${AM_CERTS_SERVICE_TOKEN_PATH} - AM_CERTS_SECRET_ID_PATH: ${AM_CERTS_SECRET_ID_PATH} - AM_CERTS_SECRET_RENEW_THRESHOLD: ${AM_CERTS_SECRET_RENEW_THRESHOLD} - AM_CERTS_SECRET_CHECK_INTERVAL: ${AM_CERTS_SECRET_CHECK_INTERVAL} - AM_CERTS_OPENBAO_PKI_CA_CN: ${AM_CERTS_OPENBAO_PKI_CA_CN} - AM_CERTS_OPENBAO_PKI_CA_OU: ${AM_CERTS_OPENBAO_PKI_CA_OU} - AM_CERTS_OPENBAO_PKI_CA_O: ${AM_CERTS_OPENBAO_PKI_CA_O} - AM_CERTS_OPENBAO_PKI_CA_C: ${AM_CERTS_OPENBAO_PKI_CA_C} - AM_CERTS_OPENBAO_PKI_CA_L: ${AM_CERTS_OPENBAO_PKI_CA_L} - AM_CERTS_OPENBAO_PKI_CA_ST: ${AM_CERTS_OPENBAO_PKI_CA_ST} - AM_CERTS_OPENBAO_PKI_CA_ADDR: ${AM_CERTS_OPENBAO_PKI_CA_ADDR} - AM_CERTS_OPENBAO_PKI_CA_PO: ${AM_CERTS_OPENBAO_PKI_CA_PO} - AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES: ${AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES} - AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES: ${AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES} - AM_CERTS_OPENBAO_PKI_CA_URI_SANS: ${AM_CERTS_OPENBAO_PKI_CA_URI_SANS} - AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES: ${AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES} - AM_CERTS_OPENBAO_UNSEAL_KEY_1: ${AM_CERTS_OPENBAO_UNSEAL_KEY_1} - AM_CERTS_OPENBAO_UNSEAL_KEY_2: ${AM_CERTS_OPENBAO_UNSEAL_KEY_2} - AM_CERTS_OPENBAO_UNSEAL_KEY_3: ${AM_CERTS_OPENBAO_UNSEAL_KEY_3} - AM_CERTS_OPENBAO_ROOT_TOKEN: ${AM_CERTS_OPENBAO_ROOT_TOKEN} - AM_JAEGER_URL: ${AM_JAEGER_URL} - AM_JAEGER_TRACE_RATIO: ${AM_JAEGER_TRACE_RATIO} - AM_AUTH_GRPC_URL: ${AM_AUTH_GRPC_URL} - AM_AUTH_GRPC_TIMEOUT: ${AM_AUTH_GRPC_TIMEOUT} - AM_AUTH_GRPC_CLIENT_CERT: ${AM_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - AM_AUTH_GRPC_CLIENT_KEY: ${AM_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - AM_AUTH_GRPC_SERVER_CA_CERTS: ${AM_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - AM_DOMAINS_GRPC_URL: ${AM_DOMAINS_GRPC_URL} - AM_DOMAINS_GRPC_TIMEOUT: ${AM_DOMAINS_GRPC_TIMEOUT} - AM_DOMAINS_GRPC_CLIENT_CERT: ${AM_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - AM_DOMAINS_GRPC_CLIENT_KEY: ${AM_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - AM_DOMAINS_GRPC_SERVER_CA_CERTS: ${AM_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - networks: !override - - supermq-base-net - - certs-db: - networks: !override - - supermq-base-net - - openbao: - networks: !override - - supermq-base-net diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index f8c88d077..7f889210a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,406 +1,512 @@ # Copyright (c) Abstract Machines # SPDX-License-Identifier: Apache-2.0 -name: "supermq" +name: "magistrala" networks: - supermq-base-net: + magistrala-base-net: driver: bridge - name: supermq-base-net + name: magistrala-base-net + ipam: + config: + - subnet: 172.30.0.0/24 volumes: - supermq-users-db-volume: - supermq-groups-db-volume: - supermq-clients-db-volume: - supermq-channels-db-volume: - supermq-channels-redis-volume: - supermq-clients-redis-volume: - supermq-broker-volume: - supermq-mqtt-broker-volume: - supermq-spicedb-db-volume: - supermq-auth-db-volume: - supermq-pat-db-volume: - supermq-domains-db-volume: - supermq-domains-redis-volume: - supermq-auth-redis-volume: - supermq-auth-keys-volume: + magistrala-users-db-volume: + magistrala-groups-db-volume: + magistrala-clients-db-volume: + magistrala-channels-db-volume: + magistrala-channels-redis-volume: + magistrala-clients-redis-volume: + magistrala-spicedb-db-volume: + magistrala-auth-db-volume: + magistrala-pat-db-volume: + magistrala-domains-db-volume: + magistrala-domains-redis-volume: + magistrala-auth-redis-volume: + magistrala-auth-keys-volume: + magistrala-ui-backend-db-volume: + magistrala-journal-volume: + magistrala-re-db-volume: + magistrala-alarms-db-volume: + magistrala-reports-db-volume: + magistrala-certs-db-volume: + magistrala-openbao-data: + magistrala-timescale-writer-volume: + magistrala-fluxmq-node1-volume: + magistrala-fluxmq-node2-volume: + magistrala-fluxmq-node3-volume: services: spicedb: - image: docker.io/authzed/spicedb:v1.37.0 - container_name: supermq-spicedb + image: docker.io/authzed/spicedb:v1.50.0 + container_name: magistrala-spicedb command: "serve" restart: "always" networks: - - supermq-base-net + - magistrala-base-net ports: - "8080:8080" - "9091:9090" - "50051:50051" environment: - SPICEDB_GRPC_PRESHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SPICEDB_DATASTORE_ENGINE: ${SMQ_SPICEDB_DATASTORE_ENGINE} - SPICEDB_DATASTORE_CONN_URI: "${SMQ_SPICEDB_DATASTORE_ENGINE}://${SMQ_SPICEDB_DB_USER}:${SMQ_SPICEDB_DB_PASS}@spicedb-db:${SMQ_SPICEDB_DB_PORT}/${SMQ_SPICEDB_DB_NAME}?sslmode=disable" + SPICEDB_GRPC_PRESHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + SPICEDB_DATASTORE_ENGINE: ${MG_SPICEDB_DATASTORE_ENGINE} + SPICEDB_DATASTORE_CONN_URI: "${MG_SPICEDB_DATASTORE_ENGINE}://${MG_SPICEDB_DB_USER}:${MG_SPICEDB_DB_PASS}@spicedb-db:${MG_SPICEDB_DB_PORT}/${MG_SPICEDB_DB_NAME}?sslmode=disable" depends_on: - spicedb-migrate spicedb-migrate: - image: docker.io/authzed/spicedb:v1.37.0 - container_name: supermq-spicedb-migrate + image: docker.io/authzed/spicedb:v1.50.0 + container_name: magistrala-spicedb-migrate command: "migrate head" restart: "on-failure" networks: - - supermq-base-net + - magistrala-base-net environment: - SPICEDB_DATASTORE_ENGINE: ${SMQ_SPICEDB_DATASTORE_ENGINE} - SPICEDB_DATASTORE_CONN_URI: "${SMQ_SPICEDB_DATASTORE_ENGINE}://${SMQ_SPICEDB_DB_USER}:${SMQ_SPICEDB_DB_PASS}@spicedb-db:${SMQ_SPICEDB_DB_PORT}/${SMQ_SPICEDB_DB_NAME}?sslmode=disable" + SPICEDB_DATASTORE_ENGINE: ${MG_SPICEDB_DATASTORE_ENGINE} + SPICEDB_DATASTORE_CONN_URI: "${MG_SPICEDB_DATASTORE_ENGINE}://${MG_SPICEDB_DB_USER}:${MG_SPICEDB_DB_PASS}@spicedb-db:${MG_SPICEDB_DB_PORT}/${MG_SPICEDB_DB_NAME}?sslmode=disable" depends_on: - spicedb-db spicedb-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-spicedb-db + container_name: magistrala-spicedb-db networks: - - supermq-base-net + - magistrala-base-net ports: - "6010:5432" environment: - POSTGRES_USER: ${SMQ_SPICEDB_DB_USER} - POSTGRES_PASSWORD: ${SMQ_SPICEDB_DB_PASS} - POSTGRES_DB: ${SMQ_SPICEDB_DB_NAME} + POSTGRES_USER: ${MG_SPICEDB_DB_USER} + POSTGRES_PASSWORD: ${MG_SPICEDB_DB_PASS} + POSTGRES_DB: ${MG_SPICEDB_DB_NAME} volumes: - - supermq-spicedb-db-volume:/var/lib/postgresql/data + - magistrala-spicedb-db-volume:/var/lib/postgresql/data command: ["postgres", "-c", "track_commit_timestamp=on"] auth-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-auth-db + container_name: magistrala-auth-db restart: on-failure ports: - 6001:5432 environment: - POSTGRES_USER: ${SMQ_AUTH_DB_USER} - POSTGRES_PASSWORD: ${SMQ_AUTH_DB_PASS} - POSTGRES_DB: ${SMQ_AUTH_DB_NAME} + POSTGRES_USER: ${MG_AUTH_DB_USER} + POSTGRES_PASSWORD: ${MG_AUTH_DB_PASS} + POSTGRES_DB: ${MG_AUTH_DB_NAME} networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-auth-db-volume:/var/lib/postgresql/data + - magistrala-auth-db-volume:/var/lib/postgresql/data auth-redis: image: docker.io/redis:8.2.2-alpine3.22 - container_name: supermq-auth-redis + container_name: magistrala-auth-redis restart: on-failure networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-auth-redis-volume:/data + - magistrala-auth-redis-volume:/data - ./redis/redis.conf:/etc/redis/redis.conf:ro command: ["redis-server", "/etc/redis/redis.conf"] auth: - image: docker.io/supermq/auth:${SMQ_RELEASE_TAG} - container_name: supermq-auth + image: docker.io/magistrala/auth:${MG_RELEASE_TAG} + container_name: magistrala-auth depends_on: - auth-db - spicedb + - nginx expose: - - ${SMQ_AUTH_GRPC_PORT} + - ${MG_AUTH_GRPC_PORT} restart: on-failure environment: - SMQ_AUTH_LOG_LEVEL: ${SMQ_AUTH_LOG_LEVEL} - SMQ_SPICEDB_SCHEMA_FILE: ${SMQ_SPICEDB_SCHEMA_FILE} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_AUTH_INVITATION_DURATION: ${SMQ_AUTH_INVITATION_DURATION} - SMQ_AUTH_HTTP_HOST: ${SMQ_AUTH_HTTP_HOST} - SMQ_AUTH_HTTP_PORT: ${SMQ_AUTH_HTTP_PORT} - SMQ_AUTH_HTTP_SERVER_CERT: ${SMQ_AUTH_HTTP_SERVER_CERT} - SMQ_AUTH_HTTP_SERVER_KEY: ${SMQ_AUTH_HTTP_SERVER_KEY} - SMQ_AUTH_GRPC_HOST: ${SMQ_AUTH_GRPC_HOST} - SMQ_AUTH_GRPC_PORT: ${SMQ_AUTH_GRPC_PORT} - SMQ_AUTH_ACCESS_TOKEN_DURATION: ${SMQ_AUTH_ACCESS_TOKEN_DURATION} - SMQ_AUTH_REFRESH_TOKEN_DURATION: ${SMQ_AUTH_REFRESH_TOKEN_DURATION} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_AUTH_KEYS_ACTIVE_KEY_PATH: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH:+/keys/active.key} - SMQ_AUTH_KEYS_RETIRING_KEY_PATH: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:+/keys/retiring.key} + MG_AUTH_LOG_LEVEL: ${MG_AUTH_LOG_LEVEL} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_AUTH_INVITATION_DURATION: ${MG_AUTH_INVITATION_DURATION} + MG_AUTH_HTTP_HOST: ${MG_AUTH_HTTP_HOST} + MG_AUTH_HTTP_PORT: ${MG_AUTH_HTTP_PORT} + MG_AUTH_HTTP_SERVER_CERT: ${MG_AUTH_HTTP_SERVER_CERT} + MG_AUTH_HTTP_SERVER_KEY: ${MG_AUTH_HTTP_SERVER_KEY} + MG_AUTH_GRPC_HOST: ${MG_AUTH_GRPC_HOST} + MG_AUTH_GRPC_PORT: ${MG_AUTH_GRPC_PORT} + MG_AUTH_ACCESS_TOKEN_DURATION: ${MG_AUTH_ACCESS_TOKEN_DURATION} + MG_AUTH_REFRESH_TOKEN_DURATION: ${MG_AUTH_REFRESH_TOKEN_DURATION} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_AUTH_KEYS_ACTIVE_KEY_PATH: ${MG_AUTH_KEYS_ACTIVE_KEY_PATH:+/keys/active.key} + MG_AUTH_KEYS_RETIRING_KEY_PATH: ${MG_AUTH_KEYS_RETIRING_KEY_PATH:+/keys/retiring.key} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_AUTH_GRPC_SERVER_CERT: ${SMQ_AUTH_GRPC_SERVER_CERT:+/auth-grpc-server.crt} - SMQ_AUTH_GRPC_SERVER_KEY: ${SMQ_AUTH_GRPC_SERVER_KEY:+/auth-grpc-server.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_GRPC_CLIENT_CA_CERTS: ${SMQ_AUTH_GRPC_CLIENT_CA_CERTS:+/auth-grpc-client-ca.crt} - SMQ_AUTH_DB_HOST: ${SMQ_AUTH_DB_HOST} - SMQ_AUTH_DB_PORT: ${SMQ_AUTH_DB_PORT} - SMQ_AUTH_DB_USER: ${SMQ_AUTH_DB_USER} - SMQ_AUTH_DB_PASS: ${SMQ_AUTH_DB_PASS} - SMQ_AUTH_DB_NAME: ${SMQ_AUTH_DB_NAME} - SMQ_AUTH_DB_SSL_MODE: ${SMQ_AUTH_DB_SSL_MODE} - SMQ_AUTH_DB_SSL_CERT: ${SMQ_AUTH_DB_SSL_CERT} - SMQ_AUTH_DB_SSL_KEY: ${SMQ_AUTH_DB_SSL_KEY} - SMQ_AUTH_DB_SSL_ROOT_CERT: ${SMQ_AUTH_DB_SSL_ROOT_CERT} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_AUTH_ADAPTER_INSTANCE_ID: ${SMQ_AUTH_ADAPTER_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_AUTH_CACHE_URL: ${SMQ_AUTH_CACHE_URL} + MG_AUTH_GRPC_SERVER_CERT: ${MG_AUTH_GRPC_SERVER_CERT:+/auth-grpc-server.crt} + MG_AUTH_GRPC_SERVER_KEY: ${MG_AUTH_GRPC_SERVER_KEY:+/auth-grpc-server.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_GRPC_CLIENT_CA_CERTS: ${MG_AUTH_GRPC_CLIENT_CA_CERTS:+/auth-grpc-client-ca.crt} + MG_AUTH_DB_HOST: ${MG_AUTH_DB_HOST} + MG_AUTH_DB_PORT: ${MG_AUTH_DB_PORT} + MG_AUTH_DB_USER: ${MG_AUTH_DB_USER} + MG_AUTH_DB_PASS: ${MG_AUTH_DB_PASS} + MG_AUTH_DB_NAME: ${MG_AUTH_DB_NAME} + MG_AUTH_DB_SSL_MODE: ${MG_AUTH_DB_SSL_MODE} + MG_AUTH_DB_SSL_CERT: ${MG_AUTH_DB_SSL_CERT} + MG_AUTH_DB_SSL_KEY: ${MG_AUTH_DB_SSL_KEY} + MG_AUTH_DB_SSL_ROOT_CERT: ${MG_AUTH_DB_SSL_ROOT_CERT} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_AUTH_ADAPTER_INSTANCE_ID: ${MG_AUTH_ADAPTER_INSTANCE_ID} + MG_ES_URL: ${MG_ES_URL} + MG_AUTH_CACHE_URL: ${MG_AUTH_CACHE_URL} ports: - - ${SMQ_AUTH_HTTP_PORT}:${SMQ_AUTH_HTTP_PORT} - - ${SMQ_AUTH_GRPC_PORT}:${SMQ_AUTH_GRPC_PORT} + - ${MG_AUTH_HTTP_PORT}:${MG_AUTH_HTTP_PORT} + - ${MG_AUTH_GRPC_PORT}:${MG_AUTH_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} - - supermq-pat-db-volume:/supermq-data + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + - magistrala-pat-db-volume:/magistrala-data # Auth active private key file - type: bind - source: ${SMQ_AUTH_KEYS_ACTIVE_KEY_PATH} + source: ${MG_AUTH_KEYS_ACTIVE_KEY_PATH} target: /keys/active.key read_only: true # Auth retiring private key file (optional, for key rotation) - type: bind - source: ${SMQ_AUTH_KEYS_RETIRING_KEY_PATH:-ssl/certs/dummy/retiring_key} + source: ${MG_AUTH_KEYS_RETIRING_KEY_PATH:-./ssl/placeholder} target: /keys/retiring.key read_only: true bind: create_host_path: true # Auth gRPC mTLS server certificates - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} + source: ${MG_AUTH_GRPC_SERVER_CERT:-./ssl/placeholder} target: /auth-grpc-server.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_KEY:-ssl/certs/dummy/server_key} + source: ${MG_AUTH_GRPC_SERVER_KEY:-./ssl/placeholder} target: /auth-grpc-server.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca_certs} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_AUTH_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-client-ca.crt bind: create_host_path: true # Auth Callout Client Certificates - type: bind - source: ${SMQ_AUTH_CALLOUT_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_CALLOUT_CLIENT_CERT:-./ssl/placeholder} target: /auth-callout-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_CALLOUT_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_CALLOUT_CLIENT_KEY:-./ssl/placeholder} target: /auth-callout-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_CALLOUT_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_AUTH_CALLOUT_CLIENT_CA_CERTS:-./ssl/placeholder} target: /auth-callout-client-ca.crt bind: create_host_path: true domains-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-domains-db + container_name: magistrala-domains-db restart: on-failure ports: - 6003:5432 environment: - POSTGRES_USER: ${SMQ_DOMAINS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_DOMAINS_DB_PASS} - POSTGRES_DB: ${SMQ_DOMAINS_DB_NAME} + POSTGRES_USER: ${MG_DOMAINS_DB_USER} + POSTGRES_PASSWORD: ${MG_DOMAINS_DB_PASS} + POSTGRES_DB: ${MG_DOMAINS_DB_NAME} networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-domains-db-volume:/var/lib/postgresql/data + - magistrala-domains-db-volume:/var/lib/postgresql/data domains-redis: image: docker.io/redis:8.2.2-alpine3.22 - container_name: supermq-domains-redis + container_name: magistrala-domains-redis restart: on-failure networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-domains-redis-volume:/data + - magistrala-domains-redis-volume:/data domains: - image: docker.io/supermq/domains:${SMQ_RELEASE_TAG} - container_name: supermq-domains + image: docker.io/magistrala/domains:${MG_RELEASE_TAG} + container_name: magistrala-domains depends_on: - domains-db - spicedb + - nginx expose: - - ${SMQ_DOMAINS_GRPC_PORT} + - ${MG_DOMAINS_GRPC_PORT} restart: on-failure environment: - SMQ_DOMAINS_LOG_LEVEL: ${SMQ_DOMAINS_LOG_LEVEL} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_SPICEDB_SCHEMA_FILE: ${SMQ_SPICEDB_SCHEMA_FILE} - SMQ_DOMAINS_HTTP_HOST: ${SMQ_DOMAINS_HTTP_HOST} - SMQ_DOMAINS_HTTP_PORT: ${SMQ_DOMAINS_HTTP_PORT} - SMQ_DOMAINS_HTTP_SERVER_CERT: ${SMQ_DOMAINS_HTTP_SERVER_CERT} - SMQ_DOMAINS_HTTP_SERVER_KEY: ${SMQ_DOMAINS_HTTP_SERVER_KEY} - SMQ_DOMAINS_GRPC_HOST: ${SMQ_DOMAINS_GRPC_HOST} - SMQ_DOMAINS_GRPC_PORT: ${SMQ_DOMAINS_GRPC_PORT} + MG_DOMAINS_LOG_LEVEL: ${MG_DOMAINS_LOG_LEVEL} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_DOMAINS_HTTP_HOST: ${MG_DOMAINS_HTTP_HOST} + MG_DOMAINS_HTTP_PORT: ${MG_DOMAINS_HTTP_PORT} + MG_DOMAINS_HTTP_SERVER_CERT: ${MG_DOMAINS_HTTP_SERVER_CERT} + MG_DOMAINS_HTTP_SERVER_KEY: ${MG_DOMAINS_HTTP_SERVER_KEY} + MG_DOMAINS_GRPC_HOST: ${MG_DOMAINS_GRPC_HOST} + MG_DOMAINS_GRPC_PORT: ${MG_DOMAINS_GRPC_PORT} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_DOMAINS_GRPC_SERVER_CERT: ${SMQ_DOMAINS_GRPC_SERVER_CERT:+/domains-grpc-server.crt} - SMQ_DOMAINS_GRPC_SERVER_KEY: ${SMQ_DOMAINS_GRPC_SERVER_KEY:+/domains-grpc-server.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS: ${SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS:+/domains-grpc-client-ca.crt} - SMQ_DOMAINS_DB_HOST: ${SMQ_DOMAINS_DB_HOST} - SMQ_DOMAINS_DB_PORT: ${SMQ_DOMAINS_DB_PORT} - SMQ_DOMAINS_DB_USER: ${SMQ_DOMAINS_DB_USER} - SMQ_DOMAINS_DB_PASS: ${SMQ_DOMAINS_DB_PASS} - SMQ_DOMAINS_DB_NAME: ${SMQ_DOMAINS_DB_NAME} - SMQ_DOMAINS_DB_SSL_MODE: ${SMQ_DOMAINS_DB_SSL_MODE} - SMQ_DOMAINS_DB_SSL_CERT: ${SMQ_DOMAINS_DB_SSL_CERT} - SMQ_DOMAINS_DB_SSL_KEY: ${SMQ_DOMAINS_DB_SSL_KEY} - SMQ_DOMAINS_DB_SSL_ROOT_CERT: ${SMQ_DOMAINS_DB_SSL_ROOT_CERT} - SMQ_DOMAINS_INSTANCE_ID: ${SMQ_DOMAINS_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_DOMAINS_CACHE_URL: ${SMQ_DOMAINS_CACHE_URL} - SMQ_DOMAINS_CACHE_KEY_DURATION: ${SMQ_DOMAINS_CACHE_KEY_DURATION} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_GROUPS_GRPC_URL: ${SMQ_GROUPS_GRPC_URL} - SMQ_GROUPS_GRPC_TIMEOUT: ${SMQ_GROUPS_GRPC_TIMEOUT} - SMQ_GROUPS_GRPC_CLIENT_CERT: ${SMQ_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} - SMQ_GROUPS_GRPC_CLIENT_KEY: ${SMQ_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} - SMQ_GROUPS_GRPC_SERVER_CA_CERTS: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} - SMQ_CHANNELS_URL: ${SMQ_CHANNELS_URL} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_DOMAINS_CALLOUT_URLS: ${SMQ_DOMAINS_CALLOUT_URLS} - SMQ_DOMAINS_CALLOUT_METHOD: ${SMQ_DOMAINS_CALLOUT_METHOD} - SMQ_DOMAINS_CALLOUT_TLS_VERIFICATION: ${SMQ_DOMAINS_CALLOUT_TLS_VERIFICATION} - SMQ_DOMAINS_CALLOUT_TIMEOUT: ${SMQ_DOMAINS_CALLOUT_TIMEOUT} - SMQ_DOMAINS_CALLOUT_CA_CERT: ${SMQ_DOMAINS_CALLOUT_CA_CERT} - SMQ_DOMAINS_CALLOUT_CERT: ${SMQ_DOMAINS_CALLOUT_CERT} - SMQ_DOMAINS_CALLOUT_KEY: ${SMQ_DOMAINS_CALLOUT_KEY} - SMQ_DOMAINS_CALLOUT_OPERATIONS: ${SMQ_DOMAINS_CALLOUT_OPERATIONS} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} + MG_DOMAINS_GRPC_SERVER_CERT: ${MG_DOMAINS_GRPC_SERVER_CERT:+/domains-grpc-server.crt} + MG_DOMAINS_GRPC_SERVER_KEY: ${MG_DOMAINS_GRPC_SERVER_KEY:+/domains-grpc-server.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_DOMAINS_GRPC_CLIENT_CA_CERTS: ${MG_DOMAINS_GRPC_CLIENT_CA_CERTS:+/domains-grpc-client-ca.crt} + MG_DOMAINS_DB_HOST: ${MG_DOMAINS_DB_HOST} + MG_DOMAINS_DB_PORT: ${MG_DOMAINS_DB_PORT} + MG_DOMAINS_DB_USER: ${MG_DOMAINS_DB_USER} + MG_DOMAINS_DB_PASS: ${MG_DOMAINS_DB_PASS} + MG_DOMAINS_DB_NAME: ${MG_DOMAINS_DB_NAME} + MG_DOMAINS_DB_SSL_MODE: ${MG_DOMAINS_DB_SSL_MODE} + MG_DOMAINS_DB_SSL_CERT: ${MG_DOMAINS_DB_SSL_CERT} + MG_DOMAINS_DB_SSL_KEY: ${MG_DOMAINS_DB_SSL_KEY} + MG_DOMAINS_DB_SSL_ROOT_CERT: ${MG_DOMAINS_DB_SSL_ROOT_CERT} + MG_DOMAINS_INSTANCE_ID: ${MG_DOMAINS_INSTANCE_ID} + MG_ES_URL: ${MG_ES_URL} + MG_DOMAINS_CACHE_URL: ${MG_DOMAINS_CACHE_URL} + MG_DOMAINS_CACHE_KEY_DURATION: ${MG_DOMAINS_CACHE_KEY_DURATION} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_GROUPS_GRPC_URL: ${MG_GROUPS_GRPC_URL} + MG_GROUPS_GRPC_TIMEOUT: ${MG_GROUPS_GRPC_TIMEOUT} + MG_GROUPS_GRPC_CLIENT_CERT: ${MG_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} + MG_GROUPS_GRPC_CLIENT_KEY: ${MG_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} + MG_GROUPS_GRPC_SERVER_CA_CERTS: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} + MG_CHANNELS_URL: ${MG_CHANNELS_URL} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_DOMAINS_CALLOUT_URLS: ${MG_DOMAINS_CALLOUT_URLS} + MG_DOMAINS_CALLOUT_METHOD: ${MG_DOMAINS_CALLOUT_METHOD} + MG_DOMAINS_CALLOUT_TLS_VERIFICATION: ${MG_DOMAINS_CALLOUT_TLS_VERIFICATION} + MG_DOMAINS_CALLOUT_TIMEOUT: ${MG_DOMAINS_CALLOUT_TIMEOUT} + MG_DOMAINS_CALLOUT_CA_CERT: ${MG_DOMAINS_CALLOUT_CA_CERT} + MG_DOMAINS_CALLOUT_CERT: ${MG_DOMAINS_CALLOUT_CERT} + MG_DOMAINS_CALLOUT_KEY: ${MG_DOMAINS_CALLOUT_KEY} + MG_DOMAINS_CALLOUT_OPERATIONS: ${MG_DOMAINS_CALLOUT_OPERATIONS} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_DOMAINS_HTTP_PORT}:${SMQ_DOMAINS_HTTP_PORT} - - ${SMQ_DOMAINS_GRPC_PORT}:${SMQ_DOMAINS_GRPC_PORT} + - ${MG_DOMAINS_HTTP_PORT}:${MG_DOMAINS_HTTP_PORT} + - ${MG_DOMAINS_GRPC_PORT}:${MG_DOMAINS_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - ./permission.yaml:/permission.yaml - - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} - # Auth gRPC mTLS server certificates + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + # Domains gRPC mTLS server certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} + source: ${MG_DOMAINS_GRPC_SERVER_CERT:-./ssl/placeholder} target: /domains-grpc-server.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_KEY:-ssl/certs/dummy/server_key} + source: ${MG_DOMAINS_GRPC_SERVER_KEY:-./ssl/placeholder} target: /domains-grpc-server.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca_certs} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_DOMAINS_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-client-ca.crt bind: create_host_path: true # Auth gRPC client certificates - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true # Groups gRPC client certificates - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_GROUPS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /groups-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_GROUPS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /groups-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /groups-grpc-server-ca.crt bind: create_host_path: true # Channels gRPC client certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /channels-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /channels-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-server-ca.crt bind: create_host_path: true # Clients gRPC client certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /clients-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /clients-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /clients-grpc-server-ca.crt bind: create_host_path: true + journal-db: + image: postgres:16.2-alpine + container_name: magistrala-journal-db + restart: on-failure + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_JOURNAL_DB_USER} + POSTGRES_PASSWORD: ${MG_JOURNAL_DB_PASS} + POSTGRES_DB: ${MG_JOURNAL_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} + networks: + - magistrala-base-net + volumes: + - magistrala-journal-volume:/var/lib/postgresql/data + + journal: + image: docker.io/magistrala/journal:${MG_RELEASE_TAG} + container_name: magistrala-journal + depends_on: + - journal-db + - auth + - domains + - nginx + restart: on-failure + environment: + MG_JOURNAL_LOG_LEVEL: ${MG_JOURNAL_LOG_LEVEL} + MG_JOURNAL_HTTP_HOST: ${MG_JOURNAL_HTTP_HOST} + MG_JOURNAL_HTTP_PORT: ${MG_JOURNAL_HTTP_PORT} + MG_JOURNAL_HTTP_SERVER_CERT: ${MG_JOURNAL_HTTP_SERVER_CERT} + MG_JOURNAL_HTTP_SERVER_KEY: ${MG_JOURNAL_HTTP_SERVER_KEY} + MG_JOURNAL_DB_HOST: ${MG_JOURNAL_DB_HOST} + MG_JOURNAL_DB_PORT: ${MG_JOURNAL_DB_PORT} + MG_JOURNAL_DB_USER: ${MG_JOURNAL_DB_USER} + MG_JOURNAL_DB_PASS: ${MG_JOURNAL_DB_PASS} + MG_JOURNAL_DB_NAME: ${MG_JOURNAL_DB_NAME} + MG_JOURNAL_DB_SSL_MODE: ${MG_JOURNAL_DB_SSL_MODE} + MG_JOURNAL_DB_SSL_CERT: ${MG_JOURNAL_DB_SSL_CERT} + MG_JOURNAL_DB_SSL_KEY: ${MG_JOURNAL_DB_SSL_KEY} + MG_JOURNAL_DB_SSL_ROOT_CERT: ${MG_JOURNAL_DB_SSL_ROOT_CERT} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_JOURNAL_INSTANCE_ID: ${MG_JOURNAL_INSTANCE_ID} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + ports: + - ${MG_JOURNAL_HTTP_PORT}:${MG_JOURNAL_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca.crt + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /domains-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /domains-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /domains-grpc-server-ca.crt + bind: + create_host_path: true + nginx: image: docker.io/nginx:1.29.2-alpine3.22 - container_name: supermq-nginx + container_name: magistrala-nginx restart: on-failure volumes: - ./nginx/nginx-${AUTH-key}.conf:/etc/nginx/nginx.conf.template @@ -408,33 +514,31 @@ services: - ./nginx/snippets:/etc/nginx/snippets - ./ssl/authorization.js:/etc/nginx/authorization.js - type: bind - source: ${SMQ_NGINX_SERVER_CERT:-./ssl/certs/supermq-server.crt} - target: /etc/ssl/certs/supermq-server.crt + source: ${MG_NGINX_SERVER_CERT:-./ssl/certs/magistrala-server.crt} + target: /etc/ssl/certs/magistrala-server.crt - type: bind - source: ${SMQ_NGINX_SERVER_KEY:-./ssl/certs/supermq-server.key} - target: /etc/ssl/private/supermq-server.key + source: ${MG_NGINX_SERVER_KEY:-./ssl/certs/magistrala-server.key} + target: /etc/ssl/private/magistrala-server.key - type: bind - source: ${SMQ_NGINX_SERVER_CLIENT_CA:-./ssl/certs/ca.crt} + source: ${MG_NGINX_SERVER_CLIENT_CA:-./ssl/certs/ca.crt} target: /etc/ssl/certs/ca.crt - type: bind - source: ${SMQ_NGINX_SERVER_DHPARAM:-./ssl/dhparam.pem} + source: ${MG_NGINX_SERVER_DHPARAM:-./ssl/dhparam.pem} target: /etc/ssl/certs/dhparam.pem ports: - - ${SMQ_NGINX_HTTP_PORT}:${SMQ_NGINX_HTTP_PORT} - - ${SMQ_NGINX_SSL_PORT}:${SMQ_NGINX_SSL_PORT} - - ${SMQ_NGINX_MQTT_PORT}:${SMQ_NGINX_MQTT_PORT} - - ${SMQ_NGINX_MQTTS_PORT}:${SMQ_NGINX_MQTTS_PORT} + - ${MG_NGINX_HTTP_PORT}:${MG_NGINX_HTTP_PORT} + - ${MG_NGINX_SSL_PORT}:${MG_NGINX_SSL_PORT} + - ${MG_NGINX_MQTT_PORT}:${MG_NGINX_MQTT_PORT} + - ${MG_NGINX_MQTTS_PORT}:${MG_NGINX_MQTTS_PORT} + - ${MG_NGINX_AMQP_PORT}:${MG_NGINX_AMQP_PORT} networks: - - supermq-base-net + - magistrala-base-net env_file: - .env depends_on: - - auth - - clients - - users - - mqtt-adapter - - http-adapter - - coap-adapter + - fluxmq-node1 + - fluxmq-node2 + - fluxmq-node3 ulimits: nofile: soft: 65536 @@ -442,1163 +546,1828 @@ services: clients-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-clients-db + container_name: magistrala-clients-db restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" environment: - POSTGRES_USER: ${SMQ_CLIENTS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_CLIENTS_DB_PASS} - POSTGRES_DB: ${SMQ_CLIENTS_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} + POSTGRES_USER: ${MG_CLIENTS_DB_USER} + POSTGRES_PASSWORD: ${MG_CLIENTS_DB_PASS} + POSTGRES_DB: ${MG_CLIENTS_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} networks: - - supermq-base-net + - magistrala-base-net ports: - 6006:5432 volumes: - - supermq-clients-db-volume:/var/lib/postgresql/data + - magistrala-clients-db-volume:/var/lib/postgresql/data clients-redis: image: docker.io/redis:8.2.2-alpine3.22 - container_name: supermq-clients-redis + container_name: magistrala-clients-redis restart: on-failure networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-clients-redis-volume:/data + - magistrala-clients-redis-volume:/data clients: - image: docker.io/supermq/clients:${SMQ_RELEASE_TAG} - container_name: supermq-clients + image: docker.io/magistrala/clients:${MG_RELEASE_TAG} + container_name: magistrala-clients depends_on: - clients-db - users - auth - - nats + - nginx restart: on-failure environment: - SMQ_CLIENTS_LOG_LEVEL: ${SMQ_CLIENTS_LOG_LEVEL} - SMQ_CLIENTS_STANDALONE_ID: ${SMQ_CLIENTS_STANDALONE_ID} - SMQ_CLIENTS_STANDALONE_TOKEN: ${SMQ_CLIENTS_STANDALONE_TOKEN} - SMQ_CLIENTS_CACHE_KEY_DURATION: ${SMQ_CLIENTS_CACHE_KEY_DURATION} - SMQ_CLIENTS_HTTP_HOST: ${SMQ_CLIENTS_HTTP_HOST} - SMQ_CLIENTS_HTTP_PORT: ${SMQ_CLIENTS_HTTP_PORT} - SMQ_CLIENTS_GRPC_HOST: ${SMQ_CLIENTS_GRPC_HOST} - SMQ_CLIENTS_GRPC_PORT: ${SMQ_CLIENTS_GRPC_PORT} + MG_CLIENTS_LOG_LEVEL: ${MG_CLIENTS_LOG_LEVEL} + MG_CLIENTS_STANDALONE_ID: ${MG_CLIENTS_STANDALONE_ID} + MG_CLIENTS_STANDALONE_TOKEN: ${MG_CLIENTS_STANDALONE_TOKEN} + MG_CLIENTS_CACHE_KEY_DURATION: ${MG_CLIENTS_CACHE_KEY_DURATION} + MG_CLIENTS_HTTP_HOST: ${MG_CLIENTS_HTTP_HOST} + MG_CLIENTS_HTTP_PORT: ${MG_CLIENTS_HTTP_PORT} + MG_CLIENTS_GRPC_HOST: ${MG_CLIENTS_GRPC_HOST} + MG_CLIENTS_GRPC_PORT: ${MG_CLIENTS_GRPC_PORT} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_CLIENTS_GRPC_SERVER_CERT: ${SMQ_CLIENTS_GRPC_SERVER_CERT:+/clients-grpc-server.crt} - SMQ_CLIENTS_GRPC_SERVER_KEY: ${SMQ_CLIENTS_GRPC_SERVER_KEY:+/clients-grpc-server.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_CLIENTS_GRPC_CLIENT_CA_CERTS: ${SMQ_CLIENTS_GRPC_CLIENT_CA_CERTS:+/clients-grpc-client-ca.crt} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_CLIENTS_CACHE_URL: ${SMQ_CLIENTS_CACHE_URL} - SMQ_CLIENTS_DB_HOST: ${SMQ_CLIENTS_DB_HOST} - SMQ_CLIENTS_DB_PORT: ${SMQ_CLIENTS_DB_PORT} - SMQ_CLIENTS_DB_USER: ${SMQ_CLIENTS_DB_USER} - SMQ_CLIENTS_DB_PASS: ${SMQ_CLIENTS_DB_PASS} - SMQ_CLIENTS_DB_NAME: ${SMQ_CLIENTS_DB_NAME} - SMQ_CLIENTS_DB_SSL_MODE: ${SMQ_CLIENTS_DB_SSL_MODE} - SMQ_CLIENTS_DB_SSL_CERT: ${SMQ_CLIENTS_DB_SSL_CERT} - SMQ_CLIENTS_DB_SSL_KEY: ${SMQ_CLIENTS_DB_SSL_KEY} - SMQ_CLIENTS_DB_SSL_ROOT_CERT: ${SMQ_CLIENTS_DB_SSL_ROOT_CERT} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_CHANNELS_URL: ${SMQ_CHANNELS_URL} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_GROUPS_URL: ${SMQ_GROUPS_URL} - SMQ_GROUPS_GRPC_URL: ${SMQ_GROUPS_GRPC_URL} - SMQ_GROUPS_GRPC_TIMEOUT: ${SMQ_GROUPS_GRPC_TIMEOUT} - SMQ_GROUPS_GRPC_CLIENT_CERT: ${SMQ_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} - SMQ_GROUPS_GRPC_CLIENT_KEY: ${SMQ_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} - SMQ_GROUPS_GRPC_SERVER_CA_CERTS: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_SPICEDB_SCHEMA_FILE: ${SMQ_SPICEDB_SCHEMA_FILE} - SMQ_CLIENTS_CALLOUT_URLS: ${SMQ_CLIENTS_CALLOUT_URLS} - SMQ_CLIENTS_CALLOUT_METHOD: ${SMQ_CLIENTS_CALLOUT_METHOD} - SMQ_CLIENTS_CALLOUT_TLS_VERIFICATION: ${SMQ_CLIENTS_CALLOUT_TLS_VERIFICATION} - SMQ_CLIENTS_CALLOUT_TIMEOUT: ${SMQ_CLIENTS_CALLOUT_TIMEOUT} - SMQ_CLIENTS_CALLOUT_CA_CERT: ${SMQ_CLIENTS_CALLOUT_CA_CERT} - SMQ_CLIENTS_CALLOUT_CERT: ${SMQ_CLIENTS_CALLOUT_CERT} - SMQ_CLIENTS_CALLOUT_KEY: ${SMQ_CLIENTS_CALLOUT_KEY} - SMQ_CLIENTS_CALLOUT_OPERATIONS: ${SMQ_CLIENTS_CALLOUT_OPERATIONS} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} + MG_CLIENTS_GRPC_SERVER_CERT: ${MG_CLIENTS_GRPC_SERVER_CERT:+/clients-grpc-server.crt} + MG_CLIENTS_GRPC_SERVER_KEY: ${MG_CLIENTS_GRPC_SERVER_KEY:+/clients-grpc-server.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_CLIENTS_GRPC_CLIENT_CA_CERTS: ${MG_CLIENTS_GRPC_CLIENT_CA_CERTS:+/clients-grpc-client-ca.crt} + MG_ES_URL: ${MG_ES_URL} + MG_CLIENTS_CACHE_URL: ${MG_CLIENTS_CACHE_URL} + MG_CLIENTS_DB_HOST: ${MG_CLIENTS_DB_HOST} + MG_CLIENTS_DB_PORT: ${MG_CLIENTS_DB_PORT} + MG_CLIENTS_DB_USER: ${MG_CLIENTS_DB_USER} + MG_CLIENTS_DB_PASS: ${MG_CLIENTS_DB_PASS} + MG_CLIENTS_DB_NAME: ${MG_CLIENTS_DB_NAME} + MG_CLIENTS_DB_SSL_MODE: ${MG_CLIENTS_DB_SSL_MODE} + MG_CLIENTS_DB_SSL_CERT: ${MG_CLIENTS_DB_SSL_CERT} + MG_CLIENTS_DB_SSL_KEY: ${MG_CLIENTS_DB_SSL_KEY} + MG_CLIENTS_DB_SSL_ROOT_CERT: ${MG_CLIENTS_DB_SSL_ROOT_CERT} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_CHANNELS_URL: ${MG_CHANNELS_URL} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_GROUPS_URL: ${MG_GROUPS_URL} + MG_GROUPS_GRPC_URL: ${MG_GROUPS_GRPC_URL} + MG_GROUPS_GRPC_TIMEOUT: ${MG_GROUPS_GRPC_TIMEOUT} + MG_GROUPS_GRPC_CLIENT_CERT: ${MG_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} + MG_GROUPS_GRPC_CLIENT_KEY: ${MG_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} + MG_GROUPS_GRPC_SERVER_CA_CERTS: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_CLIENTS_CALLOUT_URLS: ${MG_CLIENTS_CALLOUT_URLS} + MG_CLIENTS_CALLOUT_METHOD: ${MG_CLIENTS_CALLOUT_METHOD} + MG_CLIENTS_CALLOUT_TLS_VERIFICATION: ${MG_CLIENTS_CALLOUT_TLS_VERIFICATION} + MG_CLIENTS_CALLOUT_TIMEOUT: ${MG_CLIENTS_CALLOUT_TIMEOUT} + MG_CLIENTS_CALLOUT_CA_CERT: ${MG_CLIENTS_CALLOUT_CA_CERT} + MG_CLIENTS_CALLOUT_CERT: ${MG_CLIENTS_CALLOUT_CERT} + MG_CLIENTS_CALLOUT_KEY: ${MG_CLIENTS_CALLOUT_KEY} + MG_CLIENTS_CALLOUT_OPERATIONS: ${MG_CLIENTS_CALLOUT_OPERATIONS} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_CLIENTS_HTTP_PORT}:${SMQ_CLIENTS_HTTP_PORT} - - ${SMQ_CLIENTS_GRPC_PORT}:${SMQ_CLIENTS_GRPC_PORT} + - ${MG_CLIENTS_HTTP_PORT}:${MG_CLIENTS_HTTP_PORT} + - ${MG_CLIENTS_GRPC_PORT}:${MG_CLIENTS_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - ./permission.yaml:/permission.yaml - - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} # Clients gRPC server certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} + source: ${MG_CLIENTS_GRPC_SERVER_CERT:-./ssl/placeholder} target: /clients-grpc-server.crt bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_KEY:-ssl/certs/dummy/server_key} + source: ${MG_CLIENTS_GRPC_SERVER_KEY:-./ssl/placeholder} target: /clients-grpc-server.key bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca_certs} + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /clients-grpc-server-ca.crt bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_CLIENTS_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} target: /clients-grpc-client-ca.crt bind: create_host_path: true # Auth gRPC client certificates - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true # Channel gRPC client certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /channels-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /channels-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-server-ca.crt bind: create_host_path: true # Group gRPC client certificates - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_GROUPS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /groups-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_GROUPS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /groups-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /groups-grpc-server-ca.crt bind: create_host_path: true # Domain gRPC client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /domains-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /domains-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true channels-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-channels-db + container_name: magistrala-channels-db restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" environment: - POSTGRES_USER: ${SMQ_CHANNELS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_CHANNELS_DB_PASS} - POSTGRES_DB: ${SMQ_CHANNELS_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} + POSTGRES_USER: ${MG_CHANNELS_DB_USER} + POSTGRES_PASSWORD: ${MG_CHANNELS_DB_PASS} + POSTGRES_DB: ${MG_CHANNELS_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} networks: - - supermq-base-net + - magistrala-base-net ports: - 6005:5432 volumes: - - supermq-channels-db-volume:/var/lib/postgresql/data + - magistrala-channels-db-volume:/var/lib/postgresql/data channels-redis: image: docker.io/redis:8.2.2-alpine3.22 - container_name: supermq-channels-redis + container_name: magistrala-channels-redis restart: on-failure networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-channels-redis-volume:/data + - magistrala-channels-redis-volume:/data channels: - image: docker.io/supermq/channels:${SMQ_RELEASE_TAG} - container_name: supermq-channels + image: docker.io/magistrala/channels:${MG_RELEASE_TAG} + container_name: magistrala-channels depends_on: - channels-db - channels-redis - users - auth - - nats + - nginx restart: on-failure environment: - SMQ_CHANNELS_LOG_LEVEL: ${SMQ_CHANNELS_LOG_LEVEL} - SMQ_CHANNELS_INSTANCE_ID: ${SMQ_CHANNELS_INSTANCE_ID} - SMQ_CHANNELS_HTTP_HOST: ${SMQ_CHANNELS_HTTP_HOST} - SMQ_CHANNELS_HTTP_PORT: ${SMQ_CHANNELS_HTTP_PORT} - SMQ_CHANNELS_GRPC_HOST: ${SMQ_CHANNELS_GRPC_HOST} - SMQ_CHANNELS_GRPC_PORT: ${SMQ_CHANNELS_GRPC_PORT} + MG_CHANNELS_LOG_LEVEL: ${MG_CHANNELS_LOG_LEVEL} + MG_CHANNELS_INSTANCE_ID: ${MG_CHANNELS_INSTANCE_ID} + MG_CHANNELS_HTTP_HOST: ${MG_CHANNELS_HTTP_HOST} + MG_CHANNELS_HTTP_PORT: ${MG_CHANNELS_HTTP_PORT} + MG_CHANNELS_GRPC_HOST: ${MG_CHANNELS_GRPC_HOST} + MG_CHANNELS_GRPC_PORT: ${MG_CHANNELS_GRPC_PORT} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_CHANNELS_GRPC_SERVER_CERT: ${SMQ_CHANNELS_GRPC_SERVER_CERT:+/channels-grpc-server.crt} - SMQ_CHANNELS_GRPC_SERVER_KEY: ${SMQ_CHANNELS_GRPC_SERVER_KEY:+/channels-grpc-server.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_CHANNELS_GRPC_CLIENT_CA_CERTS: ${SMQ_CHANNELS_GRPC_CLIENT_CA_CERTS:+/channels-grpc-client-ca.crt} - SMQ_CHANNELS_DB_HOST: ${SMQ_CHANNELS_DB_HOST} - SMQ_CHANNELS_DB_PORT: ${SMQ_CHANNELS_DB_PORT} - SMQ_CHANNELS_DB_USER: ${SMQ_CHANNELS_DB_USER} - SMQ_CHANNELS_DB_PASS: ${SMQ_CHANNELS_DB_PASS} - SMQ_CHANNELS_DB_NAME: ${SMQ_CHANNELS_DB_NAME} - SMQ_CHANNELS_DB_SSL_MODE: ${SMQ_CHANNELS_DB_SSL_MODE} - SMQ_CHANNELS_DB_SSL_CERT: ${SMQ_CHANNELS_DB_SSL_CERT} - SMQ_CHANNELS_DB_SSL_KEY: ${SMQ_CHANNELS_DB_SSL_KEY} - SMQ_CHANNELS_DB_SSL_ROOT_CERT: ${SMQ_CHANNELS_DB_SSL_ROOT_CERT} - SMQ_CHANNELS_CACHE_URL: ${SMQ_CHANNELS_CACHE_URL} - SMQ_CHANNELS_CACHE_KEY_DURATION: ${SMQ_CHANNELS_CACHE_KEY_DURATION} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_GROUPS_GRPC_URL: ${SMQ_GROUPS_GRPC_URL} - SMQ_GROUPS_GRPC_TIMEOUT: ${SMQ_GROUPS_GRPC_TIMEOUT} - SMQ_GROUPS_GRPC_CLIENT_CERT: ${SMQ_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} - SMQ_GROUPS_GRPC_CLIENT_KEY: ${SMQ_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} - SMQ_GROUPS_GRPC_SERVER_CA_CERTS: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_SPICEDB_SCHEMA_FILE: ${SMQ_SPICEDB_SCHEMA_FILE} - SMQ_CHANNELS_CALLOUT_URLS: ${SMQ_CHANNELS_CALLOUT_URLS} - SMQ_CHANNELS_CALLOUT_METHOD: ${SMQ_CHANNELS_CALLOUT_METHOD} - SMQ_CHANNELS_CALLOUT_TLS_VERIFICATION: ${SMQ_CHANNELS_CALLOUT_TLS_VERIFICATION} - SMQ_CHANNELS_CALLOUT_TIMEOUT: ${SMQ_CHANNELS_CALLOUT_TIMEOUT} - SMQ_CHANNELS_CALLOUT_CA_CERT: ${SMQ_CHANNELS_CALLOUT_CA_CERT} - SMQ_CHANNELS_CALLOUT_CERT: ${SMQ_CHANNELS_CALLOUT_CERT} - SMQ_CHANNELS_CALLOUT_KEY: ${SMQ_CHANNELS_CALLOUT_KEY} - SMQ_CHANNELS_CALLOUT_OPERATIONS: ${SMQ_CHANNELS_CALLOUT_OPERATIONS} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} + MG_CHANNELS_GRPC_SERVER_CERT: ${MG_CHANNELS_GRPC_SERVER_CERT:+/channels-grpc-server.crt} + MG_CHANNELS_GRPC_SERVER_KEY: ${MG_CHANNELS_GRPC_SERVER_KEY:+/channels-grpc-server.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_CHANNELS_GRPC_CLIENT_CA_CERTS: ${MG_CHANNELS_GRPC_CLIENT_CA_CERTS:+/channels-grpc-client-ca.crt} + MG_CHANNELS_DB_HOST: ${MG_CHANNELS_DB_HOST} + MG_CHANNELS_DB_PORT: ${MG_CHANNELS_DB_PORT} + MG_CHANNELS_DB_USER: ${MG_CHANNELS_DB_USER} + MG_CHANNELS_DB_PASS: ${MG_CHANNELS_DB_PASS} + MG_CHANNELS_DB_NAME: ${MG_CHANNELS_DB_NAME} + MG_CHANNELS_DB_SSL_MODE: ${MG_CHANNELS_DB_SSL_MODE} + MG_CHANNELS_DB_SSL_CERT: ${MG_CHANNELS_DB_SSL_CERT} + MG_CHANNELS_DB_SSL_KEY: ${MG_CHANNELS_DB_SSL_KEY} + MG_CHANNELS_DB_SSL_ROOT_CERT: ${MG_CHANNELS_DB_SSL_ROOT_CERT} + MG_CHANNELS_CACHE_URL: ${MG_CHANNELS_CACHE_URL} + MG_CHANNELS_CACHE_KEY_DURATION: ${MG_CHANNELS_CACHE_KEY_DURATION} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_GROUPS_GRPC_URL: ${MG_GROUPS_GRPC_URL} + MG_GROUPS_GRPC_TIMEOUT: ${MG_GROUPS_GRPC_TIMEOUT} + MG_GROUPS_GRPC_CLIENT_CERT: ${MG_GROUPS_GRPC_CLIENT_CERT:+/groups-grpc-client.crt} + MG_GROUPS_GRPC_CLIENT_KEY: ${MG_GROUPS_GRPC_CLIENT_KEY:+/groups-grpc-client.key} + MG_GROUPS_GRPC_SERVER_CA_CERTS: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_CHANNELS_CALLOUT_URLS: ${MG_CHANNELS_CALLOUT_URLS} + MG_CHANNELS_CALLOUT_METHOD: ${MG_CHANNELS_CALLOUT_METHOD} + MG_CHANNELS_CALLOUT_TLS_VERIFICATION: ${MG_CHANNELS_CALLOUT_TLS_VERIFICATION} + MG_CHANNELS_CALLOUT_TIMEOUT: ${MG_CHANNELS_CALLOUT_TIMEOUT} + MG_CHANNELS_CALLOUT_CA_CERT: ${MG_CHANNELS_CALLOUT_CA_CERT} + MG_CHANNELS_CALLOUT_CERT: ${MG_CHANNELS_CALLOUT_CERT} + MG_CHANNELS_CALLOUT_KEY: ${MG_CHANNELS_CALLOUT_KEY} + MG_CHANNELS_CALLOUT_OPERATIONS: ${MG_CHANNELS_CALLOUT_OPERATIONS} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_CHANNELS_HTTP_PORT}:${SMQ_CHANNELS_HTTP_PORT} - - ${SMQ_CHANNELS_GRPC_PORT}:${SMQ_CHANNELS_GRPC_PORT} + - ${MG_CHANNELS_HTTP_PORT}:${MG_CHANNELS_HTTP_PORT} + - ${MG_CHANNELS_GRPC_PORT}:${MG_CHANNELS_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - ./permission.yaml:/permission.yaml - - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} # Channels gRPC server certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} + source: ${MG_CHANNELS_GRPC_SERVER_CERT:-./ssl/placeholder} target: /channels-grpc-server.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_KEY:-ssl/certs/dummy/server_key} + source: ${MG_CHANNELS_GRPC_SERVER_KEY:-./ssl/placeholder} target: /channels-grpc-server.key bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca_certs} + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-server-ca.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_CHANNELS_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-client-ca.crt bind: create_host_path: true # Auth gRPC client certificates - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true # Clients gRPC client certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /clients-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /clients-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /clients-grpc-server-ca.crt bind: create_host_path: true # Groups gRPC client certificates - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_GROUPS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /groups-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_GROUPS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /groups-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /groups-grpc-server-ca.crt bind: create_host_path: true # Domains gRPC client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /domains-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /domains-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true users-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-users-db + container_name: magistrala-users-db restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" environment: - POSTGRES_USER: ${SMQ_USERS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_USERS_DB_PASS} - POSTGRES_DB: ${SMQ_USERS_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} + POSTGRES_USER: ${MG_USERS_DB_USER} + POSTGRES_PASSWORD: ${MG_USERS_DB_PASS} + POSTGRES_DB: ${MG_USERS_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} ports: - 6002:5432 networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-users-db-volume:/var/lib/postgresql/data + - magistrala-users-db-volume:/var/lib/postgresql/data users: - image: docker.io/supermq/users:${SMQ_RELEASE_TAG} - container_name: supermq-users + image: docker.io/magistrala/users:${MG_RELEASE_TAG} + container_name: magistrala-users depends_on: - users-db - auth - - nats + - nginx restart: on-failure environment: - SMQ_USERS_LOG_LEVEL: ${SMQ_USERS_LOG_LEVEL} - SMQ_USERS_SECRET_KEY: ${SMQ_USERS_SECRET_KEY} - SMQ_USERS_ADMIN_EMAIL: ${SMQ_USERS_ADMIN_EMAIL} - SMQ_USERS_ADMIN_PASSWORD: ${SMQ_USERS_ADMIN_PASSWORD} - SMQ_USERS_ADMIN_USERNAME: ${SMQ_USERS_ADMIN_USERNAME} - SMQ_USERS_ADMIN_FIRST_NAME: ${SMQ_USERS_ADMIN_FIRST_NAME} - SMQ_USERS_ADMIN_LAST_NAME: ${SMQ_USERS_ADMIN_LAST_NAME} - SMQ_USERS_PASS_REGEX: ${SMQ_USERS_PASS_REGEX} - SMQ_USERS_HTTP_HOST: ${SMQ_USERS_HTTP_HOST} - SMQ_USERS_HTTP_PORT: ${SMQ_USERS_HTTP_PORT} - SMQ_USERS_HTTP_SERVER_CERT: ${SMQ_USERS_HTTP_SERVER_CERT} - SMQ_USERS_HTTP_SERVER_KEY: ${SMQ_USERS_HTTP_SERVER_KEY} - SMQ_USERS_GRPC_HOST: ${SMQ_USERS_GRPC_HOST} - SMQ_USERS_GRPC_PORT: ${SMQ_USERS_GRPC_PORT} + MG_USERS_LOG_LEVEL: ${MG_USERS_LOG_LEVEL} + MG_USERS_SECRET_KEY: ${MG_USERS_SECRET_KEY} + MG_USERS_ADMIN_EMAIL: ${MG_USERS_ADMIN_EMAIL} + MG_USERS_ADMIN_PASSWORD: ${MG_USERS_ADMIN_PASSWORD} + MG_USERS_ADMIN_USERNAME: ${MG_USERS_ADMIN_USERNAME} + MG_USERS_ADMIN_FIRST_NAME: ${MG_USERS_ADMIN_FIRST_NAME} + MG_USERS_ADMIN_LAST_NAME: ${MG_USERS_ADMIN_LAST_NAME} + MG_USERS_PASS_REGEX: ${MG_USERS_PASS_REGEX} + MG_USERS_HTTP_HOST: ${MG_USERS_HTTP_HOST} + MG_USERS_HTTP_PORT: ${MG_USERS_HTTP_PORT} + MG_USERS_HTTP_SERVER_CERT: ${MG_USERS_HTTP_SERVER_CERT} + MG_USERS_HTTP_SERVER_KEY: ${MG_USERS_HTTP_SERVER_KEY} + MG_USERS_GRPC_HOST: ${MG_USERS_GRPC_HOST} + MG_USERS_GRPC_PORT: ${MG_USERS_GRPC_PORT} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_USERS_GRPC_SERVER_CERT: ${SMQ_USERS_GRPC_SERVER_CERT:+/users-grpc-server.crt} - SMQ_USERS_GRPC_SERVER_KEY: ${SMQ_USERS_GRPC_SERVER_KEY:+/users-grpc-server.key} - SMQ_USERS_GRPC_SERVER_CA_CERTS: ${SMQ_USERS_GRPC_SERVER_CA_CERTS:+/users-grpc-server-ca.crt} - SMQ_USERS_GRPC_CLIENT_CA_CERTS: ${SMQ_USERS_GRPC_CLIENT_CA_CERTS:+/users-grpc-client-ca.crt} - SMQ_USERS_DB_HOST: ${SMQ_USERS_DB_HOST} - SMQ_USERS_DB_PORT: ${SMQ_USERS_DB_PORT} - SMQ_USERS_DB_USER: ${SMQ_USERS_DB_USER} - SMQ_USERS_DB_PASS: ${SMQ_USERS_DB_PASS} - SMQ_USERS_DB_NAME: ${SMQ_USERS_DB_NAME} - SMQ_USERS_DB_SSL_MODE: ${SMQ_USERS_DB_SSL_MODE} - SMQ_USERS_DB_SSL_CERT: ${SMQ_USERS_DB_SSL_CERT} - SMQ_USERS_DB_SSL_KEY: ${SMQ_USERS_DB_SSL_KEY} - SMQ_USERS_DB_SSL_ROOT_CERT: ${SMQ_USERS_DB_SSL_ROOT_CERT} - SMQ_USERS_ALLOW_SELF_REGISTER: ${SMQ_USERS_ALLOW_SELF_REGISTER} - SMQ_EMAIL_HOST: ${SMQ_EMAIL_HOST} - SMQ_EMAIL_PORT: ${SMQ_EMAIL_PORT} - SMQ_EMAIL_USERNAME: ${SMQ_EMAIL_USERNAME} - SMQ_EMAIL_PASSWORD: ${SMQ_EMAIL_PASSWORD} - SMQ_EMAIL_FROM_ADDRESS: ${SMQ_EMAIL_FROM_ADDRESS} - SMQ_EMAIL_FROM_NAME: ${SMQ_EMAIL_FROM_NAME} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_GOOGLE_CLIENT_ID: ${SMQ_GOOGLE_CLIENT_ID} - SMQ_GOOGLE_CLIENT_SECRET: ${SMQ_GOOGLE_CLIENT_SECRET} - SMQ_GOOGLE_REDIRECT_URL: ${SMQ_GOOGLE_REDIRECT_URL} - SMQ_GOOGLE_STATE: ${SMQ_GOOGLE_STATE} - SMQ_OAUTH_UI_REDIRECT_URL: ${SMQ_OAUTH_UI_REDIRECT_URL} - SMQ_OAUTH_UI_ERROR_URL: ${SMQ_OAUTH_UI_ERROR_URL} - SMQ_USERS_DELETE_INTERVAL: ${SMQ_USERS_DELETE_INTERVAL} - SMQ_USERS_DELETE_AFTER: ${SMQ_USERS_DELETE_AFTER} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_PASSWORD_RESET_URL_PREFIX: ${SMQ_PASSWORD_RESET_URL_PREFIX} - SMQ_PASSWORD_RESET_EMAIL_TEMPLATE: ${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE} - SMQ_VERIFICATION_URL_PREFIX: ${SMQ_VERIFICATION_URL_PREFIX} - SMQ_VERIFICATION_EMAIL_TEMPLATE: ${SMQ_VERIFICATION_EMAIL_TEMPLATE} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} + MG_USERS_GRPC_SERVER_CERT: ${MG_USERS_GRPC_SERVER_CERT:+/users-grpc-server.crt} + MG_USERS_GRPC_SERVER_KEY: ${MG_USERS_GRPC_SERVER_KEY:+/users-grpc-server.key} + MG_USERS_GRPC_SERVER_CA_CERTS: ${MG_USERS_GRPC_SERVER_CA_CERTS:+/users-grpc-server-ca.crt} + MG_USERS_GRPC_CLIENT_CA_CERTS: ${MG_USERS_GRPC_CLIENT_CA_CERTS:+/users-grpc-client-ca.crt} + MG_USERS_DB_HOST: ${MG_USERS_DB_HOST} + MG_USERS_DB_PORT: ${MG_USERS_DB_PORT} + MG_USERS_DB_USER: ${MG_USERS_DB_USER} + MG_USERS_DB_PASS: ${MG_USERS_DB_PASS} + MG_USERS_DB_NAME: ${MG_USERS_DB_NAME} + MG_USERS_DB_SSL_MODE: ${MG_USERS_DB_SSL_MODE} + MG_USERS_DB_SSL_CERT: ${MG_USERS_DB_SSL_CERT} + MG_USERS_DB_SSL_KEY: ${MG_USERS_DB_SSL_KEY} + MG_USERS_DB_SSL_ROOT_CERT: ${MG_USERS_DB_SSL_ROOT_CERT} + MG_USERS_ALLOW_SELF_REGISTER: ${MG_USERS_ALLOW_SELF_REGISTER} + MG_EMAIL_HOST: ${MG_EMAIL_HOST} + MG_EMAIL_PORT: ${MG_EMAIL_PORT} + MG_EMAIL_USERNAME: ${MG_EMAIL_USERNAME} + MG_EMAIL_PASSWORD: ${MG_EMAIL_PASSWORD} + MG_EMAIL_FROM_ADDRESS: ${MG_EMAIL_FROM_ADDRESS} + MG_EMAIL_FROM_NAME: ${MG_EMAIL_FROM_NAME} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_GOOGLE_CLIENT_ID: ${MG_GOOGLE_CLIENT_ID} + MG_GOOGLE_CLIENT_SECRET: ${MG_GOOGLE_CLIENT_SECRET} + MG_GOOGLE_REDIRECT_URL: ${MG_GOOGLE_REDIRECT_URL} + MG_GOOGLE_STATE: ${MG_GOOGLE_STATE} + MG_OAUTH_UI_REDIRECT_URL: ${MG_OAUTH_UI_REDIRECT_URL} + MG_OAUTH_UI_ERROR_URL: ${MG_OAUTH_UI_ERROR_URL} + MG_USERS_DELETE_INTERVAL: ${MG_USERS_DELETE_INTERVAL} + MG_USERS_DELETE_AFTER: ${MG_USERS_DELETE_AFTER} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_PASSWORD_RESET_URL_PREFIX: ${MG_PASSWORD_RESET_URL_PREFIX} + MG_PASSWORD_RESET_EMAIL_TEMPLATE: ${MG_PASSWORD_RESET_EMAIL_TEMPLATE} + MG_VERIFICATION_URL_PREFIX: ${MG_VERIFICATION_URL_PREFIX} + MG_VERIFICATION_EMAIL_TEMPLATE: ${MG_VERIFICATION_EMAIL_TEMPLATE} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_USERS_HTTP_PORT}:${SMQ_USERS_HTTP_PORT} - - ${SMQ_USERS_GRPC_PORT}:${SMQ_USERS_GRPC_PORT} + - ${MG_USERS_HTTP_PORT}:${MG_USERS_HTTP_PORT} + - ${MG_USERS_GRPC_PORT}:${MG_USERS_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - - ./templates/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}:/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE} - - ./templates/${SMQ_VERIFICATION_EMAIL_TEMPLATE}:/${SMQ_VERIFICATION_EMAIL_TEMPLATE} + - ./templates/${MG_PASSWORD_RESET_EMAIL_TEMPLATE}:/${MG_PASSWORD_RESET_EMAIL_TEMPLATE} + - ./templates/${MG_VERIFICATION_EMAIL_TEMPLATE}:/${MG_VERIFICATION_EMAIL_TEMPLATE} + # Users gRPC server certificates + - type: bind + source: ${MG_USERS_GRPC_SERVER_CERT:-./ssl/placeholder} + target: /users-grpc-server.crt + bind: + create_host_path: true + - type: bind + source: ${MG_USERS_GRPC_SERVER_KEY:-./ssl/placeholder} + target: /users-grpc-server.key + bind: + create_host_path: true + - type: bind + source: ${MG_USERS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /users-grpc-server-ca.crt + bind: + create_host_path: true + - type: bind + source: ${MG_USERS_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} + target: /users-grpc-client-ca.crt + bind: + create_host_path: true # Auth gRPC client certificates - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true # Domains gRPC client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /domains-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /domains-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true notifications: - image: docker.io/supermq/notifications:${SMQ_RELEASE_TAG} - container_name: supermq-notifications + image: docker.io/magistrala/notifications:${MG_RELEASE_TAG} + container_name: magistrala-notifications depends_on: - - nats + - nginx restart: on-failure environment: - SMQ_NOTIFICATIONS_LOG_LEVEL: ${SMQ_NOTIFICATIONS_LOG_LEVEL} - SMQ_NOTIFICATIONS_INSTANCE_ID: ${SMQ_NOTIFICATIONS_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_EMAIL_HOST: ${SMQ_EMAIL_HOST} - SMQ_EMAIL_PORT: ${SMQ_EMAIL_PORT} - SMQ_EMAIL_USERNAME: ${SMQ_EMAIL_USERNAME} - SMQ_EMAIL_PASSWORD: ${SMQ_EMAIL_PASSWORD} - SMQ_EMAIL_FROM_ADDRESS: ${SMQ_EMAIL_FROM_ADDRESS} - SMQ_EMAIL_FROM_NAME: ${SMQ_EMAIL_FROM_NAME} - SMQ_EMAIL_INVITATION_TEMPLATE: ${SMQ_EMAIL_INVITATION_TEMPLATE} - SMQ_EMAIL_ACCEPTANCE_TEMPLATE: ${SMQ_EMAIL_ACCEPTANCE_TEMPLATE} - SMQ_EMAIL_REJECTION_TEMPLATE: ${SMQ_EMAIL_REJECTION_TEMPLATE} - SMQ_USERS_GRPC_URL: ${SMQ_USERS_GRPC_URL} - SMQ_USERS_GRPC_TIMEOUT: ${SMQ_USERS_GRPC_TIMEOUT} - SMQ_USERS_GRPC_CLIENT_CERT: ${SMQ_USERS_GRPC_CLIENT_CERT:+/users-grpc-client.crt} - SMQ_USERS_GRPC_CLIENT_KEY: ${SMQ_USERS_GRPC_CLIENT_KEY:+/users-grpc-client.key} - SMQ_USERS_GRPC_SERVER_CA_CERTS: ${SMQ_USERS_GRPC_SERVER_CA_CERTS:+/users-grpc-server-ca.crt} + MG_NOTIFICATIONS_LOG_LEVEL: ${MG_NOTIFICATIONS_LOG_LEVEL} + MG_NOTIFICATIONS_INSTANCE_ID: ${MG_NOTIFICATIONS_INSTANCE_ID} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_EMAIL_HOST: ${MG_EMAIL_HOST} + MG_EMAIL_PORT: ${MG_EMAIL_PORT} + MG_EMAIL_USERNAME: ${MG_EMAIL_USERNAME} + MG_EMAIL_PASSWORD: ${MG_EMAIL_PASSWORD} + MG_EMAIL_FROM_ADDRESS: ${MG_EMAIL_FROM_ADDRESS} + MG_EMAIL_FROM_NAME: ${MG_EMAIL_FROM_NAME} + MG_EMAIL_INVITATION_TEMPLATE: ${MG_EMAIL_INVITATION_TEMPLATE} + MG_EMAIL_ACCEPTANCE_TEMPLATE: ${MG_EMAIL_ACCEPTANCE_TEMPLATE} + MG_EMAIL_REJECTION_TEMPLATE: ${MG_EMAIL_REJECTION_TEMPLATE} + MG_USERS_GRPC_URL: ${MG_USERS_GRPC_URL} + MG_USERS_GRPC_TIMEOUT: ${MG_USERS_GRPC_TIMEOUT} + MG_USERS_GRPC_CLIENT_CERT: ${MG_USERS_GRPC_CLIENT_CERT:+/users-grpc-client.crt} + MG_USERS_GRPC_CLIENT_KEY: ${MG_USERS_GRPC_CLIENT_KEY:+/users-grpc-client.key} + MG_USERS_GRPC_SERVER_CA_CERTS: ${MG_USERS_GRPC_SERVER_CA_CERTS:+/users-grpc-server-ca.crt} networks: - - supermq-base-net + - magistrala-base-net volumes: - - ./templates/${SMQ_EMAIL_INVITATION_TEMPLATE}:/${SMQ_EMAIL_INVITATION_TEMPLATE} - - ./templates/${SMQ_EMAIL_ACCEPTANCE_TEMPLATE}:/${SMQ_EMAIL_ACCEPTANCE_TEMPLATE} - - ./templates/${SMQ_EMAIL_REJECTION_TEMPLATE}:/${SMQ_EMAIL_REJECTION_TEMPLATE} + - ./templates/${MG_EMAIL_INVITATION_TEMPLATE}:/${MG_EMAIL_INVITATION_TEMPLATE} + - ./templates/${MG_EMAIL_ACCEPTANCE_TEMPLATE}:/${MG_EMAIL_ACCEPTANCE_TEMPLATE} + - ./templates/${MG_EMAIL_REJECTION_TEMPLATE}:/${MG_EMAIL_REJECTION_TEMPLATE} # Users gRPC client certificates - type: bind - source: ${SMQ_USERS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_USERS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /users-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_USERS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_USERS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /users-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_USERS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_USERS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /users-grpc-server-ca.crt bind: create_host_path: true groups-db: image: docker.io/postgres:18.0-alpine3.22 - container_name: supermq-groups-db + container_name: magistrala-groups-db restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" environment: - POSTGRES_USER: ${SMQ_GROUPS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_GROUPS_DB_PASS} - POSTGRES_DB: ${SMQ_GROUPS_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} + POSTGRES_USER: ${MG_GROUPS_DB_USER} + POSTGRES_PASSWORD: ${MG_GROUPS_DB_PASS} + POSTGRES_DB: ${MG_GROUPS_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} ports: - 6004:5432 networks: - - supermq-base-net + - magistrala-base-net volumes: - - supermq-groups-db-volume:/var/lib/postgresql/data + - magistrala-groups-db-volume:/var/lib/postgresql/data groups: - image: docker.io/supermq/groups:${SMQ_RELEASE_TAG} - container_name: supermq-groups + image: docker.io/magistrala/groups:${MG_RELEASE_TAG} + container_name: magistrala-groups depends_on: - groups-db - auth - - nats + - nginx restart: on-failure environment: - SMQ_GROUPS_LOG_LEVEL: ${SMQ_GROUPS_LOG_LEVEL} - SMQ_GROUPS_HTTP_HOST: ${SMQ_GROUPS_HTTP_HOST} - SMQ_GROUPS_HTTP_PORT: ${SMQ_GROUPS_HTTP_PORT} - SMQ_GROUPS_HTTP_SERVER_CERT: ${SMQ_GROUPS_HTTP_SERVER_CERT} - SMQ_GROUPS_HTTP_SERVER_KEY: ${SMQ_GROUPS_HTTP_SERVER_KEY} - SMQ_GROUPS_GRPC_HOST: ${SMQ_GROUPS_GRPC_HOST} - SMQ_GROUPS_GRPC_PORT: ${SMQ_GROUPS_GRPC_PORT} + MG_GROUPS_LOG_LEVEL: ${MG_GROUPS_LOG_LEVEL} + MG_GROUPS_HTTP_HOST: ${MG_GROUPS_HTTP_HOST} + MG_GROUPS_HTTP_PORT: ${MG_GROUPS_HTTP_PORT} + MG_GROUPS_HTTP_SERVER_CERT: ${MG_GROUPS_HTTP_SERVER_CERT} + MG_GROUPS_HTTP_SERVER_KEY: ${MG_GROUPS_HTTP_SERVER_KEY} + MG_GROUPS_GRPC_HOST: ${MG_GROUPS_GRPC_HOST} + MG_GROUPS_GRPC_PORT: ${MG_GROUPS_GRPC_PORT} ## Compose supports parameter expansion in environment, ## Eg: ${VAR:+replacement} or ${VAR+replacement} -> replacement if VAR is set and non-empty, otherwise empty ## Eg :${VAR:-default} or ${VAR-default} -> value of VAR if set and non-empty, otherwise default - SMQ_GROUPS_GRPC_SERVER_CERT: ${SMQ_GROUPS_GRPC_SERVER_CERT:+/groups-grpc-server.crt} - SMQ_GROUPS_GRPC_SERVER_KEY: ${SMQ_GROUPS_GRPC_SERVER_KEY:+/groups-grpc-server.key} - SMQ_GROUPS_GRPC_SERVER_CA_CERTS: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} - SMQ_GROUPS_GRPC_CLIENT_CA_CERTS: ${SMQ_GROUPS_GRPC_CLIENT_CA_CERTS:+/groups-grpc-client-ca.crt} - SMQ_GROUPS_DB_HOST: ${SMQ_GROUPS_DB_HOST} - SMQ_GROUPS_DB_PORT: ${SMQ_GROUPS_DB_PORT} - SMQ_GROUPS_DB_USER: ${SMQ_GROUPS_DB_USER} - SMQ_GROUPS_DB_PASS: ${SMQ_GROUPS_DB_PASS} - SMQ_GROUPS_DB_NAME: ${SMQ_GROUPS_DB_NAME} - SMQ_GROUPS_DB_SSL_MODE: ${SMQ_GROUPS_DB_SSL_MODE} - SMQ_GROUPS_DB_SSL_CERT: ${SMQ_GROUPS_DB_SSL_CERT} - SMQ_GROUPS_DB_SSL_KEY: ${SMQ_GROUPS_DB_SSL_KEY} - SMQ_GROUPS_DB_SSL_ROOT_CERT: ${SMQ_GROUPS_DB_SSL_ROOT_CERT} - SMQ_CHANNELS_URL: ${SMQ_CHANNELS_URL} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_SPICEDB_PRE_SHARED_KEY: ${SMQ_SPICEDB_PRE_SHARED_KEY} - SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} - SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} - SMQ_SPICEDB_SCHEMA_FILE: ${SMQ_SPICEDB_SCHEMA_FILE} - SMQ_GROUPS_CALLOUT_URLS: ${SMQ_GROUPS_CALLOUT_URLS} - SMQ_GROUPS_CALLOUT_METHOD: ${SMQ_GROUPS_CALLOUT_METHOD} - SMQ_GROUPS_CALLOUT_TLS_VERIFICATION: ${SMQ_GROUPS_CALLOUT_TLS_VERIFICATION} - SMQ_GROUPS_CALLOUT_TIMEOUT: ${SMQ_GROUPS_CALLOUT_TIMEOUT} - SMQ_GROUPS_CALLOUT_CA_CERT: ${SMQ_GROUPS_CALLOUT_CA_CERT} - SMQ_GROUPS_CALLOUT_CERT: ${SMQ_GROUPS_CALLOUT_CERT} - SMQ_GROUPS_CALLOUT_KEY: ${SMQ_GROUPS_CALLOUT_KEY} - SMQ_GROUPS_CALLOUT_OPERATIONS: ${SMQ_GROUPS_CALLOUT_OPERATIONS} - SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} + MG_GROUPS_GRPC_SERVER_CERT: ${MG_GROUPS_GRPC_SERVER_CERT:+/groups-grpc-server.crt} + MG_GROUPS_GRPC_SERVER_KEY: ${MG_GROUPS_GRPC_SERVER_KEY:+/groups-grpc-server.key} + MG_GROUPS_GRPC_SERVER_CA_CERTS: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:+/groups-grpc-server-ca.crt} + MG_GROUPS_GRPC_CLIENT_CA_CERTS: ${MG_GROUPS_GRPC_CLIENT_CA_CERTS:+/groups-grpc-client-ca.crt} + MG_GROUPS_DB_HOST: ${MG_GROUPS_DB_HOST} + MG_GROUPS_DB_PORT: ${MG_GROUPS_DB_PORT} + MG_GROUPS_DB_USER: ${MG_GROUPS_DB_USER} + MG_GROUPS_DB_PASS: ${MG_GROUPS_DB_PASS} + MG_GROUPS_DB_NAME: ${MG_GROUPS_DB_NAME} + MG_GROUPS_DB_SSL_MODE: ${MG_GROUPS_DB_SSL_MODE} + MG_GROUPS_DB_SSL_CERT: ${MG_GROUPS_DB_SSL_CERT} + MG_GROUPS_DB_SSL_KEY: ${MG_GROUPS_DB_SSL_KEY} + MG_GROUPS_DB_SSL_ROOT_CERT: ${MG_GROUPS_DB_SSL_ROOT_CERT} + MG_CHANNELS_URL: ${MG_CHANNELS_URL} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_AUTH_KEYS_ALGORITHM: ${MG_AUTH_KEYS_ALGORITHM} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_GROUPS_CALLOUT_URLS: ${MG_GROUPS_CALLOUT_URLS} + MG_GROUPS_CALLOUT_METHOD: ${MG_GROUPS_CALLOUT_METHOD} + MG_GROUPS_CALLOUT_TLS_VERIFICATION: ${MG_GROUPS_CALLOUT_TLS_VERIFICATION} + MG_GROUPS_CALLOUT_TIMEOUT: ${MG_GROUPS_CALLOUT_TIMEOUT} + MG_GROUPS_CALLOUT_CA_CERT: ${MG_GROUPS_CALLOUT_CA_CERT} + MG_GROUPS_CALLOUT_CERT: ${MG_GROUPS_CALLOUT_CERT} + MG_GROUPS_CALLOUT_KEY: ${MG_GROUPS_CALLOUT_KEY} + MG_GROUPS_CALLOUT_OPERATIONS: ${MG_GROUPS_CALLOUT_OPERATIONS} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_GROUPS_HTTP_PORT}:${SMQ_GROUPS_HTTP_PORT} - - ${SMQ_GROUPS_GRPC_PORT}:${SMQ_GROUPS_GRPC_PORT} + - ${MG_GROUPS_HTTP_PORT}:${MG_GROUPS_HTTP_PORT} + - ${MG_GROUPS_GRPC_PORT}:${MG_GROUPS_GRPC_PORT} networks: - - supermq-base-net + - magistrala-base-net volumes: - ./permission.yaml:/permission.yaml - - ./spicedb/schema.zed:${SMQ_SPICEDB_SCHEMA_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} # Groups gRPC server certificates - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} + source: ${MG_GROUPS_GRPC_SERVER_CERT:-./ssl/placeholder} target: /groups-grpc-server.crt bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_KEY:-ssl/certs/dummy/server_key} + source: ${MG_GROUPS_GRPC_SERVER_KEY:-./ssl/placeholder} target: /groups-grpc-server.key bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca_certs} + source: ${MG_GROUPS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /groups-grpc-server-ca.crt bind: create_host_path: true - type: bind - source: ${SMQ_GROUPS_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca_certs} + source: ${MG_GROUPS_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} target: /groups-grpc-client-ca.crt bind: create_host_path: true # Auth gRPC client certificates - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true # Clients gRPC client certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /clients-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /clients-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /clients-grpc-server-ca.crt bind: create_host_path: true # Channels gRPC client certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /channels-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /channels-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-server-ca.crt bind: create_host_path: true # Domains gRPC client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /domains-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /domains-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true jaeger: image: docker.io/jaegertracing/all-in-one:1.74.0 - container_name: supermq-jaeger + container_name: magistrala-jaeger environment: - COLLECTOR_OTLP_ENABLED: ${SMQ_JAEGER_COLLECTOR_OTLP_ENABLED} - command: --memory.max-traces ${SMQ_JAEGER_MEMORY_MAX_TRACES} + COLLECTOR_OTLP_ENABLED: ${MG_JAEGER_COLLECTOR_OTLP_ENABLED} + command: --memory.max-traces ${MG_JAEGER_MEMORY_MAX_TRACES} ports: - - ${SMQ_JAEGER_FRONTEND}:${SMQ_JAEGER_FRONTEND} - - ${SMQ_JAEGER_OLTP_HTTP}:${SMQ_JAEGER_OLTP_HTTP} + - ${MG_JAEGER_FRONTEND}:${MG_JAEGER_FRONTEND} + - ${MG_JAEGER_OLTP_HTTP}:${MG_JAEGER_OLTP_HTTP} networks: - - supermq-base-net + - magistrala-base-net - mqtt-adapter: - image: docker.io/supermq/mqtt:${SMQ_RELEASE_TAG} - container_name: supermq-mqtt + fluxmq-node1: + image: ghcr.io/absmach/fluxmq:${MG_FLUXMQ_IMAGE_TAG} + container_name: magistrala-fluxmq-node1 + user: "0:0" + command: ["-config", "/etc/fluxmq/config.yaml"] depends_on: - - clients - - rabbitmq - - nats + - fluxmq-auth + restart: on-failure + ports: + - ${MG_COAP_PORT}:5683/udp + - ${MG_FLUXMQ_API_PORT_1}:8082 + networks: + magistrala-base-net: + ipv4_address: 172.30.0.201 + volumes: + - ./fluxmq/node1.yaml:/etc/fluxmq/config.yaml:ro + - magistrala-fluxmq-node1-volume:/tmp/fluxmq + + fluxmq-node2: + image: ghcr.io/absmach/fluxmq:${MG_FLUXMQ_IMAGE_TAG} + container_name: magistrala-fluxmq-node2 + user: "0:0" + command: ["-config", "/etc/fluxmq/config.yaml"] + depends_on: + - fluxmq-node1 + - fluxmq-auth + restart: on-failure + ports: + - ${MG_FLUXMQ_API_PORT_2}:8082 + networks: + magistrala-base-net: + ipv4_address: 172.30.0.202 + volumes: + - ./fluxmq/node2.yaml:/etc/fluxmq/config.yaml:ro + - magistrala-fluxmq-node2-volume:/tmp/fluxmq + + fluxmq-node3: + image: ghcr.io/absmach/fluxmq:${MG_FLUXMQ_IMAGE_TAG} + container_name: magistrala-fluxmq-node3 + user: "0:0" + command: ["-config", "/etc/fluxmq/config.yaml"] + depends_on: + - fluxmq-node1 + - fluxmq-auth + restart: on-failure + ports: + - ${MG_FLUXMQ_API_PORT_3}:8082 + networks: + magistrala-base-net: + ipv4_address: 172.30.0.203 + volumes: + - ./fluxmq/node3.yaml:/etc/fluxmq/config.yaml:ro + - magistrala-fluxmq-node3-volume:/tmp/fluxmq + + fluxmq-auth: + image: docker.io/magistrala/fluxmq:${MG_RELEASE_TAG} + container_name: magistrala-fluxmq-auth restart: on-failure environment: - SMQ_MQTT_ADAPTER_LOG_LEVEL: ${SMQ_MQTT_ADAPTER_LOG_LEVEL} - SMQ_MQTT_ADAPTER_MQTT_PORT: ${SMQ_MQTT_ADAPTER_MQTT_PORT} - SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST: ${SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST} - SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT: ${SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT} - SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME: ${SMQ_MQTT_ADAPTER_MQTT_TARGET_USERNAME} - SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD: ${SMQ_MQTT_ADAPTER_MQTT_TARGET_PASSWORD} - SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT: ${SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT} - SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK: ${SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK} - SMQ_MQTT_ADAPTER_MQTT_QOS: ${SMQ_MQTT_ADAPTER_MQTT_QOS} - SMQ_MQTT_ADAPTER_WS_PORT: ${SMQ_MQTT_ADAPTER_WS_PORT} - SMQ_MQTT_ADAPTER_INSTANCE_ID: ${SMQ_MQTT_ADAPTER_INSTANCE_ID} - SMQ_MQTT_ADAPTER_WS_TARGET_HOST: ${SMQ_MQTT_ADAPTER_WS_TARGET_HOST} - SMQ_MQTT_ADAPTER_WS_TARGET_PORT: ${SMQ_MQTT_ADAPTER_WS_TARGET_PORT} - SMQ_MQTT_ADAPTER_WS_TARGET_PATH: ${SMQ_MQTT_ADAPTER_WS_TARGET_PATH} - SMQ_MQTT_ADAPTER_INSTANCE: ${SMQ_MQTT_ADAPTER_INSTANCE} - SMQ_MQTT_ADAPTER_CACHE_NUM_COUNTERS: ${SMQ_MQTT_ADAPTER_CACHE_NUM_COUNTERS} - SMQ_MQTT_ADAPTER_CACHE_MAX_COST: ${SMQ_MQTT_ADAPTER_CACHE_MAX_COST} - SMQ_MQTT_ADAPTER_CACHE_BUFFER_ITEMS: ${SMQ_MQTT_ADAPTER_CACHE_BUFFER_ITEMS} - SMQ_MQTT_ADAPTER_CERT_FILE: ${SMQ_MQTT_ADAPTER_CERT_FILE:+/mqtt-adapter.crt} - SMQ_MQTT_ADAPTER_KEY_FILE: ${SMQ_MQTT_ADAPTER_KEY_FILE:+/mqtt-adapter.key} - SMQ_MQTT_ADAPTER_SERVER_CA_FILE: ${SMQ_MQTT_ADAPTER_SERVER_CA_FILE:+/mqtt-adapter-server-ca.crt} - SMQ_MQTT_ADAPTER_CLIENT_CA_FILE: ${SMQ_MQTT_ADAPTER_CLIENT_CA_FILE:+/mqtt-adapter-client-ca.crt} - SMQ_MQTT_ADAPTER_CERT_VERIFICATION_METHODS: ${SMQ_MQTT_ADAPTER_CERT_VERIFICATION_METHODS} - SMQ_MQTT_ADAPTER_OCSP_RESPONDER_URL: ${SMQ_MQTT_ADAPTER_OCSP_RESPONDER_URL} - SMQ_ES_URL: ${SMQ_ES_URL} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} + MG_FLUXMQ_LOG_LEVEL: ${MG_FLUXMQ_LOG_LEVEL} + MG_FLUXMQ_GRPC_HOST: ${MG_FLUXMQ_GRPC_HOST} + MG_FLUXMQ_GRPC_PORT: ${MG_FLUXMQ_GRPC_PORT} + MG_FLUXMQ_INSTANCE_ID: ${MG_FLUXMQ_INSTANCE_ID} + MG_FLUXMQ_CACHE_NUM_COUNTERS: ${MG_FLUXMQ_CACHE_NUM_COUNTERS} + MG_FLUXMQ_CACHE_MAX_COST: ${MG_FLUXMQ_CACHE_MAX_COST} + MG_FLUXMQ_CACHE_BUFFER_ITEMS: ${MG_FLUXMQ_CACHE_BUFFER_ITEMS} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} networks: - - supermq-base-net + - magistrala-base-net volumes: - # TLS certificate for MQTT - - type: bind - source: ${SMQ_MQTT_ADAPTER_CERT_FILE:-ssl/certs/dummy/server_cert} - target: /mqtt-adapter.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_MQTT_ADAPTER_KEY_FILE:-ssl/certs/dummy/server_key} - target: /mqtt-adapter.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_MQTT_ADAPTER_SERVER_CA_FILE:-ssl/certs/dummy/server_ca} - target: /mqtt-adapter-server-ca.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_MQTT_ADAPTER_CLIENT_CA_FILE:-ssl/certs/dummy/client_ca} - target: /mqtt-adapter-client-ca.crt - bind: - create_host_path: true # Clients gRPC mTLS client certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /clients-grpc-client.crt + source: ${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_CERT:+.crt} bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /clients-grpc-client.key + source: ${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_KEY:+.key} bind: create_host_path: true - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /clients-grpc-server-ca.crt + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /clients-grpc-server-ca${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+.crt} bind: create_host_path: true # Channels gRPC mTLS client certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /channels-grpc-client.crt + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /channels-grpc-client${MG_CHANNELS_GRPC_CLIENT_CERT:+.crt} bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /channels-grpc-client.key + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /channels-grpc-client${MG_CHANNELS_GRPC_CLIENT_KEY:+.key} bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /channels-grpc-server-ca.crt + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /channels-grpc-server-ca${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+.crt} bind: create_host_path: true # Domains gRPC mTLS client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /domains-grpc-client.crt + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /domains-grpc-client${MG_DOMAINS_GRPC_CLIENT_CERT:+.crt} bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /domains-grpc-client.key + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /domains-grpc-client${MG_DOMAINS_GRPC_CLIENT_KEY:+.key} bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /domains-grpc-server-ca.crt + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /domains-grpc-server-ca${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+.crt} bind: create_host_path: true - http-adapter: - image: docker.io/supermq/http:${SMQ_RELEASE_TAG} - container_name: supermq-http - depends_on: - - clients - - nats - restart: on-failure - environment: - SMQ_HTTP_ADAPTER_LOG_LEVEL: ${SMQ_HTTP_ADAPTER_LOG_LEVEL} - SMQ_HTTP_ADAPTER_HOST: ${SMQ_HTTP_ADAPTER_HOST} - SMQ_HTTP_ADAPTER_PORT: ${SMQ_HTTP_ADAPTER_PORT} - SMQ_HTTP_ADAPTER_SERVER_CERT: ${SMQ_HTTP_ADAPTER_SERVER_CERT} - SMQ_HTTP_ADAPTER_SERVER_KEY: ${SMQ_HTTP_ADAPTER_SERVER_KEY} - SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS: ${SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS} - SMQ_HTTP_ADAPTER_CACHE_MAX_COST: ${SMQ_HTTP_ADAPTER_CACHE_MAX_COST} - SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS: ${SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_AUTH_KEYS_ALGORITHM: ${SMQ_AUTH_KEYS_ALGORITHM} - SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_HTTP_ADAPTER_INSTANCE_ID: ${SMQ_HTTP_ADAPTER_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} + ui: + image: ghcr.io/absmach/magistrala/ui-mg:${MG_RELEASE_TAG} + container_name: magistrala-ui ports: - - ${SMQ_HTTP_ADAPTER_PORT}:${SMQ_HTTP_ADAPTER_PORT} + - 3000:3000 networks: - - supermq-base-net + - magistrala-base-net + environment: + MG_AUTH_URL: ${MG_AUTH_URL} + MG_DOMAINS_URL: ${MG_DOMAINS_URL} + MG_USERS_URL: ${MG_USERS_URL} + MG_CLIENTS_URL: ${MG_CLIENTS_URL} + MG_CHANNELS_URL: ${MG_CHANNELS_URL} + MG_GROUPS_URL: ${MG_GROUPS_URL} + MG_BOOTSTRAP_URL: ${MG_BOOTSTRAP_URL} + MG_CERTS_URL: ${MG_CERTS_URL} + MG_HTTP_ADAPTER_URL: ${MG_HTTP_ADAPTER_URL} + MG_READER_URL: ${MG_READER_URL} + MG_BACKEND_URL: ${MG_UI_BACKEND_URL} + MG_JOURNAL_URL: ${MG_JOURNAL_URL} + MG_ALARMS_URL: ${MG_ALARMS_URL} + MG_RE_URL: ${MG_RE_URL} + MG_REPORTS_URL: ${MG_REPORTS_URL} + MG_GOOGLE_CLIENT_ID: ${MG_GOOGLE_CLIENT_ID} + MG_GOOGLE_CLIENT_SECRET: ${MG_GOOGLE_CLIENT_SECRET} + MG_GOOGLE_REDIRECT_URL: ${MG_GOOGLE_REDIRECT_URL} + MG_GOOGLE_STATE: ${MG_GOOGLE_STATE} + MG_UI_BASE_PATH: ${MG_UI_BASE_PATH} + MG_NEXTAUTH_BASE_PATH: ${MG_NEXTAUTH_BASE_PATH} + MG_UI_TYPE: ${MG_UI_TYPE} + MG_UI_BASEURL: ${MG_UI_BASEURL} + NEXTAUTH_URL: ${NEXTAUTH_URL} + NEXTAUTH_SECRET: ${NEXTAUTH_SECRET} + NEXT_LOG_LEVEL: "debug" + MG_HOST_URL: ${MG_HOST_URL} + MG_UI_IMAGE_URL: ${MG_UI_IMAGE_URL} + MG_UI_DOCKER_ACCEPT_EULA: ${MG_UI_DOCKER_ACCEPT_EULA} + MG_SUPPORT_EMAIL: ${MG_SUPPORT_EMAIL} + MG_SUPPORT_EMAIL_PASS: ${MG_SUPPORT_EMAIL_PASS} + MG_UI_CLI_MQTT_HOST: ${MG_UI_CLI_MQTT_HOST} + MG_UI_CLI_WS_URL: ${MG_UI_CLI_WS_URL} + MG_UI_CLI_COAP_HOST: ${MG_UI_CLI_COAP_HOST} + MG_UI_CLI_COAP_PORT: ${MG_UI_CLI_COAP_PORT} + MG_UI_CLI_HTTP_URL: ${MG_UI_CLI_HTTP_URL} + MG_UI_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + MG_ACCESS_TOKEN_EXPIRY: ${MG_AUTH_ACCESS_TOKEN_DURATION} + MG_REFRESH_TOKEN_EXPIRY: ${MG_AUTH_REFRESH_TOKEN_DURATION} + MG_UI_SMTP_HOST: ${MG_UI_SMTP_HOST} + MG_UI_SMTP_PORT: ${MG_UI_SMTP_PORT} + MG_UI_SMTP_SECURE: ${MG_UI_SMTP_SECURE} + MG_UI_SUPPORT_FROM: ${MG_UI_SUPPORT_FROM} + + ui-backend: + image: ghcr.io/absmach/magistrala/ui-backend:${MG_RELEASE_TAG} + container_name: magistrala-ui-backend + ports: + - ${MG_UI_BACKEND_HTTP_PORT}:${MG_UI_BACKEND_HTTP_PORT} + networks: + - magistrala-base-net + restart: on-failure:3 + environment: + MG_BACKEND_LOG_LEVEL: ${MG_UI_BACKEND_LOG_LEVEL} + MG_BACKEND_HTTP_HOST: ${MG_UI_BACKEND_HTTP_HOST} + MG_BACKEND_HTTP_PORT: ${MG_UI_BACKEND_HTTP_PORT} + MG_BACKEND_HTTP_SERVER_CERT: ${MG_UI_BACKEND_HTTP_SERVER_CERT} + MG_BACKEND_HTTP_SERVER_KEY: ${MG_UI_BACKEND_HTTP_SERVER_KEY} + MG_BACKEND_DB_HOST: ${MG_UI_BACKEND_DB_HOST} + MG_BACKEND_DB_PORT: ${MG_UI_BACKEND_DB_PORT} + MG_BACKEND_DB_USER: ${MG_UI_BACKEND_DB_USER} + MG_BACKEND_DB_PASS: ${MG_UI_BACKEND_DB_PASS} + MG_BACKEND_DB_NAME: ${MG_UI_BACKEND_DB_NAME} + MG_BACKEND_DB_SSL_MODE: ${MG_UI_BACKEND_DB_SSL_MODE} + MG_BACKEND_DB_SSL_CERT: ${MG_UI_BACKEND_DB_SSL_CERT} + MG_BACKEND_DB_SSL_KEY: ${MG_UI_BACKEND_DB_SSL_KEY} + MG_BACKEND_DB_SSL_ROOT_CERT: ${MG_UI_BACKEND_DB_SSL_ROOT_CERT} + MG_BACKEND_INSTANCE_ID: ${MG_UI_BACKEND_INSTANCE_ID} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_UI_VERIFICATION_TLS: ${MG_UI_VERIFICATION_TLS} + MG_UI_CONTENT_TYPE: ${MG_UI_CONTENT_TYPE} + MG_READER_URL: ${MG_READER_URL} + MG_UI_DOCKER_ACCEPT_EULA: ${MG_UI_DOCKER_ACCEPT_EULA} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_TIMESCALE_READER_GRPC_URL: ${MG_TIMESCALE_READER_GRPC_URL} + MG_TIMESCALE_READER_GRPC_TIMEOUT: ${MG_TIMESCALE_READER_GRPC_TIMEOUT} + MG_TIMESCALE_READER_GRPC_CLIENT_CERT: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT:+/readers-grpc-client.crt} + MG_TIMESCALE_READER_GRPC_CLIENT_KEY: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY:+/readers-grpc-client.key} + MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS: ${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS:+/readers-grpc-server-ca.crt} + MG_BACKEND_OBJECT_STORAGE_REGION: ${MG_BACKEND_OBJECT_STORAGE_REGION} + MG_BACKEND_OBJECT_STORAGE_BUCKET: ${MG_BACKEND_OBJECT_STORAGE_BUCKET} + MG_BACKEND_OBJECT_STORAGE_ENDPOINT: ${MG_BACKEND_OBJECT_STORAGE_ENDPOINT} + MG_BACKEND_OBJECT_STORAGE_USE_PATH_STYLE: ${MG_BACKEND_OBJECT_STORAGE_USE_PATH_STYLE} + MG_BACKEND_OBJECT_STORAGE_PRESIGN_ENDPOINT: ${MG_BACKEND_OBJECT_STORAGE_PRESIGN_ENDPOINT} + MG_BACKEND_OBJECT_STORAGE_ACCESS_KEY: ${MG_BACKEND_OBJECT_STORAGE_ACCESS_KEY} + MG_BACKEND_OBJECT_STORAGE_SECRET_KEY: ${MG_BACKEND_OBJECT_STORAGE_SECRET_KEY} + MG_BACKEND_OBJECT_STORAGE_TTL: ${MG_BACKEND_OBJECT_STORAGE_TTL} + MG_BACKEND_OBJECT_STORAGE_READ_TTL: ${MG_BACKEND_OBJECT_STORAGE_READ_TTL} + depends_on: + ui-backend-db: + condition: service_healthy + seaweedfs-s3: + condition: service_started volumes: - # Clients gRPC mTLS client certificates + # Auth gRPC client certificates - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /clients-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /clients-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /clients-grpc-server-ca.crt - bind: - create_host_path: true - # Channels gRPC mTLS client certificates - - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /channels-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /channels-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /channels-grpc-server-ca.crt - bind: - create_host_path: true - # Auth gRPC mTLS client certificates - - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /auth-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /auth-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /auth-grpc-server-ca.crt bind: create_host_path: true - # Domains gRPC mTLS client certificates + # Channels gRPC client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /domains-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /domains-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /domains-grpc-server-ca.crt - bind: - create_host_path: true - - coap-adapter: - image: docker.io/supermq/coap:${SMQ_RELEASE_TAG} - container_name: supermq-coap - depends_on: - - clients - - nats - restart: on-failure - environment: - SMQ_COAP_ADAPTER_LOG_LEVEL: ${SMQ_COAP_ADAPTER_LOG_LEVEL} - SMQ_COAP_ADAPTER_HOST: ${SMQ_COAP_ADAPTER_HOST} - SMQ_COAP_ADAPTER_PORT: ${SMQ_COAP_ADAPTER_PORT} - SMQ_COAP_ADAPTER_SERVER_CERT_FILE: ${SMQ_COAP_ADAPTER_SERVER_CERT_FILE:+/coap-server.crt} - SMQ_COAP_ADAPTER_SERVER_KEY_FILE: ${SMQ_COAP_ADAPTER_SERVER_KEY_FILE:+/coap-server.key} - SMQ_COAP_ADAPTER_SERVER_CA_FILE: ${SMQ_COAP_ADAPTER_SERVER_CA_FILE:+/coap-server-ca.crt} - SMQ_COAP_ADAPTER_HTTP_HOST: ${SMQ_COAP_ADAPTER_HTTP_HOST} - SMQ_COAP_ADAPTER_HTTP_PORT: ${SMQ_COAP_ADAPTER_HTTP_PORT} - SMQ_COAP_ADAPTER_HTTP_SERVER_CERT: ${SMQ_COAP_ADAPTER_HTTP_SERVER_CERT} - SMQ_COAP_ADAPTER_HTTP_SERVER_KEY: ${SMQ_COAP_ADAPTER_HTTP_SERVER_KEY} - SMQ_COAP_ADAPTER_CACHE_NUM_COUNTERS: ${SMQ_COAP_ADAPTER_CACHE_NUM_COUNTERS} - SMQ_COAP_ADAPTER_CACHE_MAX_COST: ${SMQ_COAP_ADAPTER_CACHE_MAX_COST} - SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS: ${SMQ_COAP_ADAPTER_CACHE_BUFFER_ITEMS} - SMQ_CLIENTS_GRPC_URL: ${SMQ_CLIENTS_GRPC_URL} - SMQ_CLIENTS_GRPC_TIMEOUT: ${SMQ_CLIENTS_GRPC_TIMEOUT} - SMQ_CLIENTS_GRPC_CLIENT_CERT: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} - SMQ_CLIENTS_GRPC_CLIENT_KEY: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} - SMQ_CLIENTS_GRPC_SERVER_CA_CERTS: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} - SMQ_CHANNELS_GRPC_URL: ${SMQ_CHANNELS_GRPC_URL} - SMQ_CHANNELS_GRPC_TIMEOUT: ${SMQ_CHANNELS_GRPC_TIMEOUT} - SMQ_CHANNELS_GRPC_CLIENT_CERT: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} - SMQ_CHANNELS_GRPC_CLIENT_KEY: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} - SMQ_CHANNELS_GRPC_SERVER_CA_CERTS: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_MESSAGE_BROKER_URL: ${SMQ_MESSAGE_BROKER_URL} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_COAP_ADAPTER_INSTANCE_ID: ${SMQ_COAP_ADAPTER_INSTANCE_ID} - SMQ_ES_URL: ${SMQ_ES_URL} - ports: - - ${SMQ_COAP_ADAPTER_PORT}:${SMQ_COAP_ADAPTER_PORT}/udp - - ${SMQ_COAP_ADAPTER_HTTP_PORT}:${SMQ_COAP_ADAPTER_HTTP_PORT}/tcp - networks: - - supermq-base-net - volumes: - # DTLS certificates for CoAP - - type: bind - source: ${SMQ_COAP_ADAPTER_SERVER_CERT_FILE:-ssl/certs/dummy/server_cert} - target: /coap-server.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_COAP_ADAPTER_SERVER_KEY_FILE:-ssl/certs/dummy/server_key} - target: /coap-server.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_COAP_ADAPTER_SERVER_CA_FILE:-ssl/certs/dummy/server_ca} - target: /coap-server-ca.crt - bind: - create_host_path: true - # Clients gRPC mTLS client certificates - - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /clients-grpc-client.crt - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /clients-grpc-client.key - bind: - create_host_path: true - - type: bind - source: ${SMQ_CLIENTS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /clients-grpc-server-ca.crt - bind: - create_host_path: true - # Channels gRPC mTLS client certificates - - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /channels-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /channels-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_CHANNELS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /channels-grpc-server-ca.crt bind: create_host_path: true + # Reader gRPC client certificates - type: bind - source: ${SMQ_CHANNELS_GRPC_CLIENT_CA_CERTS:-ssl/certs/dummy/client_ca} - target: /channels-grpc-client-ca.crt + source: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /readers-grpc-client.crt bind: create_host_path: true - # Domains gRPC mTLS client certificates - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} + source: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /readers-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /readers-grpc-server-ca.crt + bind: + create_host_path: true + + ui-backend-db: + image: docker.io/postgres:18.0-alpine3.22 + container_name: magistrala-ui-backend-db + restart: on-failure + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_UI_BACKEND_DB_USER} + POSTGRES_PASSWORD: ${MG_UI_BACKEND_DB_PASS} + POSTGRES_DB: ${MG_UI_BACKEND_DB_NAME} + MG_POSTGRES_MAX_CONNECTIONS: ${MG_POSTGRES_MAX_CONNECTIONS} + ports: + - 6008:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-ui-backend-db-volume:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U $${POSTGRES_USER} -d $${POSTGRES_DB}"] + interval: 5s + timeout: 3s + retries: 60 + + seaweedfs-s3: + image: chrislusf/seaweedfs:4.16 + container_name: magistrala-seaweedfs-s3 + command: server -s3 -s3.config=/etc/seaweedfs/s3.json -dir=/data + ports: + - "8333:8333" + - "9333:9333" + - "19333:19333" + - "8888:8888" + volumes: + - ./data/seaweedfs:/data + - ./seaweedfs/s3.json:/etc/seaweedfs/s3.json:ro + networks: + - magistrala-base-net + + seaweedfs-init: + image: amazon/aws-cli + container_name: magistrala-seaweedfs-init + entrypoint: /bin/sh + depends_on: + - seaweedfs-s3 + command: + - -c + - | + echo "[INIT] Waiting 20s for SeaweedFS S3 to be ready..."; + sleep 20; + OUT=$(aws --endpoint-url http://seaweedfs-s3:8333 s3api create-bucket --bucket $${BUCKET} 2>&1); + EXIT=$$?; + if [ $$EXIT -eq 0 ]; then + echo "[INIT] Bucket $${BUCKET} created successfully."; + elif echo "$$OUT" | grep -q 'BucketAlreadyOwnedByYou\|BucketAlreadyExists'; then + echo "[INIT] Bucket $${BUCKET} already exists, skipping."; + else + echo "[INIT] Failed to create bucket $${BUCKET}: $$OUT" >&2; + exit 1; + fi + networks: + - magistrala-base-net + environment: + BUCKET: ${MG_BACKEND_OBJECT_STORAGE_BUCKET} + AWS_ACCESS_KEY_ID: ${MG_BACKEND_OBJECT_STORAGE_ACCESS_KEY} + AWS_SECRET_ACCESS_KEY: ${MG_BACKEND_OBJECT_STORAGE_SECRET_KEY} + AWS_DEFAULT_REGION: ${MG_BACKEND_OBJECT_STORAGE_REGION} + AWS_EC2_METADATA_DISABLED: "true" + + timescale: + image: timescale/timescaledb:2.19.3-pg16-oss + container_name: magistrala-timescale + restart: on-failure + environment: + POSTGRES_PASSWORD: ${MG_TIMESCALE_PASS} + POSTGRES_USER: ${MG_TIMESCALE_USER} + POSTGRES_DB: ${MG_TIMESCALE_NAME} + ports: + - 5433:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-timescale-writer-volume:/var/lib/postgresql/data + + timescale-reader: + image: docker.io/magistrala/timescale-reader:${MG_RELEASE_TAG} + container_name: magistrala-timescale-reader + depends_on: + - timescale + restart: on-failure + environment: + MG_TIMESCALE_READER_LOG_LEVEL: ${MG_TIMESCALE_READER_LOG_LEVEL} + MG_TIMESCALE_READER_HTTP_HOST: ${MG_TIMESCALE_READER_HTTP_HOST} + MG_TIMESCALE_READER_HTTP_PORT: ${MG_TIMESCALE_READER_HTTP_PORT} + MG_TIMESCALE_READER_HTTP_SERVER_CERT: ${MG_TIMESCALE_READER_HTTP_SERVER_CERT} + MG_TIMESCALE_READER_HTTP_SERVER_KEY: ${MG_TIMESCALE_READER_HTTP_SERVER_KEY} + MG_TIMESCALE_HOST: ${MG_TIMESCALE_HOST} + MG_TIMESCALE_PORT: ${MG_TIMESCALE_PORT} + MG_TIMESCALE_USER: ${MG_TIMESCALE_USER} + MG_TIMESCALE_PASS: ${MG_TIMESCALE_PASS} + MG_TIMESCALE_NAME: ${MG_TIMESCALE_NAME} + MG_TIMESCALE_SSL_MODE: ${MG_TIMESCALE_SSL_MODE} + MG_TIMESCALE_SSL_CERT: ${MG_TIMESCALE_SSL_CERT} + MG_TIMESCALE_SSL_KEY: ${MG_TIMESCALE_SSL_KEY} + MG_TIMESCALE_SSL_ROOT_CERT: ${MG_TIMESCALE_SSL_ROOT_CERT} + MG_CLIENTS_GRPC_URL: ${MG_CLIENTS_GRPC_URL} + MG_CLIENTS_GRPC_TIMEOUT: ${MG_CLIENTS_GRPC_TIMEOUT} + MG_CLIENTS_GRPC_CLIENT_CERT: ${MG_CLIENTS_GRPC_CLIENT_CERT:+/clients-grpc-client.crt} + MG_CLIENTS_GRPC_CLIENT_KEY: ${MG_CLIENTS_GRPC_CLIENT_KEY:+/clients-grpc-client.key} + MG_CLIENTS_GRPC_SERVER_CA_CERTS: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+/clients-grpc-server-ca.crt} + MG_CHANNELS_GRPC_URL: ${MG_CHANNELS_GRPC_URL} + MG_CHANNELS_GRPC_TIMEOUT: ${MG_CHANNELS_GRPC_TIMEOUT} + MG_CHANNELS_GRPC_CLIENT_CERT: ${MG_CHANNELS_GRPC_CLIENT_CERT:+/channels-grpc-client.crt} + MG_CHANNELS_GRPC_CLIENT_KEY: ${MG_CHANNELS_GRPC_CLIENT_KEY:+/channels-grpc-client.key} + MG_CHANNELS_GRPC_SERVER_CA_CERTS: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+/channels-grpc-server-ca.crt} + MG_TIMESCALE_READER_GRPC_URL: ${MG_TIMESCALE_READER_GRPC_URL} + MG_TIMESCALE_READER_GRPC_PORT: ${MG_TIMESCALE_READER_GRPC_PORT} + MG_TIMESCALE_READER_GRPC_HOST: ${MG_TIMESCALE_READER_GRPC_HOST} + MG_TIMESCALE_READER_GRPC_TIMEOUT: ${MG_TIMESCALE_READER_GRPC_TIMEOUT} + MG_TIMESCALE_READER_GRPC_CLIENT_CERT: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT:+/readers-grpc-client.crt} + MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS: ${MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS:+/readers-grpc-client-ca.crt} + MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS: ${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS:+/readers-grpc-server-ca.crt} + MG_TIMESCALE_READER_GRPC_CLIENT_KEY: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY:+/readers-grpc-client.key} + MG_TIMESCALE_READER_GRPC_SERVER_CERT: ${MG_TIMESCALE_READER_GRPC_SERVER_CERT:+/readers-grpc-server.crt} + MG_TIMESCALE_READER_GRPC_SERVER_KEY: ${MG_TIMESCALE_READER_GRPC_SERVER_KEY:+/readers-grpc-server.key} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_TIMESCALE_READER_INSTANCE_ID: ${MG_TIMESCALE_READER_INSTANCE_ID} + ports: + - ${MG_TIMESCALE_READER_HTTP_PORT}:${MG_TIMESCALE_READER_HTTP_PORT} + - ${MG_TIMESCALE_READER_GRPC_PORT}:${MG_TIMESCALE_READER_GRPC_PORT} + networks: + - magistrala-base-net + volumes: + # Auth gRPC client certificates + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client${MG_AUTH_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca${MG_AUTH_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + # Clients gRPC client certificates + - type: bind + source: ${MG_CLIENTS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_CLIENTS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /clients-grpc-client${MG_CLIENTS_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_CLIENTS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /clients-grpc-server-ca${MG_CLIENTS_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + # Channels gRPC client certificates + - type: bind + source: ${MG_CHANNELS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /channels-grpc-client${MG_CHANNELS_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_CHANNELS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /channels-grpc-client${MG_CHANNELS_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_CHANNELS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /channels-grpc-server-ca${MG_CHANNELS_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + # Reader gRPC server and client certificates + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_SERVER_CERT:-./ssl/placeholder} + target: /readers-grpc-server${MG_TIMESCALE_READER_GRPC_SERVER_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_SERVER_KEY:-./ssl/placeholder} + target: /readers-grpc-server${MG_TIMESCALE_READER_GRPC_SERVER_KEY:+.key} + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /readers-grpc-server-ca${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS:-./ssl/placeholder} + target: /readers-grpc-client-ca${MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /readers-grpc-client${MG_TIMESCALE_READER_GRPC_CLIENT_CERT:+.crt} + bind: + create_host_path: true + - type: bind + source: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /readers-grpc-client${MG_TIMESCALE_READER_GRPC_CLIENT_KEY:+.key} + bind: + create_host_path: true + + timescale-writer: + image: docker.io/magistrala/timescale-writer:${MG_RELEASE_TAG} + container_name: magistrala-timescale-writer + depends_on: + - timescale + restart: on-failure + environment: + MG_TIMESCALE_WRITER_LOG_LEVEL: ${MG_TIMESCALE_WRITER_LOG_LEVEL} + MG_TIMESCALE_WRITER_CONFIG_PATH: ${MG_TIMESCALE_WRITER_CONFIG_PATH} + MG_TIMESCALE_WRITER_HTTP_HOST: ${MG_TIMESCALE_WRITER_HTTP_HOST} + MG_TIMESCALE_WRITER_HTTP_PORT: ${MG_TIMESCALE_WRITER_HTTP_PORT} + MG_TIMESCALE_WRITER_HTTP_SERVER_CERT: ${MG_TIMESCALE_WRITER_HTTP_SERVER_CERT} + MG_TIMESCALE_WRITER_HTTP_SERVER_KEY: ${MG_TIMESCALE_WRITER_HTTP_SERVER_KEY} + MG_TIMESCALE_HOST: ${MG_TIMESCALE_HOST} + MG_TIMESCALE_PORT: ${MG_TIMESCALE_PORT} + MG_TIMESCALE_USER: ${MG_TIMESCALE_USER} + MG_TIMESCALE_PASS: ${MG_TIMESCALE_PASS} + MG_TIMESCALE_NAME: ${MG_TIMESCALE_NAME} + MG_TIMESCALE_SSL_MODE: ${MG_TIMESCALE_SSL_MODE} + MG_TIMESCALE_SSL_CERT: ${MG_TIMESCALE_SSL_CERT} + MG_TIMESCALE_SSL_KEY: ${MG_TIMESCALE_SSL_KEY} + MG_TIMESCALE_SSL_ROOT_CERT: ${MG_TIMESCALE_SSL_ROOT_CERT} + MG_MESSAGE_BROKER_URL: ${MG_MESSAGE_BROKER_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_TIMESCALE_WRITER_INSTANCE_ID: ${MG_TIMESCALE_WRITER_INSTANCE_ID} + ports: + - ${MG_TIMESCALE_WRITER_HTTP_PORT}:${MG_TIMESCALE_WRITER_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + - ./addons/timescale-writer/config.toml:${MG_TIMESCALE_WRITER_CONFIG_PATH} + re-db: + image: docker.io/postgres:18.0-alpine3.22 + container_name: magistrala-re-db + restart: on-failure + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_RE_DB_USER} + POSTGRES_PASSWORD: ${MG_RE_DB_PASS} + POSTGRES_DB: ${MG_RE_DB_NAME} + ports: + - 6009:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-re-db-volume:/var/lib/postgresql/data + + re: + image: docker.io/magistrala/re:${MG_RELEASE_TAG} + container_name: magistrala-re + depends_on: + - re-db + - spicedb-migrate + - nginx + restart: on-failure + environment: + MG_RE_LOG_LEVEL: ${MG_RE_LOG_LEVEL} + MG_RE_HTTP_PORT: ${MG_RE_HTTP_PORT} + MG_RE_HTTP_HOST: ${MG_RE_HTTP_HOST} + MG_RE_HTTP_SERVER_CERT: ${MG_RE_HTTP_SERVER_CERT} + MG_RE_HTTP_SERVER_KEY: ${MG_RE_HTTP_SERVER_KEY} + MG_RE_DB_HOST: ${MG_RE_DB_HOST} + MG_RE_DB_PORT: ${MG_RE_DB_PORT} + MG_RE_DB_USER: ${MG_RE_DB_USER} + MG_RE_DB_PASS: ${MG_RE_DB_PASS} + MG_RE_DB_NAME: ${MG_RE_DB_NAME} + MG_RE_DB_SSL_MODE: ${MG_RE_DB_SSL_MODE} + MG_RE_DB_SSL_CERT: ${MG_RE_DB_SSL_CERT} + MG_RE_DB_SSL_KEY: ${MG_RE_DB_SSL_KEY} + MG_RE_DB_SSL_ROOT_CERT: ${MG_RE_DB_SSL_ROOT_CERT} + MG_RE_CALLOUT_URLS: ${MG_RE_CALLOUT_URLS} + MG_RE_CALLOUT_METHOD: ${MG_RE_CALLOUT_METHOD} + MG_RE_CALLOUT_TLS_VERIFICATION: ${MG_RE_CALLOUT_TLS_VERIFICATION} + MG_RE_CALLOUT_TIMEOUT: ${MG_RE_CALLOUT_TIMEOUT} + MG_RE_CALLOUT_CA_CERT: ${MG_RE_CALLOUT_CA_CERT} + MG_RE_CALLOUT_CERT: ${MG_RE_CALLOUT_CERT} + MG_RE_CALLOUT_KEY: ${MG_RE_CALLOUT_KEY} + MG_RE_CALLOUT_OPERATIONS: ${MG_RE_CALLOUT_OPERATIONS} + MG_MESSAGE_BROKER_URL: ${MG_MESSAGE_BROKER_URL} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_PERMISSIONS_FILE: ${MG_PERMISSIONS_FILE} + MG_RE_INSTANCE_ID: ${MG_RE_INSTANCE_ID} + MG_EMAIL_HOST: ${MG_EMAIL_HOST} + MG_EMAIL_PORT: ${MG_EMAIL_PORT} + MG_EMAIL_USERNAME: ${MG_EMAIL_USERNAME} + MG_EMAIL_PASSWORD: ${MG_EMAIL_PASSWORD} + MG_EMAIL_FROM_ADDRESS: ${MG_EMAIL_FROM_ADDRESS} + MG_EMAIL_FROM_NAME: ${MG_EMAIL_FROM_NAME} + MG_EMAIL_TEMPLATE: ${MG_EMAIL_TEMPLATE} + MG_TIMESCALE_READER_GRPC_URL: ${MG_TIMESCALE_READER_GRPC_URL} + MG_TIMESCALE_READER_GRPC_TIMEOUT: ${MG_TIMESCALE_READER_GRPC_TIMEOUT} + MG_TIMESCALE_READER_GRPC_CLIENT_CERT: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT} + MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS: ${MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS} + MG_TIMESCALE_READER_GRPC_CLIENT_KEY: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + ports: + - ${MG_RE_HTTP_PORT}:${MG_RE_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + - ./permission.yaml:${MG_PERMISSIONS_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + - ./templates/${MG_RE_EMAIL_TEMPLATE}:/email.tmpl + # Auth gRPC client certificates + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca.crt + bind: + create_host_path: true + # Domains gRPC client certificates + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} target: /domains-grpc-client.crt bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} target: /domains-grpc-client.key bind: create_host_path: true - type: bind - source: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} target: /domains-grpc-server-ca.crt bind: create_host_path: true - rabbitmq: - image: docker.io/rabbitmq:4.1.4-management-alpine - container_name: supermq-rabbitmq + alarms-db: + image: docker.io/postgres:18.0-alpine3.22 + container_name: magistrala-alarms-db restart: on-failure + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" environment: - RABBITMQ_ERLANG_COOKIE: ${SMQ_RABBITMQ_COOKIE} - RABBITMQ_DEFAULT_USER: ${SMQ_RABBITMQ_USER} - RABBITMQ_DEFAULT_PASS: ${SMQ_RABBITMQ_PASS} - RABBITMQ_DEFAULT_VHOST: ${SMQ_RABBITMQ_VHOST} - RABBITMQ_CONFIG_FILES: /etc/rabbitmq/conf.d/ + POSTGRES_USER: ${MG_ALARMS_DB_USER} + POSTGRES_PASSWORD: ${MG_ALARMS_DB_PASS} + POSTGRES_DB: ${MG_ALARMS_DB_NAME} ports: - - ${SMQ_RABBITMQ_PORT}:${SMQ_RABBITMQ_PORT} - - ${SMQ_RABBITMQ_HTTP_PORT}:${SMQ_RABBITMQ_HTTP_PORT} - - ${SMQ_RABBITMQ_WS_PORT}:${SMQ_RABBITMQ_WS_PORT} - volumes: - - ./rabbitmq/enabled_plugins:/etc/rabbitmq/enabled_plugins - - ./rabbitmq/rabbitmq.conf:/etc/rabbitmq/conf.d/10-defaults.conf - - supermq-mqtt-broker-volume:/var/lib/rabbitmq + - 6019:5432 networks: - - supermq-base-net + - magistrala-base-net + volumes: + - magistrala-alarms-db-volume:/var/lib/postgresql/data - nats: - image: docker.io/nats:2.12.0-alpine3.22 - container_name: supermq-nats + alarms: + image: docker.io/magistrala/alarms:${MG_RELEASE_TAG} + container_name: magistrala-alarms + depends_on: + - alarms-db + - spicedb-migrate + - nginx restart: on-failure - command: "--config=/etc/nats/nats.conf" environment: - - SMQ_NATS_PORT=${SMQ_NATS_PORT} - - SMQ_NATS_HTTP_PORT=${SMQ_NATS_HTTP_PORT} - - SMQ_NATS_JETSTREAM_KEY=${SMQ_NATS_JETSTREAM_KEY} + MG_ALARMS_LOG_LEVEL: ${MG_ALARMS_LOG_LEVEL} + MG_ALARMS_HTTP_PORT: ${MG_ALARMS_HTTP_PORT} + MG_ALARMS_HTTP_HOST: ${MG_ALARMS_HTTP_HOST} + MG_ALARMS_HTTP_SERVER_CERT: ${MG_ALARMS_HTTP_SERVER_CERT} + MG_ALARMS_HTTP_SERVER_KEY: ${MG_ALARMS_HTTP_SERVER_KEY} + MG_ALARMS_DB_HOST: ${MG_ALARMS_DB_HOST} + MG_ALARMS_DB_PORT: ${MG_ALARMS_DB_PORT} + MG_ALARMS_DB_USER: ${MG_ALARMS_DB_USER} + MG_ALARMS_DB_PASS: ${MG_ALARMS_DB_PASS} + MG_ALARMS_DB_NAME: ${MG_ALARMS_DB_NAME} + MG_ALARMS_DB_SSL_MODE: ${MG_ALARMS_DB_SSL_MODE} + MG_ALARMS_DB_SSL_CERT: ${MG_ALARMS_DB_SSL_CERT} + MG_ALARMS_DB_SSL_KEY: ${MG_ALARMS_DB_SSL_KEY} + MG_ALARMS_DB_SSL_ROOT_CERT: ${MG_ALARMS_DB_SSL_ROOT_CERT} + MG_MESSAGE_BROKER_URL: ${MG_MESSAGE_BROKER_URL} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_PERMISSIONS_FILE: ${MG_PERMISSIONS_FILE} + MG_ALARMS_INSTANCE_ID: ${MG_ALARMS_INSTANCE_ID} + MG_ALARMS_EVENT_CONSUMER: ${MG_ALARMS_EVENT_CONSUMER} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} ports: - - ${SMQ_NATS_PORT}:${SMQ_NATS_PORT} - - ${SMQ_NATS_HTTP_PORT}:${SMQ_NATS_HTTP_PORT} - volumes: - - supermq-broker-volume:/data - - ./nats:/etc/nats + - ${MG_ALARMS_HTTP_PORT}:${MG_ALARMS_HTTP_PORT} networks: - - supermq-base-net + - magistrala-base-net + volumes: + - ./permission.yaml:${MG_PERMISSIONS_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + # Auth gRPC client certificates + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca.crt + bind: + create_host_path: true + # Domains gRPC client certificates + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /domains-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /domains-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /domains-grpc-server-ca.crt + bind: + create_host_path: true + + reports-db: + image: docker.io/postgres:18.0-alpine3.22 + container_name: magistrala-reports-db + restart: on-failure + command: postgres -c "max_connections=${MG_POSTGRES_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_REPORTS_DB_USER} + POSTGRES_PASSWORD: ${MG_REPORTS_DB_PASS} + POSTGRES_DB: ${MG_REPORTS_DB_NAME} + ports: + - 6020:5432 + networks: + - magistrala-base-net + volumes: + - magistrala-reports-db-volume:/var/lib/postgresql/data + + reports: + image: docker.io/magistrala/reports:${MG_RELEASE_TAG} + container_name: magistrala-reports + depends_on: + - reports-db + - spicedb-migrate + - nginx + restart: on-failure + environment: + MG_REPORTS_LOG_LEVEL: ${MG_REPORTS_LOG_LEVEL} + MG_REPORTS_HTTP_PORT: ${MG_REPORTS_HTTP_PORT} + MG_REPORTS_HTTP_HOST: ${MG_REPORTS_HTTP_HOST} + MG_REPORTS_HTTP_SERVER_CERT: ${MG_REPORTS_HTTP_SERVER_CERT} + MG_REPORTS_HTTP_SERVER_KEY: ${MG_REPORTS_HTTP_SERVER_KEY} + MG_REPORTS_DB_HOST: ${MG_REPORTS_DB_HOST} + MG_REPORTS_DB_PORT: ${MG_REPORTS_DB_PORT} + MG_REPORTS_DB_USER: ${MG_REPORTS_DB_USER} + MG_REPORTS_DB_PASS: ${MG_REPORTS_DB_PASS} + MG_REPORTS_DB_NAME: ${MG_REPORTS_DB_NAME} + MG_REPORTS_DB_SSL_MODE: ${MG_REPORTS_DB_SSL_MODE} + MG_REPORTS_DB_SSL_CERT: ${MG_REPORTS_DB_SSL_CERT} + MG_REPORTS_DB_SSL_KEY: ${MG_REPORTS_DB_SSL_KEY} + MG_REPORTS_DB_SSL_ROOT_CERT: ${MG_REPORTS_DB_SSL_ROOT_CERT} + MG_REPORTS_DEFAULT_TEMPLATE: ${MG_REPORTS_DEFAULT_TEMPLATE} + MG_PDF_CONVERTER_URL: ${MG_PDF_CONVERTER_URL} + MG_MESSAGE_BROKER_URL: ${MG_MESSAGE_BROKER_URL} + MG_ES_URL: ${MG_ES_URL} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_SPICEDB_PRE_SHARED_KEY: ${MG_SPICEDB_PRE_SHARED_KEY} + MG_SPICEDB_HOST: ${MG_SPICEDB_HOST} + MG_SPICEDB_PORT: ${MG_SPICEDB_PORT} + MG_SPICEDB_SCHEMA_FILE: ${MG_SPICEDB_SCHEMA_FILE} + MG_PERMISSIONS_FILE: ${MG_PERMISSIONS_FILE} + MG_REPORTS_INSTANCE_ID: ${MG_RE_INSTANCE_ID} + MG_EMAIL_HOST: ${MG_EMAIL_HOST} + MG_EMAIL_PORT: ${MG_EMAIL_PORT} + MG_EMAIL_USERNAME: ${MG_EMAIL_USERNAME} + MG_EMAIL_PASSWORD: ${MG_EMAIL_PASSWORD} + MG_EMAIL_FROM_ADDRESS: ${MG_EMAIL_FROM_ADDRESS} + MG_EMAIL_FROM_NAME: ${MG_EMAIL_FROM_NAME} + MG_EMAIL_TEMPLATE: ${MG_EMAIL_TEMPLATE} + MG_TIMESCALE_READER_GRPC_URL: ${MG_TIMESCALE_READER_GRPC_URL} + MG_TIMESCALE_READER_GRPC_TIMEOUT: ${MG_TIMESCALE_READER_GRPC_TIMEOUT} + MG_TIMESCALE_READER_GRPC_CLIENT_CERT: ${MG_TIMESCALE_READER_GRPC_CLIENT_CERT} + MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS: ${MG_TIMESCALE_READER_GRPC_SERVER_CA_CERTS} + MG_TIMESCALE_READER_GRPC_CLIENT_KEY: ${MG_TIMESCALE_READER_GRPC_CLIENT_KEY} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + ports: + - ${MG_REPORTS_HTTP_PORT}:${MG_REPORTS_HTTP_PORT} + networks: + - magistrala-base-net + volumes: + - ./permission.yaml:${MG_PERMISSIONS_FILE} + - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + - ./templates/${MG_REPORTS_EMAIL_TEMPLATE}:/email.tmpl + # Auth gRPC client certificates + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca.crt + bind: + create_host_path: true + # Domains gRPC client certificates + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /domains-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /domains-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /domains-grpc-server-ca.crt + bind: + create_host_path: true + + pdf-generator: + image: gotenberg/gotenberg:8.25.1 + container_name: magistrala-pdf + ports: + - "4000:3000" + networks: + - magistrala-base-net + + certs: + image: docker.io/magistrala/certs:${MG_RELEASE_TAG} + container_name: magistrala-certs + depends_on: + openbao: + condition: service_healthy + certs-db: + condition: service_started + restart: on-failure + networks: + - magistrala-base-net + environment: + MG_CERTS_LOG_LEVEL: ${MG_CERTS_LOG_LEVEL} + MG_CERTS_HTTP_HOST: ${MG_CERTS_HTTP_HOST} + MG_CERTS_HTTP_PORT: ${MG_CERTS_HTTP_PORT} + MG_CERTS_GRPC_HOST: ${MG_CERTS_GRPC_HOST} + MG_CERTS_GRPC_PORT: ${MG_CERTS_GRPC_PORT} + MG_JAEGER_URL: ${MG_JAEGER_URL} + MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} + MG_CERTS_OPENBAO_HOST: ${MG_CERTS_OPENBAO_HOST} + MG_CERTS_OPENBAO_APP_ROLE: ${MG_CERTS_OPENBAO_APP_ROLE} + MG_CERTS_OPENBAO_APP_SECRET: ${MG_CERTS_OPENBAO_APP_SECRET} + MG_CERTS_OPENBAO_NAMESPACE: ${MG_CERTS_OPENBAO_NAMESPACE} + MG_CERTS_OPENBAO_PKI_PATH: ${MG_CERTS_OPENBAO_PKI_PATH} + MG_CERTS_OPENBAO_ROLE: ${MG_CERTS_OPENBAO_ROLE} + MG_CERTS_OPENBAO_SECRET_ID_TTL: ${MG_CERTS_OPENBAO_SECRET_ID_TTL} + MG_CERTS_DB_HOST: ${MG_CERTS_DB_HOST} + MG_CERTS_DB_PORT: ${MG_CERTS_DB_PORT} + MG_CERTS_DB_USER: ${MG_CERTS_DB_USER} + MG_CERTS_DB_PASS: ${MG_CERTS_DB_PASS} + MG_CERTS_DB: ${MG_CERTS_DB} + MG_CERTS_DB_SSL_MODE: ${MG_CERTS_DB_SSL_MODE} + MG_AUTH_GRPC_URL: ${MG_AUTH_GRPC_URL} + MG_AUTH_GRPC_TIMEOUT: ${MG_AUTH_GRPC_TIMEOUT} + MG_AUTH_GRPC_CLIENT_CERT: ${MG_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} + MG_AUTH_GRPC_CLIENT_KEY: ${MG_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} + MG_AUTH_GRPC_SERVER_CA_CERTS: ${MG_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} + MG_DOMAINS_GRPC_URL: ${MG_DOMAINS_GRPC_URL} + MG_DOMAINS_GRPC_TIMEOUT: ${MG_DOMAINS_GRPC_TIMEOUT} + MG_DOMAINS_GRPC_CLIENT_CERT: ${MG_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} + MG_DOMAINS_GRPC_CLIENT_KEY: ${MG_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} + MG_DOMAINS_GRPC_SERVER_CA_CERTS: ${MG_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + MG_CERTS_SECRET: ${MG_CERTS_SECRET} + MG_CERTS_SERVICE_TOKEN_PATH: ${MG_CERTS_SERVICE_TOKEN_PATH} + MG_CERTS_SECRET_ID_PATH: ${MG_CERTS_SECRET_ID_PATH} + MG_CERTS_SECRET_RENEW_THRESHOLD: ${MG_CERTS_SECRET_RENEW_THRESHOLD} + MG_CERTS_SECRET_CHECK_INTERVAL: ${MG_CERTS_SECRET_CHECK_INTERVAL} + MG_ALLOW_UNVERIFIED_USER: ${MG_ALLOW_UNVERIFIED_USER} + ports: + - ${MG_CERTS_HTTP_PORT}:${MG_CERTS_HTTP_PORT} + - ${MG_CERTS_GRPC_PORT}:${MG_CERTS_GRPC_PORT} + volumes: + - magistrala-openbao-data:/openbao:ro + # Auth gRPC client certificates + - type: bind + source: ${AM_AUTH_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /auth-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${AM_AUTH_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /auth-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${AM_AUTH_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /auth-grpc-server-ca.crt + bind: + create_host_path: true + # Domains gRPC client certificates + - type: bind + source: ${AM_DOMAINS_GRPC_CLIENT_CERT:-./ssl/placeholder} + target: /domains-grpc-client.crt + bind: + create_host_path: true + - type: bind + source: ${AM_DOMAINS_GRPC_CLIENT_KEY:-./ssl/placeholder} + target: /domains-grpc-client.key + bind: + create_host_path: true + - type: bind + source: ${AM_DOMAINS_GRPC_SERVER_CA_CERTS:-./ssl/placeholder} + target: /domains-grpc-server-ca.crt + bind: + create_host_path: true + + certs-db: + image: docker.io/postgres:16.2-alpine + container_name: magistrala-certs-db + restart: on-failure + networks: + - magistrala-base-net + command: postgres -c "max_connections=${MG_CERTS_DB_MAX_CONNECTIONS}" + environment: + POSTGRES_USER: ${MG_CERTS_DB_USER} + POSTGRES_PASSWORD: ${MG_CERTS_DB_PASS} + POSTGRES_DB: ${MG_CERTS_DB} + ports: + - 5454:5432 + volumes: + - magistrala-certs-db-volume:/var/lib/postgresql/data + + openbao: + image: openbao/openbao:2.4.0 + container_name: magistrala-openbao + restart: on-failure + networks: + - magistrala-base-net + ports: + - 8200:8200 + healthcheck: + test: ["CMD", "sh", "-c", "test -f /opt/openbao/data/service_token"] + interval: 5s + timeout: 3s + retries: 20 + start_period: 30s + environment: + - BAO_ADDR=http://127.0.0.1:8200 + - BAO_LOG_LEVEL=info + - MG_CERTS_OPENBAO_PKI_ROLE=${MG_CERTS_OPENBAO_ROLE} + - MG_CERTS_OPENBAO_APP_ROLE=${MG_CERTS_OPENBAO_APP_ROLE} + - MG_CERTS_OPENBAO_APP_SECRET=${MG_CERTS_OPENBAO_APP_SECRET} + - MG_CERTS_OPENBAO_SECRET_ID_TTL=${MG_CERTS_OPENBAO_SECRET_ID_TTL} + - MG_CERTS_OPENBAO_NAMESPACE=${MG_CERTS_OPENBAO_NAMESPACE} + - MG_CERTS_OPENBAO_PKI_CA_CN=${MG_CERTS_OPENBAO_PKI_CA_CN} + - MG_CERTS_OPENBAO_PKI_CA_OU=${MG_CERTS_OPENBAO_PKI_CA_OU} + - MG_CERTS_OPENBAO_PKI_CA_O=${MG_CERTS_OPENBAO_PKI_CA_O} + - MG_CERTS_OPENBAO_PKI_CA_C=${MG_CERTS_OPENBAO_PKI_CA_C} + - MG_CERTS_OPENBAO_PKI_CA_L=${MG_CERTS_OPENBAO_PKI_CA_L} + - MG_CERTS_OPENBAO_PKI_CA_ST=${MG_CERTS_OPENBAO_PKI_CA_ST} + - MG_CERTS_OPENBAO_PKI_CA_ADDR=${MG_CERTS_OPENBAO_PKI_CA_ADDR} + - MG_CERTS_OPENBAO_PKI_CA_PO=${MG_CERTS_OPENBAO_PKI_CA_PO} + - MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES=${MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES} + - MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES=${MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES} + - MG_CERTS_OPENBAO_PKI_CA_URI_SANS=${MG_CERTS_OPENBAO_PKI_CA_URI_SANS} + - MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES=${MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES} + - MG_CERTS_OPENBAO_UNSEAL_KEY_1=${MG_CERTS_OPENBAO_UNSEAL_KEY_1} + - MG_CERTS_OPENBAO_UNSEAL_KEY_2=${MG_CERTS_OPENBAO_UNSEAL_KEY_2} + - MG_CERTS_OPENBAO_UNSEAL_KEY_3=${MG_CERTS_OPENBAO_UNSEAL_KEY_3} + - MG_CERTS_OPENBAO_ROOT_TOKEN=${MG_CERTS_OPENBAO_ROOT_TOKEN} + cap_add: + - IPC_LOCK + mem_swappiness: 0 + volumes: + - magistrala-openbao-data:/opt/openbao/data + - magistrala-openbao-data:/opt/openbao/config + - ./openbao-entrypoint.sh:/entrypoint.sh + entrypoint: /bin/sh + command: /entrypoint.sh diff --git a/docker/fluxmq/node1.yaml b/docker/fluxmq/node1.yaml new file mode 100644 index 000000000..680820e3f --- /dev/null +++ b/docker/fluxmq/node1.yaml @@ -0,0 +1,117 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +server: + tcp: + v3: + addr: "0.0.0.0:1883" + protocol: "v3" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + v5: + addr: "0.0.0.0:1884" + protocol: "v5" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + websocket: + v3: + addr: "0.0.0.0:8883" + path: "/mqtt" + protocol: "v3" + v5: + addr: "0.0.0.0:8884" + path: "/mqtt" + protocol: "v5" + http: + plain: + addr: "0.0.0.0:8090" + coap: + plain: + addr: "0.0.0.0:5683" + amqp: + plain: + addr: "0.0.0.0:5672" + max_connections: 10000 + amqp091: + plain: + addr: "0.0.0.0:5682" + health_addr: "0.0.0.0:8081" + health_enabled: true + shutdown_timeout: 30s + +broker: + max_message_size: 1048576 + max_retained_messages: 10000 + retry_interval: 20s + max_retries: 0 + max_qos: 2 + async_fan_out: false + fan_out_workers: 0 + +session: + max_sessions: 10000 + default_expiry_interval: 300 + max_offline_queue_size: 10000 + max_inflight_messages: 1000 + max_send_queue_size: 1000 + offline_queue_policy: "evict" + inflight_overflow: 1 + pending_queue_size: 1000 + +log: + level: "info" + format: "text" + +storage: + type: "badger" + badger_dir: "/tmp/fluxmq/data" + sync_writes: false + +cluster: + enabled: true + node_id: "node1" + etcd: + data_dir: "/tmp/fluxmq/etcd" + bind_addr: "172.30.0.201:2380" + client_addr: "172.30.0.201:2379" + initial_cluster: "node1=http://172.30.0.201:2380,node2=http://172.30.0.202:2380,node3=http://172.30.0.203:2380" + bootstrap: true + hybrid_retained_size_threshold: 1024 + transport: + bind_addr: "0.0.0.0:7948" + peers: + node2: "fluxmq-node2:7948" + node3: "fluxmq-node3:7948" + route_batch_max_size: 256 + route_batch_max_delay: 50ms + route_batch_flush_workers: 8 + route_publish_timeout: 15s + +queue_manager: + auto_commit_interval: 5s + +queues: + - name: "mqtt" + topics: + - "$queue/#" + reserved: true + - name: "events" + topics: + - "$queue/events/#" + type: "stream" + retention: + max_age: 168h + max_length_bytes: 1073741824 + +auth: + url: "http://fluxmq-auth:7016" + transport: "grpc" + timeout: 5s + protocols: + mqtt: true + http: true + coap: true + amqp: true + amqp091: false diff --git a/docker/fluxmq/node2.yaml b/docker/fluxmq/node2.yaml new file mode 100644 index 000000000..e57c477d9 --- /dev/null +++ b/docker/fluxmq/node2.yaml @@ -0,0 +1,114 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +server: + tcp: + v3: + addr: "0.0.0.0:1883" + protocol: "v3" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + v5: + addr: "0.0.0.0:1884" + protocol: "v5" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + websocket: + v3: + addr: "0.0.0.0:8883" + path: "/mqtt" + protocol: "v3" + v5: + addr: "0.0.0.0:8884" + path: "/mqtt" + protocol: "v5" + http: + plain: + addr: "0.0.0.0:8090" + amqp: + plain: + addr: "0.0.0.0:5672" + max_connections: 10000 + amqp091: + plain: + addr: "0.0.0.0:5682" + health_addr: "0.0.0.0:8084" + health_enabled: true + shutdown_timeout: 30s + +broker: + max_message_size: 1048576 + max_retained_messages: 10000 + retry_interval: 20s + max_retries: 0 + max_qos: 2 + async_fan_out: false + fan_out_workers: 0 + +session: + max_sessions: 10000 + default_expiry_interval: 300 + max_offline_queue_size: 10000 + max_inflight_messages: 1000 + max_send_queue_size: 1000 + offline_queue_policy: "evict" + inflight_overflow: 1 + pending_queue_size: 1000 + +log: + level: "info" + format: "text" + +storage: + type: "badger" + badger_dir: "/tmp/fluxmq/data" + sync_writes: false + +cluster: + enabled: true + node_id: "node2" + etcd: + data_dir: "/tmp/fluxmq/etcd" + bind_addr: "172.30.0.202:2380" + client_addr: "172.30.0.202:2379" + initial_cluster: "node1=http://172.30.0.201:2380,node2=http://172.30.0.202:2380,node3=http://172.30.0.203:2380" + bootstrap: true + hybrid_retained_size_threshold: 1024 + transport: + bind_addr: "0.0.0.0:7948" + peers: + node1: "fluxmq-node1:7948" + node3: "fluxmq-node3:7948" + route_batch_max_size: 256 + route_batch_max_delay: 50ms + route_batch_flush_workers: 8 + route_publish_timeout: 15s + +queue_manager: + auto_commit_interval: 5s + +queues: + - name: "mqtt" + topics: + - "$queue/#" + reserved: true + - name: "events" + topics: + - "$queue/events/#" + type: "stream" + retention: + max_age: 168h + max_length_bytes: 1073741824 + +auth: + url: "http://fluxmq-auth:7016" + transport: "grpc" + timeout: 5s + protocols: + mqtt: true + http: true + coap: true + amqp: true + amqp091: false diff --git a/docker/fluxmq/node3.yaml b/docker/fluxmq/node3.yaml new file mode 100644 index 000000000..7d2f5b687 --- /dev/null +++ b/docker/fluxmq/node3.yaml @@ -0,0 +1,114 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +server: + tcp: + v3: + addr: "0.0.0.0:1883" + protocol: "v3" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + v5: + addr: "0.0.0.0:1884" + protocol: "v5" + max_connections: 10000 + read_timeout: 60s + write_timeout: 60s + websocket: + v3: + addr: "0.0.0.0:8883" + path: "/mqtt" + protocol: "v3" + v5: + addr: "0.0.0.0:8884" + path: "/mqtt" + protocol: "v5" + http: + plain: + addr: "0.0.0.0:8090" + amqp: + plain: + addr: "0.0.0.0:5672" + max_connections: 10000 + amqp091: + plain: + addr: "0.0.0.0:5682" + health_addr: "0.0.0.0:8083" + health_enabled: true + shutdown_timeout: 30s + +broker: + max_message_size: 1048576 + max_retained_messages: 10000 + retry_interval: 20s + max_retries: 0 + max_qos: 2 + async_fan_out: false + fan_out_workers: 0 + +session: + max_sessions: 10000 + default_expiry_interval: 300 + max_offline_queue_size: 10000 + max_inflight_messages: 1000 + max_send_queue_size: 1000 + offline_queue_policy: "evict" + inflight_overflow: 1 + pending_queue_size: 1000 + +log: + level: "info" + format: "text" + +storage: + type: "badger" + badger_dir: "/tmp/fluxmq/data" + sync_writes: false + +cluster: + enabled: true + node_id: "node3" + etcd: + data_dir: "/tmp/fluxmq/etcd" + bind_addr: "172.30.0.203:2380" + client_addr: "172.30.0.203:2379" + initial_cluster: "node1=http://172.30.0.201:2380,node2=http://172.30.0.202:2380,node3=http://172.30.0.203:2380" + bootstrap: true + hybrid_retained_size_threshold: 1024 + transport: + bind_addr: "0.0.0.0:7948" + peers: + node1: "fluxmq-node1:7948" + node2: "fluxmq-node2:7948" + route_batch_max_size: 256 + route_batch_max_delay: 50ms + route_batch_flush_workers: 8 + route_publish_timeout: 15s + +queue_manager: + auto_commit_interval: 5s + +queues: + - name: "mqtt" + topics: + - "$queue/#" + reserved: true + - name: "events" + topics: + - "$queue/events/#" + type: "stream" + retention: + max_age: 168h + max_length_bytes: 1073741824 + +auth: + url: "http://fluxmq-auth:7016" + transport: "grpc" + timeout: 5s + protocols: + mqtt: true + http: true + coap: true + amqp: true + amqp091: false diff --git a/docker/nats/nats.conf b/docker/nats/nats.conf index a547b6bf0..cf08e8186 100644 --- a/docker/nats/nats.conf +++ b/docker/nats/nats.conf @@ -4,14 +4,14 @@ server_name: "nats_internal_broker" max_payload: 10MB max_connections: 1M -port: $SMQ_NATS_PORT -http_port: $SMQ_NATS_HTTP_PORT +port: $MG_NATS_PORT +http_port: $MG_NATS_HTTP_PORT trace: true jetstream { store_dir: "/data" cipher: "aes" - key: $SMQ_NATS_JETSTREAM_KEY + key: $MG_NATS_JETSTREAM_KEY max_mem: 1G } diff --git a/docker/nginx/.gitignore b/docker/nginx/.gitignore deleted file mode 100644 index 9453269cc..000000000 --- a/docker/nginx/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -snippets/mqtt-upstream.conf -snippets/mqtt-ws-upstream.conf \ No newline at end of file diff --git a/docker/nginx/entrypoint.sh b/docker/nginx/entrypoint.sh index a220673f1..b04e65a0d 100755 --- a/docker/nginx/entrypoint.sh +++ b/docker/nginx/entrypoint.sh @@ -2,26 +2,26 @@ # Copyright (c) Abstract Machines # SPDX-License-Identifier: Apache-2.0 -if [ -z "$SMQ_MQTT_CLUSTER" ] -then - envsubst '${SMQ_MQTT_ADAPTER_MQTT_PORT}' < /etc/nginx/snippets/mqtt-upstream-single.conf > /etc/nginx/snippets/mqtt-upstream.conf - envsubst '${SMQ_MQTT_ADAPTER_WS_PORT}' < /etc/nginx/snippets/mqtt-ws-upstream-single.conf > /etc/nginx/snippets/mqtt-ws-upstream.conf -else - envsubst '${SMQ_MQTT_ADAPTER_MQTT_PORT}' < /etc/nginx/snippets/mqtt-upstream-cluster.conf > /etc/nginx/snippets/mqtt-upstream.conf - envsubst '${SMQ_MQTT_ADAPTER_WS_PORT}' < /etc/nginx/snippets/mqtt-ws-upstream-cluster.conf > /etc/nginx/snippets/mqtt-ws-upstream.conf +if [ ! -f /etc/nginx/snippets/mqtt-upstream.conf ] || [ ! -f /etc/nginx/snippets/mqtt-ws-upstream.conf ]; then + echo "Missing MQTT upstream snippets; cannot start nginx." >&2 + exit 1 fi envsubst ' - ${SMQ_NGINX_SERVER_NAME} - ${SMQ_AUTH_HTTP_PORT} - ${SMQ_DOMAINS_HTTP_PORT} - ${SMQ_GROUPS_HTTP_PORT} - ${SMQ_USERS_HTTP_PORT} - ${SMQ_CLIENTS_HTTP_PORT} - ${SMQ_CLIENTS_AUTH_HTTP_PORT} - ${SMQ_CHANNELS_HTTP_PORT} - ${SMQ_HTTP_ADAPTER_PORT} - ${SMQ_NGINX_MQTT_PORT} - ${SMQ_NGINX_MQTTS_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf + ${MG_NGINX_SERVER_NAME} + ${MG_AUTH_HTTP_PORT} + ${MG_DOMAINS_HTTP_PORT} + ${MG_GROUPS_HTTP_PORT} + ${MG_USERS_HTTP_PORT} + ${MG_CLIENTS_HTTP_PORT} + ${MG_CLIENTS_AUTH_HTTP_PORT} + ${MG_CHANNELS_HTTP_PORT} + ${MG_HTTP_ADAPTER_PORT} + ${MG_NGINX_MQTT_PORT} + ${MG_NGINX_MQTTS_PORT} + ${MG_RE_HTTP_PORT} + ${MG_ALARMS_HTTP_PORT} + ${MG_REPORTS_HTTP_PORT} + ${MG_NGINX_AMQP_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf exec nginx -g "daemon off;" diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index 4c35ffb05..bdc84fd36 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -6,6 +6,7 @@ user nginx; worker_processes auto; worker_cpu_affinity auto; +worker_rlimit_nofile 65535; pid /run/nginx.pid; include /etc/nginx/modules-enabled/*.conf; @@ -30,9 +31,12 @@ http { ssl_protocols TLSv1.2 TLSv1.3; ssl_prefer_server_ciphers on; + resolver 127.0.0.11 ipv6=off valid=10s; + resolver_timeout 5s; # Include single-node or multiple-node (cluster) upstream include snippets/mqtt-ws-upstream.conf; + include snippets/fluxmq-http-upstream.conf; server { listen 80 default_server; @@ -41,13 +45,22 @@ http { listen [::]:443 ssl default_server; http2 on; - set $dynamic_server_name "$SMQ_NGINX_SERVER_NAME"; + set $dynamic_server_name "$MG_NGINX_SERVER_NAME"; if ($dynamic_server_name = '') { set $dynamic_server_name "localhost"; } server_name $dynamic_server_name; + set $auth_upstream "auth:${MG_AUTH_HTTP_PORT}"; + set $domains_upstream "domains:${MG_DOMAINS_HTTP_PORT}"; + set $users_upstream "users:${MG_USERS_HTTP_PORT}"; + set $groups_upstream "groups:${MG_GROUPS_HTTP_PORT}"; + set $clients_upstream "clients:${MG_CLIENTS_HTTP_PORT}"; + set $channels_upstream "channels:${MG_CHANNELS_HTTP_PORT}"; + set $rules_upstream "re:${MG_RE_HTTP_PORT}"; + set $alarms_upstream "alarms:${MG_ALARMS_HTTP_PORT}"; + set $reports_upstream "reports:${MG_REPORTS_HTTP_PORT}"; include snippets/ssl.conf; @@ -62,77 +75,87 @@ http { location ~ ^/(pats) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT}; + proxy_pass http://$auth_upstream; } # Proxy pass to domains service location ~ ^/(domains|invitations) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://domains:${SMQ_DOMAINS_HTTP_PORT}; + proxy_pass http://$domains_upstream; } # Proxy pass to users service location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://users:${SMQ_USERS_HTTP_PORT}; + proxy_pass http://$users_upstream; } # Proxy pass to groups service location ~ "^/([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})/(groups)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://groups:${SMQ_GROUPS_HTTP_PORT}; + proxy_pass http://$groups_upstream; } # Proxy pass to clients service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(clients)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } # Proxy pass to channels service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT}; + proxy_pass http://$channels_upstream; + } + + # Proxy pass to rule engine service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(rules)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$rules_upstream; + } + + # Proxy pass to alarm service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(alarms)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$alarms_upstream; + } + + # Proxy pass to reports service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(reports)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$reports_upstream; } location /health { include snippets/proxy-headers.conf; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } location /metrics { include snippets/proxy-headers.conf; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } - # Proxy pass to supermq-http-adapter + # Proxy pass to FluxMQ HTTP API location /http/ { include snippets/proxy-headers.conf; - - # Trailing `/` is mandatory. Refer to the http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_pass - # If the proxy_pass directive is specified with a URI, then when a request is passed to the server, - # the part of a normalized request URI matching the location is replaced by a URI specified in the directive - proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; + proxy_pass http://fluxmq_http_cluster/; } - # Proxy pass to supermq-mqtt-adapter over WS + # Proxy pass to FluxMQ MQTT over WebSocket location /mqtt { include snippets/proxy-headers.conf; include snippets/ws-upgrade.conf; proxy_pass http://mqtt_ws_cluster; } - - # Proxy pass to supermq-ws-adapter - location /ws/ { - include snippets/proxy-headers.conf; - include snippets/ws-upgrade.conf; - proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; - } } } @@ -142,17 +165,26 @@ stream { # Include single-node or multiple-node (cluster) upstream include snippets/mqtt-upstream.conf; + include snippets/fluxmq-amqp-upstream.conf; server { - listen ${SMQ_NGINX_MQTT_PORT}; - listen [::]:${SMQ_NGINX_MQTT_PORT}; - listen ${SMQ_NGINX_MQTTS_PORT} ssl; - listen [::]:${SMQ_NGINX_MQTTS_PORT} ssl; + listen ${MG_NGINX_MQTT_PORT}; + listen [::]:${MG_NGINX_MQTT_PORT}; + listen ${MG_NGINX_MQTTS_PORT} ssl; + listen [::]:${MG_NGINX_MQTTS_PORT} ssl; include snippets/ssl.conf; proxy_pass mqtt_cluster; } + + # FluxMQ AMQP 0.9.1 (event store) + server { + listen ${MG_NGINX_AMQP_PORT}; + listen [::]:${MG_NGINX_AMQP_PORT}; + + proxy_pass fluxmq_amqp_cluster; + } } error_log info.log info; diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index e18ee0baf..41c88161f 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -6,6 +6,7 @@ user nginx; worker_processes auto; worker_cpu_affinity auto; +worker_rlimit_nofile 65535; pid /run/nginx.pid; load_module /etc/nginx/modules/ngx_stream_js_module.so; load_module /etc/nginx/modules/ngx_http_js_module.so; @@ -37,9 +38,12 @@ http { ssl_protocols TLSv1.2 TLSv1.3; ssl_prefer_server_ciphers on; + resolver 127.0.0.11 ipv6=off valid=10s; + resolver_timeout 5s; # Include single-node or multiple-node (cluster) upstream include snippets/mqtt-ws-upstream.conf; + include snippets/fluxmq-http-upstream.conf; server { listen 80 default_server; @@ -48,13 +52,22 @@ http { listen [::]:443 ssl default_server; http2 on; - set $dynamic_server_name "$SMQ_NGINX_SERVER_NAME"; + set $dynamic_server_name "$MG_NGINX_SERVER_NAME"; if ($dynamic_server_name = '') { set $dynamic_server_name "localhost"; } server_name $dynamic_server_name; + set $auth_upstream "auth:${MG_AUTH_HTTP_PORT}"; + set $domains_upstream "domains:${MG_DOMAINS_HTTP_PORT}"; + set $users_upstream "users:${MG_USERS_HTTP_PORT}"; + set $groups_upstream "groups:${MG_GROUPS_HTTP_PORT}"; + set $clients_upstream "clients:${MG_CLIENTS_HTTP_PORT}"; + set $channels_upstream "channels:${MG_CHANNELS_HTTP_PORT}"; + set $rules_upstream "re:${MG_RE_HTTP_PORT}"; + set $alarms_upstream "alarms:${MG_ALARMS_HTTP_PORT}"; + set $reports_upstream "reports:${MG_REPORTS_HTTP_PORT}"; ssl_verify_client optional; include snippets/ssl.conf; @@ -71,81 +84,90 @@ http { location ~ ^/(pats) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT}; + proxy_pass http://$auth_upstream; } # Proxy pass to domains service location ~ ^/(domains|invitations) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://domains:${SMQ_DOMAINS_HTTP_PORT}; + proxy_pass http://$domains_upstream; } # Proxy pass to users service location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://users:${SMQ_USERS_HTTP_PORT}; + proxy_pass http://$users_upstream; } # Proxy pass to groups service location ~ "^/([a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12})/(groups)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://groups:${SMQ_GROUPS_HTTP_PORT}; + proxy_pass http://$groups_upstream; } # Proxy pass to clients service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(clients)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } # Proxy pass to channels service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; - proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT}; + proxy_pass http://$channels_upstream; + } + + # Proxy pass to rule engine service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(rules)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$rules_upstream; + } + + # Proxy pass to alarms service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(alarms)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$alarms_upstream; + } + + # Proxy pass to reports service + location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(reports)" { + include snippets/proxy-headers.conf; + add_header Access-Control-Expose-Headers Location; + proxy_pass http://$reports_upstream; } location /health { include snippets/proxy-headers.conf; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } location /metrics { include snippets/proxy-headers.conf; - proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; + proxy_pass http://$clients_upstream; } - # Proxy pass to supermq-http-adapter + # Proxy pass to FluxMQ HTTP API location /http/ { include snippets/verify-ssl-client.conf; include snippets/proxy-headers.conf; proxy_set_header Authorization $auth_key; - - # Trailing `/` is mandatory. Refer to the http://nginx.org/en/docs/http/ngx_http_proxy_module.html#proxy_pass - # If the proxy_pass directive is specified with a URI, then when a request is passed to the server, - # the part of a normalized request URI matching the location is replaced by a URI specified in the directive - proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; + proxy_pass http://fluxmq_http_cluster/; } - # Proxy pass to supermq-mqtt-adapter over WS + # Proxy pass to FluxMQ MQTT over WebSocket location /mqtt { include snippets/verify-ssl-client.conf; include snippets/proxy-headers.conf; include snippets/ws-upgrade.conf; proxy_pass http://mqtt_ws_cluster; } - - # Proxy pass to supermq-ws-adapter - location /ws/ { - include snippets/verify-ssl-client.conf; - include snippets/proxy-headers.conf; - include snippets/ws-upgrade.conf; - proxy_pass http://http-adapter:${SMQ_HTTP_ADAPTER_PORT}/; - } } } @@ -160,20 +182,29 @@ stream { # Include single-node or multiple-node (cluster) upstream include snippets/mqtt-upstream.conf; + include snippets/fluxmq-amqp-upstream.conf; ssl_verify_client on; include snippets/ssl-client.conf; server { - listen ${SMQ_NGINX_MQTT_PORT}; - listen [::]:${SMQ_NGINX_MQTT_PORT}; - listen ${SMQ_NGINX_MQTTS_PORT} ssl; - listen [::]:${SMQ_NGINX_MQTTS_PORT} ssl; + listen ${MG_NGINX_MQTT_PORT}; + listen [::]:${MG_NGINX_MQTT_PORT}; + listen ${MG_NGINX_MQTTS_PORT} ssl; + listen [::]:${MG_NGINX_MQTTS_PORT} ssl; include snippets/ssl.conf; js_preread authorization.authenticate; proxy_pass mqtt_cluster; } + + # FluxMQ AMQP 0.9.1 (event store) + server { + listen ${MG_NGINX_AMQP_PORT}; + listen [::]:${MG_NGINX_AMQP_PORT}; + + proxy_pass fluxmq_amqp_cluster; + } } error_log info.log info; diff --git a/docker/nginx/snippets/fluxmq-amqp-upstream.conf b/docker/nginx/snippets/fluxmq-amqp-upstream.conf new file mode 100644 index 000000000..9706b38e8 --- /dev/null +++ b/docker/nginx/snippets/fluxmq-amqp-upstream.conf @@ -0,0 +1,10 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +upstream fluxmq_amqp_cluster { + zone fluxmq_amqp_cluster_zone 64k; + random two least_conn; + server 172.30.0.201:5682; + server 172.30.0.202:5682; + server 172.30.0.203:5682; +} diff --git a/docker/nginx/snippets/fluxmq-http-upstream.conf b/docker/nginx/snippets/fluxmq-http-upstream.conf new file mode 100644 index 000000000..671fd08e7 --- /dev/null +++ b/docker/nginx/snippets/fluxmq-http-upstream.conf @@ -0,0 +1,10 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +upstream fluxmq_http_cluster { + zone fluxmq_http_cluster_zone 64k; + random two least_conn; + server 172.30.0.201:8090; + server 172.30.0.202:8090; + server 172.30.0.203:8090; +} diff --git a/docker/nginx/snippets/mqtt-upstream-cluster.conf b/docker/nginx/snippets/mqtt-upstream-cluster.conf deleted file mode 100644 index 48159e25c..000000000 --- a/docker/nginx/snippets/mqtt-upstream-cluster.conf +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -upstream mqtt_cluster { - least_conn; - server mqtt-adapter-1:${SMQ_MQTT_ADAPTER_MQTT_PORT}; - server mqtt-adapter-2:${SMQ_MQTT_ADAPTER_MQTT_PORT}; - server mqtt-adapter-3:${SMQ_MQTT_ADAPTER_MQTT_PORT}; -} diff --git a/docker/nginx/snippets/mqtt-upstream-single.conf b/docker/nginx/snippets/mqtt-upstream-single.conf deleted file mode 100644 index f714eb5f6..000000000 --- a/docker/nginx/snippets/mqtt-upstream-single.conf +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -upstream mqtt_cluster { - server mqtt-adapter:${SMQ_MQTT_ADAPTER_MQTT_PORT}; -} diff --git a/docker/nginx/snippets/mqtt-upstream.conf b/docker/nginx/snippets/mqtt-upstream.conf new file mode 100644 index 000000000..a98172891 --- /dev/null +++ b/docker/nginx/snippets/mqtt-upstream.conf @@ -0,0 +1,10 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +upstream mqtt_cluster { + zone mqtt_cluster_zone 64k; + random two least_conn; + server 172.30.0.201:1883; + server 172.30.0.202:1883; + server 172.30.0.203:1883; +} diff --git a/docker/nginx/snippets/mqtt-ws-upstream-cluster.conf b/docker/nginx/snippets/mqtt-ws-upstream-cluster.conf deleted file mode 100644 index 293bf1d86..000000000 --- a/docker/nginx/snippets/mqtt-ws-upstream-cluster.conf +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -upstream mqtt_ws_cluster { - least_conn; - server mqtt-adapter-1:${SMQ_MQTT_ADAPTER_WS_PORT}; - server mqtt-adapter-2:${SMQ_MQTT_ADAPTER_WS_PORT}; - server mqtt-adapter-3:${SMQ_MQTT_ADAPTER_WS_PORT}; -} diff --git a/docker/nginx/snippets/mqtt-ws-upstream-single.conf b/docker/nginx/snippets/mqtt-ws-upstream-single.conf deleted file mode 100644 index 2df359925..000000000 --- a/docker/nginx/snippets/mqtt-ws-upstream-single.conf +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -upstream mqtt_ws_cluster { - server mqtt-adapter:${SMQ_MQTT_ADAPTER_WS_PORT}; -} diff --git a/docker/nginx/snippets/mqtt-ws-upstream.conf b/docker/nginx/snippets/mqtt-ws-upstream.conf new file mode 100644 index 000000000..dcde2a0da --- /dev/null +++ b/docker/nginx/snippets/mqtt-ws-upstream.conf @@ -0,0 +1,10 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +upstream mqtt_ws_cluster { + zone mqtt_ws_cluster_zone 64k; + random two least_conn; + server 172.30.0.201:8883; + server 172.30.0.202:8883; + server 172.30.0.203:8883; +} diff --git a/docker/nginx/snippets/ssl.conf b/docker/nginx/snippets/ssl.conf index d8cf234c5..109f6795b 100644 --- a/docker/nginx/snippets/ssl.conf +++ b/docker/nginx/snippets/ssl.conf @@ -3,8 +3,8 @@ # These paths are set to its default values as # a volume in the docker/docker-compose.yaml file. -ssl_certificate /etc/ssl/certs/supermq-server.crt; -ssl_certificate_key /etc/ssl/private/supermq-server.key; +ssl_certificate /etc/ssl/certs/magistrala-server.crt; +ssl_certificate_key /etc/ssl/private/magistrala-server.key; ssl_dhparam /etc/ssl/certs/dhparam.pem; ssl_protocols TLSv1.2 TLSv1.3; diff --git a/docker/addons/certs/openbao-entrypoint.sh b/docker/openbao-entrypoint.sh similarity index 69% rename from docker/addons/certs/openbao-entrypoint.sh rename to docker/openbao-entrypoint.sh index 20344ff3c..c3f3d59bf 100755 --- a/docker/addons/certs/openbao-entrypoint.sh +++ b/docker/openbao-entrypoint.sh @@ -29,13 +29,13 @@ export BAO_ADDR=http://127.0.0.1:8200 create_pki_policy() { cat > /opt/openbao/config/pki-policy.hcl << EOF -path "pki_int/issue/${AM_CERTS_OPENBAO_PKI_ROLE}" { +path "pki_int/issue/${MG_CERTS_OPENBAO_PKI_ROLE}" { capabilities = ["create", "update"] } -path "pki_int/sign/${AM_CERTS_OPENBAO_PKI_ROLE}" { +path "pki_int/sign/${MG_CERTS_OPENBAO_PKI_ROLE}" { capabilities = ["create", "update"] } -path "pki_int/sign-verbatim/${AM_CERTS_OPENBAO_PKI_ROLE}" { +path "pki_int/sign-verbatim/${MG_CERTS_OPENBAO_PKI_ROLE}" { capabilities = ["create", "update"] } path "pki_int/certs" { @@ -74,13 +74,13 @@ path "auth/token/lookup-self" { path "sys/renew/*" { capabilities = ["update"] } -path "auth/approle/role/${AM_CERTS_OPENBAO_PKI_ROLE}/secret-id" { +path "auth/approle/role/${MG_CERTS_OPENBAO_PKI_ROLE}/secret-id" { capabilities = ["create", "update"] } -path "auth/approle/role/${AM_CERTS_OPENBAO_PKI_ROLE}/secret-id-accessor/lookup" { +path "auth/approle/role/${MG_CERTS_OPENBAO_PKI_ROLE}/secret-id-accessor/lookup" { capabilities = ["create", "update"] } -path "auth/approle/role/${AM_CERTS_OPENBAO_PKI_ROLE}/secret-id-accessor/destroy" { +path "auth/approle/role/${MG_CERTS_OPENBAO_PKI_ROLE}/secret-id-accessor/destroy" { capabilities = ["create", "update"] } EOF @@ -88,17 +88,17 @@ EOF } # Check if we have pre-configured unseal keys and root token -if [ -n "$AM_CERTS_OPENBAO_UNSEAL_KEY_1" ] && [ -n "$AM_CERTS_OPENBAO_UNSEAL_KEY_2" ] && [ -n "$AM_CERTS_OPENBAO_UNSEAL_KEY_3" ] && [ -n "$AM_CERTS_OPENBAO_ROOT_TOKEN" ]; then +if [ -n "$MG_CERTS_OPENBAO_UNSEAL_KEY_1" ] && [ -n "$MG_CERTS_OPENBAO_UNSEAL_KEY_2" ] && [ -n "$MG_CERTS_OPENBAO_UNSEAL_KEY_3" ] && [ -n "$MG_CERTS_OPENBAO_ROOT_TOKEN" ]; then echo "Using pre-configured unseal keys and root token..." bao server -config=/opt/openbao/config/config.hcl > /opt/openbao/logs/server.log 2>&1 & BAO_PID=$! sleep 5 - bao operator unseal "$AM_CERTS_OPENBAO_UNSEAL_KEY_1" - bao operator unseal "$AM_CERTS_OPENBAO_UNSEAL_KEY_2" - bao operator unseal "$AM_CERTS_OPENBAO_UNSEAL_KEY_3" + bao operator unseal "$MG_CERTS_OPENBAO_UNSEAL_KEY_1" + bao operator unseal "$MG_CERTS_OPENBAO_UNSEAL_KEY_2" + bao operator unseal "$MG_CERTS_OPENBAO_UNSEAL_KEY_3" - export BAO_TOKEN=$AM_CERTS_OPENBAO_ROOT_TOKEN + export BAO_TOKEN=$MG_CERTS_OPENBAO_ROOT_TOKEN else # Initialize OpenBao if not already done if [ ! -f /opt/openbao/data/init.json ]; then @@ -154,18 +154,18 @@ if [ ! -f /opt/openbao/data/configured ]; then echo "Configuring OpenBao PKI and AppRole..." # Create namespace if specified - if [ -n "$AM_CERTS_OPENBAO_NAMESPACE" ]; then - if bao namespace create "$AM_CERTS_OPENBAO_NAMESPACE" 2>/tmp/ns_error; then - export BAO_NAMESPACE="$AM_CERTS_OPENBAO_NAMESPACE" - echo "$AM_CERTS_OPENBAO_NAMESPACE" > /opt/openbao/data/namespace - echo "Created namespace: $AM_CERTS_OPENBAO_NAMESPACE" + if [ -n "$MG_CERTS_OPENBAO_NAMESPACE" ]; then + if bao namespace create "$MG_CERTS_OPENBAO_NAMESPACE" 2>/tmp/ns_error; then + export BAO_NAMESPACE="$MG_CERTS_OPENBAO_NAMESPACE" + echo "$MG_CERTS_OPENBAO_NAMESPACE" > /opt/openbao/data/namespace + echo "Created namespace: $MG_CERTS_OPENBAO_NAMESPACE" else if grep -q "namespace already exists" /tmp/ns_error; then - export BAO_NAMESPACE="$AM_CERTS_OPENBAO_NAMESPACE" - echo "$AM_CERTS_OPENBAO_NAMESPACE" > /opt/openbao/data/namespace - echo "Using existing namespace: $AM_CERTS_OPENBAO_NAMESPACE" + export BAO_NAMESPACE="$MG_CERTS_OPENBAO_NAMESPACE" + echo "$MG_CERTS_OPENBAO_NAMESPACE" > /opt/openbao/data/namespace + echo "Using existing namespace: $MG_CERTS_OPENBAO_NAMESPACE" else - echo "ERROR: Failed to create namespace $AM_CERTS_OPENBAO_NAMESPACE:" >&2 + echo "ERROR: Failed to create namespace $MG_CERTS_OPENBAO_NAMESPACE:" >&2 cat /tmp/ns_error >&2 exit 1 fi @@ -200,7 +200,7 @@ if [ ! -f /opt/openbao/data/configured ]; then bao secrets tune -max-lease-ttl=87600h pki > /dev/null # Validate required CA environment variables - for var in AM_CERTS_OPENBAO_PKI_CA_CN AM_CERTS_OPENBAO_PKI_CA_O AM_CERTS_OPENBAO_PKI_CA_C; do + for var in MG_CERTS_OPENBAO_PKI_CA_CN MG_CERTS_OPENBAO_PKI_CA_O MG_CERTS_OPENBAO_PKI_CA_C; do eval "value=\$var" if [ -z "$value" ]; then echo "ERROR: Required environment variable $var is not set" >&2 @@ -209,23 +209,23 @@ if [ ! -f /opt/openbao/data/configured ]; then done PKI_CMD="bao write -field=certificate pki/root/generate/internal \ - common_name=\"$AM_CERTS_OPENBAO_PKI_CA_CN\" \ - organization=\"$AM_CERTS_OPENBAO_PKI_CA_O\" \ - country=\"$AM_CERTS_OPENBAO_PKI_CA_C\" \ + common_name=\"$MG_CERTS_OPENBAO_PKI_CA_CN\" \ + organization=\"$MG_CERTS_OPENBAO_PKI_CA_O\" \ + country=\"$MG_CERTS_OPENBAO_PKI_CA_C\" \ ttl=87600h \ key_bits=2048 \ exclude_cn_from_sans=false" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_OU" ] && PKI_CMD="$PKI_CMD ou=\"$AM_CERTS_OPENBAO_PKI_CA_OU\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_L" ] && PKI_CMD="$PKI_CMD locality=\"$AM_CERTS_OPENBAO_PKI_CA_L\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_ST" ] && PKI_CMD="$PKI_CMD province=\"$AM_CERTS_OPENBAO_PKI_CA_ST\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_ADDR" ] && PKI_CMD="$PKI_CMD street_address=\"$AM_CERTS_OPENBAO_PKI_CA_ADDR\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_PO" ] && PKI_CMD="$PKI_CMD postal_code=\"$AM_CERTS_OPENBAO_PKI_CA_PO\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_OU" ] && PKI_CMD="$PKI_CMD ou=\"$MG_CERTS_OPENBAO_PKI_CA_OU\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_L" ] && PKI_CMD="$PKI_CMD locality=\"$MG_CERTS_OPENBAO_PKI_CA_L\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_ST" ] && PKI_CMD="$PKI_CMD province=\"$MG_CERTS_OPENBAO_PKI_CA_ST\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_ADDR" ] && PKI_CMD="$PKI_CMD street_address=\"$MG_CERTS_OPENBAO_PKI_CA_ADDR\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_PO" ] && PKI_CMD="$PKI_CMD postal_code=\"$MG_CERTS_OPENBAO_PKI_CA_PO\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES" ] && PKI_CMD="$PKI_CMD alt_names=\"$AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES" ] && PKI_CMD="$PKI_CMD ip_sans=\"$AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_URI_SANS" ] && PKI_CMD="$PKI_CMD uri_sans=\"$AM_CERTS_OPENBAO_PKI_CA_URI_SANS\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES" ] && PKI_CMD="$PKI_CMD email_sans=\"$AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES" ] && PKI_CMD="$PKI_CMD alt_names=\"$MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES" ] && PKI_CMD="$PKI_CMD ip_sans=\"$MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_URI_SANS" ] && PKI_CMD="$PKI_CMD uri_sans=\"$MG_CERTS_OPENBAO_PKI_CA_URI_SANS\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES" ] && PKI_CMD="$PKI_CMD email_sans=\"$MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES\"" eval $PKI_CMD > /dev/null @@ -248,24 +248,24 @@ if [ ! -f /opt/openbao/data/configured ]; then bao secrets tune -max-lease-ttl=8760h pki_int > /dev/null - INTERMEDIATE_CN="${AM_CERTS_OPENBAO_PKI_CA_CN} Intermediate" + INTERMEDIATE_CN="${MG_CERTS_OPENBAO_PKI_CA_CN} Intermediate" INTERMEDIATE_CSR_CMD="bao write -field=csr pki_int/intermediate/generate/internal \ common_name=\"$INTERMEDIATE_CN\" \ - organization=\"$AM_CERTS_OPENBAO_PKI_CA_O\" \ - country=\"$AM_CERTS_OPENBAO_PKI_CA_C\" \ + organization=\"$MG_CERTS_OPENBAO_PKI_CA_O\" \ + country=\"$MG_CERTS_OPENBAO_PKI_CA_C\" \ ttl=8760h \ key_bits=2048" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_OU" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD ou=\"$AM_CERTS_OPENBAO_PKI_CA_OU\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_L" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD locality=\"$AM_CERTS_OPENBAO_PKI_CA_L\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_ST" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD province=\"$AM_CERTS_OPENBAO_PKI_CA_ST\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_ADDR" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD street_address=\"$AM_CERTS_OPENBAO_PKI_CA_ADDR\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_PO" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD postal_code=\"$AM_CERTS_OPENBAO_PKI_CA_PO\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_OU" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD ou=\"$MG_CERTS_OPENBAO_PKI_CA_OU\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_L" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD locality=\"$MG_CERTS_OPENBAO_PKI_CA_L\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_ST" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD province=\"$MG_CERTS_OPENBAO_PKI_CA_ST\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_ADDR" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD street_address=\"$MG_CERTS_OPENBAO_PKI_CA_ADDR\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_PO" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD postal_code=\"$MG_CERTS_OPENBAO_PKI_CA_PO\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD alt_names=\"$AM_CERTS_OPENBAO_PKI_CA_DNS_NAMES\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD ip_sans=\"$AM_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_URI_SANS" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD uri_sans=\"$AM_CERTS_OPENBAO_PKI_CA_URI_SANS\"" - [ -n "$AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD email_sans=\"$AM_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD alt_names=\"$MG_CERTS_OPENBAO_PKI_CA_DNS_NAMES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD ip_sans=\"$MG_CERTS_OPENBAO_PKI_CA_IP_ADDRESSES\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_URI_SANS" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD uri_sans=\"$MG_CERTS_OPENBAO_PKI_CA_URI_SANS\"" + [ -n "$MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES" ] && INTERMEDIATE_CSR_CMD="$INTERMEDIATE_CSR_CMD email_sans=\"$MG_CERTS_OPENBAO_PKI_CA_EMAIL_ADDRESSES\"" INTERMEDIATE_CSR=$(eval $INTERMEDIATE_CSR_CMD) @@ -310,7 +310,7 @@ if [ ! -f /opt/openbao/data/configured ]; then echo "$INTERMEDIATE_CERT" > /opt/openbao/data/intermediate_ca.pem - ROLE_CMD="bao write pki_int/roles/${AM_CERTS_OPENBAO_PKI_ROLE} \ + ROLE_CMD="bao write pki_int/roles/${MG_CERTS_OPENBAO_PKI_ROLE} \ allow_any_name=true \ enforce_hostnames=false \ allow_ip_sans=true \ @@ -340,8 +340,8 @@ if [ ! -f /opt/openbao/data/configured ]; then create_pki_policy # Create AppRole - SECRET_ID_TTL="${AM_CERTS_OPENBAO_SECRET_ID_TTL}" - bao write auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}" \ + SECRET_ID_TTL="${MG_CERTS_OPENBAO_SECRET_ID_TTL}" + bao write auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}" \ token_policies=pki-policy \ token_ttl=1h \ token_max_ttl=4h \ @@ -349,16 +349,16 @@ if [ ! -f /opt/openbao/data/configured ]; then secret_id_ttl="$SECRET_ID_TTL" > /dev/null # Set custom role ID if provided - if [ -n "$AM_CERTS_OPENBAO_APP_ROLE" ]; then - bao write auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/role-id role_id="$AM_CERTS_OPENBAO_APP_ROLE" > /dev/null + if [ -n "$MG_CERTS_OPENBAO_APP_ROLE" ]; then + bao write auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/role-id role_id="$MG_CERTS_OPENBAO_APP_ROLE" > /dev/null fi # Set custom secret ID if provided, otherwise generate one - if [ -n "$AM_CERTS_OPENBAO_APP_SECRET" ]; then - bao write auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/custom-secret-id secret_id="$AM_CERTS_OPENBAO_APP_SECRET" > /dev/null - echo "$AM_CERTS_OPENBAO_APP_SECRET" > /opt/openbao/data/secret_id + if [ -n "$MG_CERTS_OPENBAO_APP_SECRET" ]; then + bao write auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/custom-secret-id secret_id="$MG_CERTS_OPENBAO_APP_SECRET" > /dev/null + echo "$MG_CERTS_OPENBAO_APP_SECRET" > /opt/openbao/data/secret_id else - GENERATED_SECRET_ID=$(bao write -field=secret_id -force auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/secret-id) + GENERATED_SECRET_ID=$(bao write -field=secret_id -force auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/secret-id) echo "$GENERATED_SECRET_ID" > /opt/openbao/data/secret_id fi @@ -378,15 +378,15 @@ else echo "OpenBao already configured, verifying and updating configuration..." # Restore namespace if it exists - if [ -f /opt/openbao/data/namespace ] && [ -n "$AM_CERTS_OPENBAO_NAMESPACE" ]; then + if [ -f /opt/openbao/data/namespace ] && [ -n "$MG_CERTS_OPENBAO_NAMESPACE" ]; then SAVED_NAMESPACE=$(cat /opt/openbao/data/namespace) - if [ "$SAVED_NAMESPACE" = "$AM_CERTS_OPENBAO_NAMESPACE" ]; then - export BAO_NAMESPACE="$AM_CERTS_OPENBAO_NAMESPACE" + if [ "$SAVED_NAMESPACE" = "$MG_CERTS_OPENBAO_NAMESPACE" ]; then + export BAO_NAMESPACE="$MG_CERTS_OPENBAO_NAMESPACE" fi fi # Check if AppRole role exists, create if missing - if ! bao read auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}" > /dev/null 2>&1; then + if ! bao read auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}" > /dev/null 2>&1; then if ! bao auth enable approle > /tmp/auth_success 2>/tmp/auth_error; then if ! grep -q "already in use" /tmp/auth_error; then echo "ERROR: Failed to enable AppRole auth method:" >&2 @@ -398,29 +398,29 @@ else create_pki_policy - SECRET_ID_TTL="${AM_CERTS_OPENBAO_SECRET_ID_TTL}" - bao write auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}" \ + SECRET_ID_TTL="${MG_CERTS_OPENBAO_SECRET_ID_TTL}" + bao write auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}" \ token_policies=pki-policy \ token_ttl=1h \ token_max_ttl=4h \ bind_secret_id=true \ secret_id_ttl="$SECRET_ID_TTL" > /dev/null - if [ -n "$AM_CERTS_OPENBAO_APP_ROLE" ]; then - bao write auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/role-id role_id="$AM_CERTS_OPENBAO_APP_ROLE" > /dev/null + if [ -n "$MG_CERTS_OPENBAO_APP_ROLE" ]; then + bao write auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/role-id role_id="$MG_CERTS_OPENBAO_APP_ROLE" > /dev/null fi fi SECRET_ID_VALID=false - if [ -n "$AM_CERTS_OPENBAO_APP_SECRET" ]; then - if bao write -field=client_token auth/approle/login role_id="$AM_CERTS_OPENBAO_APP_ROLE" secret_id="$AM_CERTS_OPENBAO_APP_SECRET" > /dev/null 2>&1; then + if [ -n "$MG_CERTS_OPENBAO_APP_SECRET" ]; then + if bao write -field=client_token auth/approle/login role_id="$MG_CERTS_OPENBAO_APP_ROLE" secret_id="$MG_CERTS_OPENBAO_APP_SECRET" > /dev/null 2>&1; then SECRET_ID_VALID=true - echo "$AM_CERTS_OPENBAO_APP_SECRET" > /opt/openbao/data/secret_id + echo "$MG_CERTS_OPENBAO_APP_SECRET" > /opt/openbao/data/secret_id fi elif [ -f /opt/openbao/data/secret_id ]; then STORED_SECRET_ID=$(cat /opt/openbao/data/secret_id) if [ -n "$STORED_SECRET_ID" ]; then - ROLE_ID=$(bao read -field=role_id auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/role-id) + ROLE_ID=$(bao read -field=role_id auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/role-id) if bao write -field=client_token auth/approle/login role_id="$ROLE_ID" secret_id="$STORED_SECRET_ID" > /dev/null 2>&1; then SECRET_ID_VALID=true fi @@ -428,7 +428,7 @@ else fi if [ "$SECRET_ID_VALID" = "false" ]; then - NEW_SECRET_ID=$(bao write -field=secret_id -force auth/approle/role/"${AM_CERTS_OPENBAO_PKI_ROLE}"/secret-id) + NEW_SECRET_ID=$(bao write -field=secret_id -force auth/approle/role/"${MG_CERTS_OPENBAO_PKI_ROLE}"/secret-id) if [ -z "$NEW_SECRET_ID" ]; then echo "ERROR: Failed to generate new secret ID" >&2 diff --git a/docker/permission.yaml b/docker/permission.yaml index 9c895dc32..32268c5db 100644 --- a/docker/permission.yaml +++ b/docker/permission.yaml @@ -130,3 +130,75 @@ domains: - check_members_exists: view_role_users_permission - remove_members: remove_role_users_permission - remove_all_members: remove_role_users_permission + +alarm: + operations: + - list: alarm_read_permission + - view: alarm_read_permission + - update: alarm_update_permission + - delete: alarm_delete_permission + - assign: alarm_assign_permission + - acknowledge: alarm_acknowledge_permission + - resolve: alarm_resolve_permission + +rule: + operations: + - add: rule_create_permission + - list: rule_read_permission + - view: read_permission + - update: update_permission + - update_tags: update_permission + - update_schedule: update_permission + - enable: update_permission + - disable: update_permission + - delete: delete_permission + - alarm_assign: alarm_assign_permission + - alarm_acknowledge: alarm_acknowledge_permission + - alarm_resolve: alarm_resolve_permission + roles_operations: + - add: manage_role_permission + - remove: manage_role_permission + - update: manage_role_permission + - retrieve: view_role_users_permission + - retrieve_all: view_role_users_permission + - add_actions: manage_role_permission + - list_actions: view_role_users_permission + - check_actions_exists: view_role_users_permission + - remove_actions: manage_role_permission + - remove_all_actions: manage_role_permission + - add_members: add_role_users_permission + - list_members: view_role_users_permission + - check_members_exists: view_role_users_permission + - remove_members: remove_role_users_permission + - remove_all_members: remove_role_users_permission + +report: + operations: + - add: report_create_permission + - list: report_read_permission + - generate: report_read_permission + - view: read_permission + - update: update_permission + - update_schedule: update_permission + - enable: update_permission + - disable: update_permission + - delete: delete_permission + - update_template: update_permission + - view_template: read_permission + - delete_template: delete_permission + roles_operations: + - add: manage_role_permission + - remove: manage_role_permission + - update: manage_role_permission + - retrieve: view_role_users_permission + - retrieve_all: view_role_users_permission + - add_actions: manage_role_permission + - list_actions: view_role_users_permission + - check_actions_exists: view_role_users_permission + - remove_actions: manage_role_permission + - remove_all_actions: manage_role_permission + - add_members: add_role_users_permission + - list_members: view_role_users_permission + - check_members_exists: view_role_users_permission + - remove_members: remove_role_users_permission + - remove_all_members: remove_role_users_permission diff --git a/docker/rabbitmq/enabled_plugins b/docker/rabbitmq/enabled_plugins deleted file mode 100644 index 2561f4974..000000000 --- a/docker/rabbitmq/enabled_plugins +++ /dev/null @@ -1 +0,0 @@ -[rabbitmq_management,rabbitmq_mqtt,rabbitmq_web_mqtt]. diff --git a/docker/rabbitmq/rabbitmq.conf b/docker/rabbitmq/rabbitmq.conf deleted file mode 100644 index 31e326c6b..000000000 --- a/docker/rabbitmq/rabbitmq.conf +++ /dev/null @@ -1,15 +0,0 @@ -## DEFAULT SETTINGS ARE NOT MEANT TO BE TAKEN STRAIGHT INTO PRODUCTION -## see https://www.rabbitmq.com/configure.html for further information -## on configuring RabbitMQ - -## allow access to the guest user from anywhere on the network -## https://www.rabbitmq.com/access-control.html#loopback-users -## https://www.rabbitmq.com/production-checklist.html#users -loopback_users.guest = false - -## Send all logs to stdout/TTY. Necessary to see logs when running via -## a container -log.console = true - -## Enable anonymous connection -mqtt.allow_anonymous = true diff --git a/docker/seaweedfs/s3.json b/docker/seaweedfs/s3.json new file mode 100644 index 000000000..98b5e34ea --- /dev/null +++ b/docker/seaweedfs/s3.json @@ -0,0 +1,17 @@ +{ + "identities": [ + { + "name": "localuser", + "credentials": [ + { + "accessKey": "localKey", + "secretKey": "localSecret" + } + ], + "actions": ["Admin", "Read", "Write"] + } + ], + "s3": { + "region": "fra1" + } +} diff --git a/docker/spicedb/schema.zed b/docker/spicedb/schema.zed index ff0df7cda..b4336cea5 100644 --- a/docker/spicedb/schema.zed +++ b/docker/spicedb/schema.zed @@ -1,10 +1,13 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + definition user {} definition role { - relation entity: domain | group | channel | client + relation entity: domain | group | channel | client | rule | report relation member: user - relation built_in_role: domain | group | channel | client + relation built_in_role: domain | group | channel | client | rule | report permission delete = entity->manage_role_permission - built_in_role->manage_role_permission permission update = entity->manage_role_permission - built_in_role->manage_role_permission @@ -299,6 +302,29 @@ definition domain { relation group_remove_role_users: role#member | team#member relation group_view_role_users: role#member | team#member + relation alarm_update: role#member | team#member + relation alarm_read: role#member | team#member + relation alarm_delete: role#member | team#member + relation rule_create: role#member | team#member + relation rule_update: role#member | team#member + relation rule_read: role#member | team#member + relation rule_delete: role#member | team#member + relation rule_manage_role: role#member | team#member + relation rule_add_role_users: role#member | team#member + relation rule_remove_role_users: role#member | team#member + relation rule_view_role_users: role#member | team#member + relation alarm_assign: role#member | team#member + relation alarm_acknowledge: role#member | team#member + relation alarm_resolve: role#member | team#member + relation report_create: role#member | team#member + relation report_update: role#member | team#member + relation report_read: role#member | team#member + relation report_delete: role#member | team#member + relation report_manage_role: role#member | team#member + relation report_add_role_users: role#member | team#member + relation report_remove_role_users: role#member | team#member + relation report_view_role_users: role#member | team#member + permission update_permission = update + team->domain_update + organization->admin permission read_permission = read + team->domain_read + organization->admin permission enable_permission = enable + team->domain_update + organization->admin @@ -318,7 +344,9 @@ definition domain { channel_update + channel_read + channel_delete + channel_set_parent_group + channel_connect_to_client + channel_publish + channel_subscribe + channel_manage_role + channel_add_role_users + channel_remove_role_users + channel_view_role_users + group_update + group_membership + group_read + group_delete + group_set_child + group_set_parent + - group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users + organization->admin + group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users + + alarm_update + alarm_read + alarm_delete + rule_create + rule_update + rule_read + rule_delete + rule_manage_role + rule_add_role_users + rule_remove_role_users + rule_view_role_users + alarm_assign + alarm_acknowledge + alarm_resolve + report_create + report_update + report_read + report_delete + report_manage_role + report_add_role_users + report_remove_role_users + report_view_role_users + + organization->admin permission admin = (read & update & enable & disable & delete & manage_role & add_role_users & remove_role_users & view_role_users) + organization->admin @@ -362,6 +390,29 @@ definition domain { permission group_remove_role_users_permission = group_remove_role_users + team->group_remove_role_users + organization->admin permission group_view_role_users_permission = group_view_role_users + team->group_view_role_users + organization->admin + permission alarm_update_permission = alarm_update + team->alarm_update + organization->admin + permission alarm_read_permission = alarm_read + team->alarm_read + organization->admin + permission alarm_delete_permission = alarm_delete + team->alarm_delete + organization->admin + permission rule_create_permission = rule_create + team->rule_create + organization->admin + permission rule_update_permission = rule_update + team->rule_update + organization->admin + permission rule_read_permission = rule_read + team->rule_read + organization->admin + permission rule_delete_permission = rule_delete + team->rule_delete + organization->admin + permission rule_manage_role_permission = rule_manage_role + team->rule_manage_role + organization->admin + permission rule_add_role_users_permission = rule_add_role_users + team->rule_add_role_users + organization->admin + permission rule_remove_role_users_permission = rule_remove_role_users + team->rule_remove_role_users + organization->admin + permission rule_view_role_users_permission = rule_view_role_users + team->rule_view_role_users + organization->admin + permission alarm_assign_permission = alarm_assign + team->alarm_assign + organization->admin + permission alarm_acknowledge_permission = alarm_acknowledge + team->alarm_acknowledge + organization->admin + permission alarm_resolve_permission = alarm_resolve + team->alarm_resolve + organization->admin + permission report_create_permission = report_create + team->report_create + organization->admin + permission report_update_permission = report_update + team->report_update + organization->admin + permission report_read_permission = report_read + team->report_read + organization->admin + permission report_delete_permission = report_delete + team->report_delete + organization->admin + permission report_manage_role_permission = report_manage_role + team->report_manage_role + organization->admin + permission report_add_role_users_permission = report_add_role_users + team->report_add_role_users + organization->admin + permission report_remove_role_users_permission = report_remove_role_users + team->report_remove_role_users + organization->admin + permission report_view_role_users_permission = report_view_role_users + team->report_view_role_users + organization->admin + } // Add this relation and permission in future while adding organization @@ -451,6 +502,29 @@ definition team { relation group_remove_role_users: role#member | team#member relation group_view_role_users: role#member | team#member + relation alarm_update: role#member | team#member + relation alarm_read: role#member | team#member + relation alarm_delete: role#member | team#member + relation rule_create: role#member | team#member + relation rule_update: role#member | team#member + relation rule_read: role#member | team#member + relation rule_delete: role#member | team#member + relation rule_manage_role: role#member | team#member + relation rule_add_role_users: role#member | team#member + relation rule_remove_role_users: role#member | team#member + relation rule_view_role_users: role#member | team#member + relation alarm_assign: role#member | team#member + relation alarm_acknowledge: role#member | team#member + relation alarm_resolve: role#member | team#member + relation report_create: role#member | team#member + relation report_update: role#member | team#member + relation report_read: role#member | team#member + relation report_delete: role#member | team#member + relation report_manage_role: role#member | team#member + relation report_add_role_users: role#member | team#member + relation report_remove_role_users: role#member | team#member + relation report_view_role_users: role#member | team#member + permission delete_permission = delete + organization->team_delete + parent_team->subteam_delete + organization->admin permission update_permission = update + organization->team_update + parent_team->subteam_update + organization->admin permission read_permission = read + organization->team_read + parent_team->subteam_read + organization->admin @@ -521,3 +595,57 @@ definition platform { permission admin = administrator permission membership = administrator + member } + +definition rule { +relation domain: domain + +relation update: role#member +relation read: role#member +relation delete: role#member + +relation manage_role: role#member +relation add_role_users: role#member +relation remove_role_users: role#member +relation view_role_users: role#member + +relation alarm_read: role#member +relation alarm_assign: role#member +relation alarm_acknowledge: role#member +relation alarm_resolve: role#member + +permission update_permission = update + domain->rule_update_permission +permission read_permission = read + domain->rule_read_permission +permission delete_permission = delete + domain->rule_delete_permission + +permission manage_role_permission = manage_role + domain->rule_manage_role_permission +permission add_role_users_permission = add_role_users + domain->rule_add_role_users_permission +permission remove_role_users_permission = remove_role_users + domain->rule_remove_role_users_permission +permission view_role_users_permission = view_role_users + domain->rule_view_role_users_permission + +permission alarm_read_permission = alarm_read + domain->alarm_read_permission +permission alarm_assign_permission = alarm_assign + domain->alarm_assign_permission +permission alarm_acknowledge_permission = alarm_acknowledge + domain->alarm_acknowledge_permission +permission alarm_resolve_permission = alarm_resolve + domain->alarm_resolve_permission +} + +definition report { +relation domain: domain + +relation update: role#member +relation read: role#member +relation delete: role#member + +relation manage_role: role#member +relation add_role_users: role#member +relation remove_role_users: role#member +relation view_role_users: role#member + +permission update_permission = update + domain->report_update_permission +permission read_permission = read + domain->report_read_permission +permission delete_permission = delete + domain->report_delete_permission + +permission manage_role_permission = manage_role + domain->report_manage_role_permission +permission add_role_users_permission = add_role_users + domain->report_add_role_users_permission +permission remove_role_users_permission = remove_role_users + domain->report_remove_role_users_permission +permission view_role_users_permission = view_role_users + domain->report_view_role_users_permission +} diff --git a/docker/ssl/Makefile b/docker/ssl/Makefile index 0378d8dbb..965ae4862 100644 --- a/docker/ssl/Makefile +++ b/docker/ssl/Makefile @@ -79,17 +79,17 @@ ca: openssl req -newkey rsa:2048 -x509 -nodes -sha512 -days 1095 \ -keyout $(CRT_LOCATION)/ca.key -out $(CRT_LOCATION)/ca.crt -subj "/CN=$(CN_CA)/O=$(O)/OU=$(OU_CA)/emailAddress=$(EA)" -# Server cert and key name is "supermq-server". +# Server cert and key name is "magistrala-server". server_cert: - # Create supermq server key and CSR. - openssl req -new -sha256 -newkey rsa:4096 -nodes -keyout $(CRT_LOCATION)/supermq-server.key \ - -out $(CRT_LOCATION)/supermq-server.csr -subj "/CN=$(CN_SRV)/O=$(O)/OU=$(OU_CRT)/emailAddress=$(EA)" + # Create magistrala server key and CSR. + openssl req -new -sha256 -newkey rsa:4096 -nodes -keyout $(CRT_LOCATION)/magistrala-server.key \ + -out $(CRT_LOCATION)/magistrala-server.csr -subj "/CN=$(CN_SRV)/O=$(O)/OU=$(OU_CRT)/emailAddress=$(EA)" # Sign server CSR. - openssl x509 -req -days 1000 -in $(CRT_LOCATION)/supermq-server.csr -CA $(CRT_LOCATION)/ca.crt -CAkey $(CRT_LOCATION)/ca.key -CAcreateserial -out $(CRT_LOCATION)/supermq-server.crt + openssl x509 -req -days 1000 -in $(CRT_LOCATION)/magistrala-server.csr -CA $(CRT_LOCATION)/ca.crt -CAkey $(CRT_LOCATION)/ca.key -CAcreateserial -out $(CRT_LOCATION)/magistrala-server.crt # Remove CSR. - rm $(CRT_LOCATION)/supermq-server.csr + rm $(CRT_LOCATION)/magistrala-server.csr client_cert: # Create supermq server key and CSR. diff --git a/docker/ssl/certs/supermq-server.crt b/docker/ssl/certs/magistrala-server.crt similarity index 100% rename from docker/ssl/certs/supermq-server.crt rename to docker/ssl/certs/magistrala-server.crt diff --git a/docker/ssl/certs/supermq-server.key b/docker/ssl/certs/magistrala-server.key similarity index 100% rename from docker/ssl/certs/supermq-server.key rename to docker/ssl/certs/magistrala-server.key diff --git a/docker/ssl/placeholder b/docker/ssl/placeholder new file mode 100644 index 000000000..f5f101481 --- /dev/null +++ b/docker/ssl/placeholder @@ -0,0 +1 @@ +optional bind-mount placeholder diff --git a/docker/templates/re.tmpl b/docker/templates/re.tmpl new file mode 100644 index 000000000..3dad54580 --- /dev/null +++ b/docker/templates/re.tmpl @@ -0,0 +1,3 @@ +{{.Header}} +{{.Content}} +{{.Footer}} diff --git a/docker/templates/reports.tmpl b/docker/templates/reports.tmpl new file mode 100644 index 000000000..3dad54580 --- /dev/null +++ b/docker/templates/reports.tmpl @@ -0,0 +1,3 @@ +{{.Header}} +{{.Content}} +{{.Footer}} diff --git a/domains/README.md b/domains/README.md index 561603882..11f7fe371 100644 --- a/domains/README.md +++ b/domains/README.md @@ -10,49 +10,49 @@ The service is configured through environment variables (unset variables fall ba | Variable | Description | Default | | ------------------------------------ | -------------------------------------------------------------------------------------------- | ------------------------------------- | -| `SMQ_DOMAINS_LOG_LEVEL` | Log level for Domains (debug, info, warn, error) | debug | -| `SMQ_DOMAINS_HTTP_HOST` | Domains service HTTP host | domains | -| `SMQ_DOMAINS_HTTP_PORT` | Domains service HTTP port | 9003 | -| `SMQ_DOMAINS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | -| `SMQ_DOMAINS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | -| `SMQ_DOMAINS_GRPC_PORT` | Domains service gRPC port | 7003 | -| `SMQ_DOMAINS_GRPC_SERVER_CERT` | Path to PEM-encoded gRPC server certificate | "" | -| `SMQ_DOMAINS_GRPC_SERVER_KEY` | Path to PEM-encoded gRPC server key | "" | -| `SMQ_DOMAINS_GRPC_SERVER_CA_CERTS` | Path to trusted CA bundle for the gRPC server | "" | -| `SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS` | Path to client CA bundle to require gRPC mTLS | "" | -| `SMQ_DOMAINS_DB_HOST` | Database host address | domains-db | -| `SMQ_DOMAINS_DB_PORT` | Database host port | 5432 | -| `SMQ_DOMAINS_DB_USER` | Database user | supermq | -| `SMQ_DOMAINS_DB_PASS` | Database password | supermq | -| `SMQ_DOMAINS_DB_NAME` | Name of the database used by the service | domains | -| `SMQ_DOMAINS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | "" | -| `SMQ_DOMAINS_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | -| `SMQ_DOMAINS_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | -| `SMQ_DOMAINS_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | -| `SMQ_DOMAINS_CACHE_URL` | Cache database URL | redis://domains-redis:6379/0 | -| `SMQ_DOMAINS_CACHE_KEY_DURATION` | Cache key duration for domain status/route lookups | 10m | -| `SMQ_DOMAINS_INSTANCE_ID` | Domains instance ID (auto-generated when empty) | "" | -| `SMQ_SPICEDB_HOST` | SpiceDB host for policy checks | supermq-spicedb | -| `SMQ_SPICEDB_PORT` | SpiceDB port | 50051 | -| `SMQ_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file used to seed available actions | ./docker/spicedb/schema.schema.zed | -| `SMQ_SPICEDB_PRE_SHARED_KEY` | SpiceDB preshared key | 12345678 | -| `SMQ_ES_URL` | Event store URL | nats://localhost:4222 | -| `SMQ_JAEGER_URL` | Jaeger server URL | | -| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | -| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | "" | -| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 1s | -| `SMQ_AUTH_GRPC_CLIENT_CERT` | Path to the PEM-encoded Auth gRPC client certificate | "" | -| `SMQ_AUTH_GRPC_CLIENT_KEY` | Path to the PEM-encoded Auth gRPC client key | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM-encoded Auth gRPC trusted CA bundle | "" | -| `SMQ_DOMAINS_CALLOUT_URLS` | Comma-separated list of HTTP callout targets invoked on domain operations | "" | -| `SMQ_DOMAINS_CALLOUT_METHOD` | HTTP method for callouts (POST or GET) | POST | -| `SMQ_DOMAINS_CALLOUT_TLS_VERIFICATION` | Verify TLS certificates for callouts | true | -| `SMQ_DOMAINS_CALLOUT_TIMEOUT` | Callout request timeout | 10s | -| `SMQ_DOMAINS_CALLOUT_KEY` | Client key for mTLS callouts | "" | -| `SMQ_DOMAINS_CALLOUT_OPERATIONS` | Comma-separated list of operation names that should trigger callouts | "" | +| `MG_DOMAINS_LOG_LEVEL` | Log level for Domains (debug, info, warn, error) | debug | +| `MG_DOMAINS_HTTP_HOST` | Domains service HTTP host | domains | +| `MG_DOMAINS_HTTP_PORT` | Domains service HTTP port | 9003 | +| `MG_DOMAINS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | +| `MG_DOMAINS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | +| `MG_DOMAINS_GRPC_PORT` | Domains service gRPC port | 7003 | +| `MG_DOMAINS_GRPC_SERVER_CERT` | Path to PEM-encoded gRPC server certificate | "" | +| `MG_DOMAINS_GRPC_SERVER_KEY` | Path to PEM-encoded gRPC server key | "" | +| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Path to trusted CA bundle for the gRPC server | "" | +| `MG_DOMAINS_GRPC_CLIENT_CA_CERTS` | Path to client CA bundle to require gRPC mTLS | "" | +| `MG_DOMAINS_DB_HOST` | Database host address | domains-db | +| `MG_DOMAINS_DB_PORT` | Database host port | 5432 | +| `MG_DOMAINS_DB_USER` | Database user | supermq | +| `MG_DOMAINS_DB_PASS` | Database password | supermq | +| `MG_DOMAINS_DB_NAME` | Name of the database used by the service | domains | +| `MG_DOMAINS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | "" | +| `MG_DOMAINS_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | +| `MG_DOMAINS_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | +| `MG_DOMAINS_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | +| `MG_DOMAINS_CACHE_URL` | Cache database URL | redis://domains-redis:6379/0 | +| `MG_DOMAINS_CACHE_KEY_DURATION` | Cache key duration for domain status/route lookups | 10m | +| `MG_DOMAINS_INSTANCE_ID` | Domains instance ID (auto-generated when empty) | "" | +| `MG_SPICEDB_HOST` | SpiceDB host for policy checks | supermq-spicedb | +| `MG_SPICEDB_PORT` | SpiceDB port | 50051 | +| `MG_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file used to seed available actions | ./docker/spicedb/schema.schema.zed | +| `MG_SPICEDB_PRE_SHARED_KEY` | SpiceDB preshared key | 12345678 | +| `MG_ES_URL` | Event store URL | nats://localhost:4222 | +| `MG_JAEGER_URL` | Jaeger server URL | | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | +| `MG_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | +| `MG_AUTH_GRPC_URL` | Auth service gRPC URL | "" | +| `MG_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 1s | +| `MG_AUTH_GRPC_CLIENT_CERT` | Path to the PEM-encoded Auth gRPC client certificate | "" | +| `MG_AUTH_GRPC_CLIENT_KEY` | Path to the PEM-encoded Auth gRPC client key | "" | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM-encoded Auth gRPC trusted CA bundle | "" | +| `MG_DOMAINS_CALLOUT_URLS` | Comma-separated list of HTTP callout targets invoked on domain operations | "" | +| `MG_DOMAINS_CALLOUT_METHOD` | HTTP method for callouts (POST or GET) | POST | +| `MG_DOMAINS_CALLOUT_TLS_VERIFICATION` | Verify TLS certificates for callouts | true | +| `MG_DOMAINS_CALLOUT_TIMEOUT` | Callout request timeout | 10s | +| `MG_DOMAINS_CALLOUT_KEY` | Client key for mTLS callouts | "" | +| `MG_DOMAINS_CALLOUT_OPERATIONS` | Comma-separated list of operation names that should trigger callouts | "" | -**Note**: Set `SMQ_DOMAINS_CALLOUT_OPERATIONS` to a subset of `OpCreateDomain`, `OpRetrieveDomain`, `OpUpdateDomain`, `OpEnableDomain`, `OpDisableDomain`, `OpFreezeDomain`, `OpListDomains`, `OpViewDomainInvitation`, `OpSendInvitation`, `OpAcceptInvitation`, `OpListInvitations`, `OpListDomainInvitations`, `OpRejectInvitation`, or `OpDeleteInvitation` to filter which actions produce callouts. +**Note**: Set `MG_DOMAINS_CALLOUT_OPERATIONS` to a subset of `OpCreateDomain`, `OpRetrieveDomain`, `OpUpdateDomain`, `OpEnableDomain`, `OpDisableDomain`, `OpFreezeDomain`, `OpListDomains`, `OpViewDomainInvitation`, `OpSendInvitation`, `OpAcceptInvitation`, `OpListInvitations`, `OpListDomainInvitations`, `OpRejectInvitation`, or `OpDeleteInvitation` to filter which actions produce callouts. ## Deployment @@ -72,48 +72,48 @@ make domains make install # set the environment variables and run the service -SMQ_DOMAINS_LOG_LEVEL=debug \ -SMQ_DOMAINS_CACHE_URL=redis://domains-redis:6379/0 \ -SMQ_DOMAINS_CACHE_KEY_DURATION=10m \ -SMQ_DOMAINS_HTTP_HOST=domains \ -SMQ_DOMAINS_HTTP_PORT=9003 \ -SMQ_DOMAINS_HTTP_SERVER_CERT="" \ -SMQ_DOMAINS_HTTP_SERVER_KEY="" \ -SMQ_DOMAINS_GRPC_HOST=domains \ -SMQ_DOMAINS_GRPC_PORT=7003 \ -SMQ_DOMAINS_GRPC_SERVER_CERT="" \ -SMQ_DOMAINS_GRPC_SERVER_KEY="" \ -SMQ_DOMAINS_GRPC_SERVER_CA_CERTS="" \ -SMQ_DOMAINS_GRPC_CLIENT_CA_CERTS="" \ -SMQ_DOMAINS_DB_HOST=domains-db \ -SMQ_DOMAINS_DB_PORT=5432 \ -SMQ_DOMAINS_DB_USER=supermq \ -SMQ_DOMAINS_DB_PASS=supermq \ -SMQ_DOMAINS_DB_NAME=domains \ -SMQ_DOMAINS_DB_SSL_MODE="" \ -SMQ_DOMAINS_DB_SSL_CERT="" \ -SMQ_DOMAINS_DB_SSL_KEY="" \ -SMQ_DOMAINS_DB_SSL_ROOT_CERT="" \ -SMQ_AUTH_GRPC_URL="" \ -SMQ_AUTH_GRPC_TIMEOUT=1s \ -SMQ_AUTH_GRPC_CLIENT_CERT="" \ -SMQ_AUTH_GRPC_CLIENT_KEY="" \ -SMQ_AUTH_GRPC_SERVER_CA_CERTS="" \ -SMQ_SPICEDB_HOST=localhost \ -SMQ_SPICEDB_PORT=50051 \ -SMQ_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.schema.zed \ -SMQ_SPICEDB_PRE_SHARED_KEY=12345678 \ -SMQ_ES_URL=nats://localhost:4222 \ -SMQ_JAEGER_URL= \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_DOMAINS_CALLOUT_URLS="" \ -SMQ_DOMAINS_CALLOUT_METHOD=POST \ -SMQ_DOMAINS_CALLOUT_TLS_VERIFICATION=true \ -SMQ_DOMAINS_CALLOUT_TIMEOUT=10s \ -SMQ_DOMAINS_CALLOUT_KEY="" \ -SMQ_DOMAINS_CALLOUT_OPERATIONS="" \ -SMQ_SEND_TELEMETRY=true \ -SMQ_DOMAINS_INSTANCE_ID="" \ +MG_DOMAINS_LOG_LEVEL=debug \ +MG_DOMAINS_CACHE_URL=redis://domains-redis:6379/0 \ +MG_DOMAINS_CACHE_KEY_DURATION=10m \ +MG_DOMAINS_HTTP_HOST=domains \ +MG_DOMAINS_HTTP_PORT=9003 \ +MG_DOMAINS_HTTP_SERVER_CERT="" \ +MG_DOMAINS_HTTP_SERVER_KEY="" \ +MG_DOMAINS_GRPC_HOST=domains \ +MG_DOMAINS_GRPC_PORT=7003 \ +MG_DOMAINS_GRPC_SERVER_CERT="" \ +MG_DOMAINS_GRPC_SERVER_KEY="" \ +MG_DOMAINS_GRPC_SERVER_CA_CERTS="" \ +MG_DOMAINS_GRPC_CLIENT_CA_CERTS="" \ +MG_DOMAINS_DB_HOST=domains-db \ +MG_DOMAINS_DB_PORT=5432 \ +MG_DOMAINS_DB_USER=supermq \ +MG_DOMAINS_DB_PASS=supermq \ +MG_DOMAINS_DB_NAME=domains \ +MG_DOMAINS_DB_SSL_MODE="" \ +MG_DOMAINS_DB_SSL_CERT="" \ +MG_DOMAINS_DB_SSL_KEY="" \ +MG_DOMAINS_DB_SSL_ROOT_CERT="" \ +MG_AUTH_GRPC_URL="" \ +MG_AUTH_GRPC_TIMEOUT=1s \ +MG_AUTH_GRPC_CLIENT_CERT="" \ +MG_AUTH_GRPC_CLIENT_KEY="" \ +MG_AUTH_GRPC_SERVER_CA_CERTS="" \ +MG_SPICEDB_HOST=localhost \ +MG_SPICEDB_PORT=50051 \ +MG_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.schema.zed \ +MG_SPICEDB_PRE_SHARED_KEY=12345678 \ +MG_ES_URL=nats://localhost:4222 \ +MG_JAEGER_URL= \ +MG_JAEGER_TRACE_RATIO=1.0 \ +MG_DOMAINS_CALLOUT_URLS="" \ +MG_DOMAINS_CALLOUT_METHOD=POST \ +MG_DOMAINS_CALLOUT_TLS_VERIFICATION=true \ +MG_DOMAINS_CALLOUT_TIMEOUT=10s \ +MG_DOMAINS_CALLOUT_KEY="" \ +MG_DOMAINS_CALLOUT_OPERATIONS="" \ +MG_SEND_TELEMETRY=true \ +MG_DOMAINS_INSTANCE_ID="" \ $GOBIN/supermq-domains ``` @@ -360,9 +360,9 @@ curl -X GET http://localhost:9004/domains/roles/available-actions \ - Domains and invitations are persisted in PostgreSQL; migrations also create role tables with a `domains_` prefix. - Redis caches domain status and route-to-ID lookups to speed up authorization. -- Domain lifecycle events are published to the configured event store (`SMQ_ES_URL`). +- Domain lifecycle events are published to the configured event store (`MG_ES_URL`). - Authorization and role checks are enforced via SpiceDB-backed policy service. -- Optional HTTP callouts can be triggered before operations, using the `SMQ_DOMAINS_CALLOUT_*` settings. +- Optional HTTP callouts can be triggered before operations, using the `MG_DOMAINS_CALLOUT_*` settings. - Observability: Jaeger tracing, Prometheus metrics at `/metrics`, and a `/health` endpoint. ### Domains Table @@ -400,7 +400,7 @@ curl -X GET http://localhost:9004/domains/roles/available-actions \ - Prefer `disable` over delete when you need reversible off-boarding; use `freeze` for emergency locks by admins. - Keep role definitions minimal; grant only the actions needed and audit with `list-role-members`. - Clean up stale invitations regularly using the domain/user invitation listing endpoints. -- When enabling callouts, narrow `SMQ_DOMAINS_CALLOUT_OPERATIONS` to the events you must observe. +- When enabling callouts, narrow `MG_DOMAINS_CALLOUT_OPERATIONS` to the events you must observe. ## Versioning and Health Check diff --git a/domains/events/streams.go b/domains/events/streams.go index e6ab29285..f4dc79fd2 100644 --- a/domains/events/streams.go +++ b/domains/events/streams.go @@ -43,7 +43,7 @@ type eventStore struct { // NewEventStoreMiddleware returns wrapper around auth service that sends // events to event store. func NewEventStoreMiddleware(ctx context.Context, svc domains.Service, url string) (domains.Service, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "domains-es-pub") if err != nil { return nil, err } diff --git a/domains/middleware/authorization.go b/domains/middleware/authorization.go index 37bd3b603..7c9c7a738 100644 --- a/domains/middleware/authorization.go +++ b/domains/middleware/authorization.go @@ -124,7 +124,7 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut SubjectType: policies.UserType, SubjectKind: policies.UsersKind, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, nil); err != nil { return domains.Domain{}, err @@ -250,7 +250,7 @@ func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, } if err := am.authz.Authorize(ctx, req, nil); err == nil { @@ -269,7 +269,7 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }, nil); err != nil { return err } diff --git a/domains/service.go b/domains/service.go index 0803a6bf5..ec9854431 100644 --- a/domains/service.go +++ b/domains/service.go @@ -85,7 +85,7 @@ func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Do optionalPolicies := []policies.Policy{ { - Subject: policies.SuperMQObject, + Subject: policies.MagistralaObject, SubjectType: policies.PlatformType, Relation: "organization", Object: d.ID, diff --git a/fluxmq/api/grpc/doc.go b/fluxmq/api/grpc/doc.go new file mode 100644 index 000000000..2810433da --- /dev/null +++ b/fluxmq/api/grpc/doc.go @@ -0,0 +1,7 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package grpc contains the FluxMQ auth callout gRPC server implementation. +// It bridges FluxMQ broker authentication and authorization requests to +// SuperMQ's Clients and Channels services. +package grpc diff --git a/fluxmq/api/grpc/server.go b/fluxmq/api/grpc/server.go new file mode 100644 index 000000000..bbfab7d9f --- /dev/null +++ b/fluxmq/api/grpc/server.go @@ -0,0 +1,163 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + "strings" + + "connectrpc.com/connect" + authv1 "github.com/absmach/fluxmq/pkg/proto/auth/v1" + "github.com/absmach/fluxmq/pkg/proto/auth/v1/authv1connect" + grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" + grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + apiutil "github.com/absmach/supermq/api/http/util" + smqauth "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/connections" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/policies" +) + +var _ authv1connect.AuthServiceHandler = (*connectServer)(nil) + +type connectServer struct { + authv1connect.UnimplementedAuthServiceHandler + clients grpcClientsV1.ClientsServiceClient + channels grpcChannelsV1.ChannelsServiceClient + parser messaging.TopicParser +} + +// NewServer creates a FluxMQ AuthService Connect handler that bridges to +// SuperMQ's Clients (authn) and Channels (authz) services. +func NewServer( + clients grpcClientsV1.ClientsServiceClient, + channels grpcChannelsV1.ChannelsServiceClient, + parser messaging.TopicParser, +) authv1connect.AuthServiceHandler { + return &connectServer{ + clients: clients, + channels: channels, + parser: parser, + } +} + +func (s *connectServer) Authenticate(ctx context.Context, req *connect.Request[authv1.AuthnReq]) (*connect.Response[authv1.AuthnRes], error) { + username := req.Msg.GetUsername() + password := req.Msg.GetPassword() + + token := authn.AuthPack(authn.BasicAuth, username, password) + res, err := s.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token}) + if err != nil { + if !shouldTryDomainAuth(req.Msg, username, password) { + return nil, encodeError(err) + } + + token = authn.AuthPack(authn.DomainAuth, username, password) + res, err = s.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token}) + if err != nil { + return nil, encodeError(err) + } + } + + return connect.NewResponse(&authv1.AuthnRes{ + Authenticated: res.GetAuthenticated(), + Id: res.GetId(), + }), nil +} + +func (s *connectServer) Authorize(ctx context.Context, req *connect.Request[authv1.AuthzReq]) (*connect.Response[authv1.AuthzRes], error) { + connType := connections.ConnType(req.Msg.GetAction()) + if err := connections.CheckConnType(connType); err != nil { + return nil, encodeError(err) + } + + var domainID, channelID string + var topicType messaging.TopicType + var err error + + switch connType { + case connections.Publish: + domainID, channelID, _, topicType, err = s.parser.ParsePublishTopic(ctx, req.Msg.GetTopic(), true) + case connections.Subscribe: + domainID, channelID, _, topicType, err = s.parser.ParseSubscribeTopic(ctx, req.Msg.GetTopic(), true) + } + if err != nil { + if shouldDenyAuthorize(err) { + return connect.NewResponse(&authv1.AuthzRes{Authorized: false}), nil + } + return nil, encodeError(err) + } + + if topicType == messaging.HealthType { + return connect.NewResponse(&authv1.AuthzRes{Authorized: true}), nil + } + + ar := &grpcChannelsV1.AuthzReq{ + Type: uint32(connType), + ClientId: req.Msg.GetExternalId(), + ClientType: policies.ClientType, + ChannelId: channelID, + DomainId: domainID, + } + res, err := s.channels.Authorize(ctx, ar) + if err != nil { + if shouldDenyAuthorize(err) { + return connect.NewResponse(&authv1.AuthzRes{Authorized: false}), nil + } + return nil, encodeError(err) + } + + return connect.NewResponse(&authv1.AuthzRes{ + Authorized: res.GetAuthorized(), + }), nil +} + +func shouldTryDomainAuth(msg *authv1.AuthnReq, username, password string) bool { + if username == "" || password == "" { + return false + } + + return strings.HasPrefix(msg.GetClientId(), "http:") +} + +func shouldDenyAuthorize(err error) bool { + if err == nil { + return false + } + + switch { + case errors.Contains(err, svcerr.ErrAuthorization), + errors.Contains(err, svcerr.ErrNotFound), + errors.Contains(err, errors.ErrMalformedEntity), + errors.Contains(err, messaging.ErrMalformedTopic), + err == apiutil.ErrMissingID: + return true + } + + // Backward compatibility for gRPC client layers that may return + // Internal with a payload containing "entity not found". + return strings.Contains(err.Error(), svcerr.ErrNotFound.Error()) +} + +func encodeError(err error) error { + switch { + case errors.Contains(err, nil): + return nil + case errors.Contains(err, errors.ErrMalformedEntity), + err == apiutil.ErrMissingID: + return connect.NewError(connect.CodeInvalidArgument, err) + case errors.Contains(err, svcerr.ErrAuthentication), + errors.Contains(err, smqauth.ErrKeyExpired): + return connect.NewError(connect.CodeUnauthenticated, err) + case errors.Contains(err, svcerr.ErrAuthorization): + return connect.NewError(connect.CodePermissionDenied, err) + case errors.Contains(err, messaging.ErrMalformedTopic): + return connect.NewError(connect.CodeInvalidArgument, err) + default: + return connect.NewError(connect.CodeInternal, err) + } +} diff --git a/go.mod b/go.mod index 929ad9f63..b20c10102 100644 --- a/go.mod +++ b/go.mod @@ -3,42 +3,53 @@ module github.com/absmach/supermq go 1.26.0 require ( + connectrpc.com/otelconnect v0.9.0 github.com/0x6flab/namegenerator v1.4.0 github.com/absmach/callhome v0.18.2 - github.com/absmach/mgate v0.5.0 + github.com/absmach/certs v0.18.5 + github.com/absmach/fluxmq v0.0.0-20260401001416-3f9be65b7db7 github.com/absmach/senml v1.0.8 github.com/authzed/authzed-go v1.8.0 github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8 github.com/authzed/spicedb v1.50.0 + github.com/caarlos0/env/v10 v10.0.0 github.com/caarlos0/env/v11 v11.4.0 - github.com/cenkalti/backoff/v4 v4.3.0 github.com/dgraph-io/ristretto/v2 v2.4.0 github.com/eclipse/paho.mqtt.golang v1.5.1 github.com/fatih/color v1.19.0 + github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba github.com/go-chi/chi/v5 v5.2.5 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid/v5 v5.4.0 github.com/google/uuid v1.6.0 + github.com/gookit/color v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f + github.com/ivanpirog/coloredcobra v1.0.1 + github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 github.com/jackc/pgtype v1.14.4 github.com/jackc/pgx/v5 v5.9.1 github.com/jmoiron/sqlx v1.4.0 github.com/lestrrat-go/jwx/v2 v2.1.6 github.com/lib/pq v1.12.0 + github.com/mitchellh/mapstructure v1.5.0 github.com/nats-io/nats.go v1.49.0 github.com/oklog/ulid/v2 v2.1.1 + github.com/openbao/openbao/api/v2 v2.5.1 github.com/ory/dockertest/v3 v3.12.0 github.com/pelletier/go-toml v1.9.5 - github.com/pion/dtls/v3 v3.1.2 github.com/plgd-dev/go-coap/v3 v3.4.2 github.com/prometheus/client_golang v1.23.2 - github.com/rabbitmq/amqp091-go v1.10.0 github.com/redis/go-redis/v9 v9.18.0 github.com/rubenv/sql-migrate v1.8.1 + github.com/slack-go/slack v0.19.0 github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 github.com/sqids/sqids-go v0.4.1 github.com/stretchr/testify v1.11.1 + github.com/traefik/yaegi v0.16.1 + github.com/vadv/gopher-lua-libs v0.8.0 + github.com/yuin/gopher-lua v1.1.1 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 go.opentelemetry.io/otel v1.42.0 @@ -49,6 +60,7 @@ require ( golang.org/x/crypto v0.49.0 golang.org/x/oauth2 v0.36.0 golang.org/x/sync v0.20.0 + gonum.org/v1/gonum v0.17.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 google.golang.org/grpc v1.79.3 google.golang.org/protobuf v1.36.11 @@ -62,7 +74,9 @@ require ( buf.build/go/protovalidate v1.1.0 // indirect cel.dev/expr v0.25.1 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect + connectrpc.com/connect v1.19.1 dario.cat/mergo v1.0.2 // indirect + filippo.io/edwards25519 v1.1.1 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect @@ -70,6 +84,7 @@ require ( github.com/authzed/cel-go v0.20.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/ccoveille/go-safecast/v2 v2.0.0 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -88,27 +103,38 @@ require ( github.com/emirpasic/gods v1.18.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // indirect github.com/go-errors/errors v1.5.1 // indirect github.com/go-gorp/gorp/v3 v3.1.0 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-kit/log v0.2.1 // indirect github.com/go-logfmt/logfmt v0.6.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zerologr v1.2.3 // indirect + github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/google/cel-go v0.26.1 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.7 // indirect + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jzelinskie/stringz v0.0.3 // indirect - github.com/klauspost/compress v1.18.2 // indirect + github.com/klauspost/compress v1.18.5 // indirect github.com/lestrrat-go/blackmagic v1.0.4 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect github.com/lestrrat-go/httprc v1.0.6 // indirect @@ -116,6 +142,7 @@ require ( github.com/lestrrat-go/option v1.0.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.30 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/moby/api v1.54.0 // indirect github.com/moby/moby/client v0.3.0 // indirect @@ -129,6 +156,8 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/runc v1.2.8 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/pion/dtls/v3 v3.1.2 // indirect github.com/pion/logging v0.2.4 // indirect github.com/pion/transport/v4 v4.0.1 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240917153116-6f2963f01587 // indirect @@ -136,18 +165,26 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.19.2 // indirect + github.com/rabbitmq/amqp091-go v1.10.0 github.com/rs/zerolog v1.34.0 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect github.com/segmentio/asm v1.2.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/smarty/assertions v1.16.0 // indirect github.com/smartystreets/goconvey v1.8.1 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/stoewer/go-strcase v1.3.1 // indirect github.com/stretchr/objx v0.5.3 // indirect + github.com/subosito/gotenv v1.6.0 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect @@ -155,10 +192,12 @@ require ( go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20251017212417-90e834f514db // indirect - golang.org/x/net v0.51.0 // indirect + golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect + golang.org/x/time v0.15.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect sigs.k8s.io/controller-runtime v0.22.4 // indirect ) diff --git a/go.sum b/go.sum index 7f97eed6c..ffcec2f54 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +al.essio.dev/pkg/shellescape v1.5.1/go.mod h1:6sIqp7X2P6mThCQ7twERpZTuigpr6KbZWtls1U8I890= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1 h1:j9yeqTWEFrtimt8Nng2MIeRrpoCvQzM9/g25XTvqUGg= buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20251209175733-2a1774d88802.1/go.mod h1:tvtbpgaVXZX4g6Pn+AnzFycuRK3MOz5HJfEGeEllXYM= buf.build/go/protovalidate v1.1.0 h1:pQqEQRpOo4SqS60qkvmhLTTQU9JwzEvdyiqAtXa5SeY= @@ -5,8 +6,13 @@ buf.build/go/protovalidate v1.1.0/go.mod h1:bGZcPiAQDC3ErCHK3t74jSoJDFOs2JH3d7LW cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= +connectrpc.com/otelconnect v0.9.0 h1:NggB3pzRC3pukQWaYbRHJulxuXvmCKCKkQ9hbrHAWoA= +connectrpc.com/otelconnect v0.9.0/go.mod h1:AEkVLjCPXra+ObGFCOClcJkNjS7zPaQSqvO0lCyjfZc= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= @@ -25,14 +31,22 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/VividCortex/gohistogram v1.0.0 h1:6+hBz+qvs0JOrrNhhmR7lFxo5sINxBCGXrdtl/UvroE= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= github.com/absmach/callhome v0.18.2 h1:dmopRHm2qTheHN1hdUKRRYpKwRrj7X9d8AWCFrb+K6s= github.com/absmach/callhome v0.18.2/go.mod h1:LEXKhES9JJtj3tBgTZv7VPNjOi5ukJQB0mFic0QP60Q= -github.com/absmach/mgate v0.5.0 h1:RV2Aalra3xIm+XTs13TM7iE7v4WTL2SKhKcPbKr22Ac= -github.com/absmach/mgate v0.5.0/go.mod h1:0KVq7mxM0wayosmyXPPxp1EL0c2d9kRp5V8NZCKdetA= +github.com/absmach/certs v0.18.5 h1:eYlvitou+LoDtt7ETVLTp6d/1xCejGL3EmVOg+rHGTU= +github.com/absmach/certs v0.18.5/go.mod h1:31dtVe1VYF16W+IvjAE/uPAIz4f3uLHgh+moBezjqIc= +github.com/absmach/fluxmq v0.0.0-20260401001416-3f9be65b7db7 h1:cNeNb3ngHX6mfPrbxbDD8dpNh82UJ0wDVEiyBjP+C8c= +github.com/absmach/fluxmq v0.0.0-20260401001416-3f9be65b7db7/go.mod h1:MSpCAYY2IHv5fQovhQr24610E0AHHgM2A9UUP0zbKco= github.com/absmach/senml v1.0.8 h1:+opem/r4g6c6eA/JLyCIuksyEhj7eBdysY3pEmy1mqo= github.com/absmach/senml v1.0.8/go.mod h1:DRhzHLgvQoIUHroBgpFrSWso+bJZO9E96RlHAHy+VRI= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= github.com/authzed/authzed-go v1.8.0 h1:cRka8J8QXGl+nyNrhsiPSFJUluIG1tuTXnG8ad2LZ1Y= @@ -43,7 +57,12 @@ github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8 h1:y17oq4U8n+k1Oc github.com/authzed/grpcutil v0.0.0-20250221190651-1985b19b35b8/go.mod h1:Pf1ZSi41EePvx1GC1DeEJw5dn35iUcxZHqpHuG1Rpic= github.com/authzed/spicedb v1.50.0 h1:dnIGrYaWDN8KRWABdidiFaaN1h4y0lNDIbsh5bZbhNY= github.com/authzed/spicedb v1.50.0/go.mod h1:kV6L+7b1bDVeoHfKPSJt+uLHDUl9hAT/yjNYayF3iyM= +github.com/aws/aws-sdk-go v1.34.0/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go v1.40.45 h1:QN1nsY27ssD/JmW4s83qmSb+uL6DG4GmCDzjmJB4xUI= +github.com/aws/aws-sdk-go v1.40.45/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/brianvoe/gofakeit/v6 v6.28.0 h1:Xib46XXuQfmlLS2EXRuJpqcw8St6qSZz75OUo0tgAW4= @@ -52,8 +71,12 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/caarlos0/env/v10 v10.0.0 h1:yIHUBZGsyqCnpTkbjk8asUlx6RFhhEs+h7TOBdgdzXA= +github.com/caarlos0/env/v10 v10.0.0/go.mod h1:ZfulV76NvVPw3tm591U4SwL3Xx9ldzBP9aGxzeN7G18= github.com/caarlos0/env/v11 v11.4.0 h1:Kcb6t5kIIr4XkoQC9AF2j+8E1Jsrl3Wz/hhm1LtoGAc= github.com/caarlos0/env/v11 v11.4.0/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= +github.com/cbroglie/mustache v1.0.1 h1:ivMg8MguXq/rrz2eu3tw6g3b16+PQhoTn6EZAhst2mw= +github.com/cbroglie/mustache v1.0.1/go.mod h1:R/RUa+SobQ14qkP4jtx5Vke5sDytONDQXNLPY/PO69g= github.com/ccoveille/go-safecast/v2 v2.0.0 h1:+5eyITXAUj3wMjad6cRVJKGnC7vDS55zk0INzJagub0= github.com/ccoveille/go-safecast/v2 v2.0.0/go.mod h1:JIYA4CAR33blIDuE6fSwCp2sz1oOBahXnvmdBhOAABs= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -63,8 +86,13 @@ github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F9 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d h1:S2NE3iHSwP0XV47EEXL8mWmRdEfGscSJ+7EgePNgt0s= github.com/certifi/gocertifi v0.0.0-20210507211836-431795d63e8d/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cheggaaa/pb/v3 v3.0.5/go.mod h1:X1L61/+36nz9bjIsrDU52qHKOQukUQe2Ge+YvGuquCw= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= @@ -77,6 +105,8 @@ github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmC github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= @@ -103,6 +133,7 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE= @@ -115,10 +146,18 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= +github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba h1:vBqABUa2HUSc6tj22Tw+ZMVGHuBzKtljM38kbRanmrM= +github.com/fiorix/go-smpp v0.0.0-20210403173735-2894b96e70ba/go.mod h1:VfKFK7fGeCP81xEhbrOqUEh45n73Yy6jaPWwTVbxprI= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= @@ -127,11 +166,17 @@ github.com/go-errors/errors v1.5.1 h1:ZwEMSLRCapFLflTpT7NKaAc7ukJ8ZPEjzlxt8rPN8b github.com/go-errors/errors v1.5.1/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= github.com/go-kit/kit v0.13.0/go.mod h1:phqEHMMUbyrCFCTgH48JueqrM3md2HcAZ8N3XE4FKDg= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.6.1 h1:4hvbpePJKnIzH1B+8OR/JPbTx37NktoI9LE2QZBBkvE= github.com/go-logfmt/logfmt v0.6.1/go.mod h1:EV2pOAQoZaT1ZXZbqDl5hrymndi4SY9ED9/z6CO0XAk= @@ -142,12 +187,15 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zerologr v1.2.3 h1:up5N9vcH9Xck3jJkXzgyOxozT14R47IyDODz8LM1KSs= github.com/go-logr/zerologr v1.2.3/go.mod h1:BxwGo7y5zgSHYR1BjbnHPyF/5ZjVKfKxAZANVu6E8Ho= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= @@ -156,19 +204,34 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid/v5 v5.4.0 h1:EfbpCTjqMuGyq5ZJwxqzn3Cbr2d0rUZU7v5ycAk/e/0= github.com/gofrs/uuid/v5 v5.4.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/cel-go v0.26.1 h1:iPbVVEdkhTX++hpe3lzSk7D3G3QSYqLGoHOcEio+UXQ= github.com/google/cel-go v0.26.1/go.mod h1:A9O8OU9rdvrK5MQyrqfIxo1a0u4g3sF8KB6PUIaryMM= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250820193118-f64d9cf942d6 h1:EEHtgt9IwisQ2AZ4pIsMjahcegHh6rmhqxzIRQIyepY= github.com/google/pprof v0.0.0-20250820193118-f64d9cf942d6/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -176,6 +239,10 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gookit/assert v0.1.1 h1:lh3GcawXe/p+cU7ESTZ5Ui3Sm/x8JWpIis4/1aF0mY0= +github.com/gookit/assert v0.1.1/go.mod h1:jS5bmIVQZTIwk42uXl4lyj4iaaxx32tqH16CFj0VX2E= +github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA= +github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -184,10 +251,32 @@ github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDa github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= +github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f h1:7LYC+Yfkj3CTRcShK0KOL/w6iTiKyqqBA9a41Wnggw8= github.com/hokaccha/go-prettyjson v0.0.0-20211117102719-0474bc63780f/go.mod h1:pFlLw2CfqZiIBOx6BuCeRLCrfxBJipTY0nIOF/VbGcI= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/ivanpirog/coloredcobra v1.0.1 h1:aURSdEmlR90/tSiWS0dMjdwOvCVUeYLfltLfbgNxrN4= +github.com/ivanpirog/coloredcobra v1.0.1/go.mod h1:iho4nEKcnwZFiniGSdcgdvRgZNjxm+h20acv8vqmN6Q= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= @@ -201,6 +290,8 @@ github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8 github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= +github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6 h1:D/V0gu4zQ3cL2WKeVNVM4r2gLxGGf6McLwgXzRTo2RQ= +github.com/jackc/pgerrcode v0.0.0-20250907135507-afb5586c32a6/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= @@ -243,20 +334,31 @@ github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmoiron/sqlx v1.4.0 h1:1PLqN7S1UYp5t4SrVVnt4nUVNemrDAtxlulVe+Qgm3o= github.com/jmoiron/sqlx v1.4.0/go.mod h1:ZrZ7UsYB/weZdl2Bxg6jCRO9c3YHl8r3ahlKmRT4JLY= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/jzelinskie/stringz v0.0.3 h1:0GhG3lVMYrYtIvRbxvQI6zqRTT1P1xyQlpa0FhfUXas= github.com/jzelinskie/stringz v0.0.3/go.mod h1:hHYbgxJuNLRw91CmpuFsYEOyQqpDVFg8pvEh23vy4P0= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -287,20 +389,30 @@ github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.12.0 h1:mC1zeiNamwKBecjHarAr26c/+d8V5w/u4J0I/yASbJo= github.com/lib/pq v1.12.0/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-sqlite3 v1.14.3/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.30 h1:bVreufq3EAIG1Quvws73du3/QgdeZ3myglJlrzSYYCY= github.com/mattn/go-sqlite3 v1.14.30/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mitchellh/mapstructure v1.3.2/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/moby/api v1.54.0 h1:7kbUgyiKcoBhm0UrWbdrMs7RX8dnwzURKVbZGy2GnL0= @@ -311,8 +423,15 @@ github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo= github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/montanaflynn/stats v0.6.3/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/nats.go v1.49.0 h1:yh/WvY59gXqYpgl33ZI+XoVPKyut/IcEaqtsiuTJpoE= github.com/nats-io/nats.go v1.49.0/go.mod h1:fDCn3mN5cY8HooHwE2ukiLb4p4G4ImmzvXyJt+tGwdw= github.com/nats-io/nkeys v0.4.12 h1:nssm7JKOG9/x4J8II47VWCL1Ds29avyiQDRn0ckMvDc= @@ -325,6 +444,8 @@ github.com/onsi/ginkgo/v2 v2.27.2 h1:LzwLj0b89qtIy6SSASkzlNvX6WktqurSHwkk2ipF/Ns github.com/onsi/ginkgo/v2 v2.27.2/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A= github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k= +github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o= +github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -337,12 +458,15 @@ github.com/ory/dockertest/v3 v3.12.0/go.mod h1:aKNDTva3cp8dwOWwb9cWuX84aH5akkxXR github.com/pborman/getopt v0.0.0-20170112200414-7148bc3a4c30/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pion/dtls/v3 v3.1.2 h1:gqEdOUXLtCGW+afsBLO0LtDD8GnuBBjEy6HRtyofZTc= github.com/pion/dtls/v3 v3.1.2/go.mod h1:Hw/igcX4pdY69z1Hgv5x7wJFrUkdgHwAn/Q/uo7YHRo= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/planetscale/vtprotobuf v0.6.1-0.20240917153116-6f2963f01587 h1:xzZOeCMQLA/W198ZkdVdt4EKFKJtS26B773zNU377ZY= @@ -354,13 +478,27 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= @@ -380,25 +518,45 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/rubenv/sql-migrate v1.8.1 h1:EPNwCvjAowHI3TnZ+4fQu3a915OpnQoPAjTXCGOy2U0= github.com/rubenv/sql-migrate v1.8.1/go.mod h1:BTIKBORjzyxZDS6dzoiw6eAFYJ1iNlGAtjn4LGeVjS8= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/slack-go/slack v0.19.0 h1:J8lL/nGTsIUX53HU8YxZeI3PDkA+sxZsFrI2Dew7h44= +github.com/slack-go/slack v0.19.0/go.mod h1:K81UmCivcYd/5Jmz8vLBfuyoZ3B4rQC2GHVXHteXiAE= github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY= github.com/smarty/assertions v1.16.0/go.mod h1:duaaFdCS0K9dnoM50iyek/eYINOZ64gbh1Xlf6LG7AI= github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.4.0/go.mod h1:Wo4iy3BUC+X2Fybo0PDqwJIv3dNRiZLHQymsfxlB84g= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/sqids/sqids-go v0.4.1 h1:eQKYzmAZbLlRwHeHYPF35QhgxwZHLnlmVj9AkIj/rrw= github.com/sqids/sqids-go v0.4.1/go.mod h1:EMwHuPQgSNFS0A49jESTfIQS+066XQTVhukrzEPScl8= github.com/stoewer/go-strcase v1.3.1 h1:iS0MdW+kVTxgMoE1LAZyMiYJFKlOzLooE4MxjirtkAs= @@ -421,6 +579,15 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/technoweenie/multipartstreamer v1.0.1 h1:XRztA5MXiR1TIRHxH2uNxXxaIkKQDeX7m2XsSOlQEnM= +github.com/technoweenie/multipartstreamer v1.0.1/go.mod h1:jNVxdtShOxzAsukZwTSw6MDx5eUJoiEBsSvzDU9uzog= +github.com/traefik/yaegi v0.16.1 h1:f1De3DVJqIDKmnasUF6MwmWv1dSEEat0wcpXhD2On3E= +github.com/traefik/yaegi v0.16.1/go.mod h1:4eVhbPb3LnD2VigQjhYbEJ69vDRFdT2HQNrXx8eEwUY= +github.com/urfave/cli v1.22.5/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= +github.com/vadv/gopher-lua-libs v0.8.0 h1:u2GVTj32Wnmu8RpSxeAdlTf9mYZrrm9ALKGYmvnvvZQ= +github.com/vadv/gopher-lua-libs v0.8.0/go.mod h1:iNYvPoNV6ur7xJj4Uj3hEVebv8Z0/MoeM1igsXQbv8g= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -430,9 +597,16 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHo github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +github.com/yuin/gluamapper v0.0.0-20150323120927-d836955830e7 h1:noHsffKZsNfU38DwcXWEPldrTjIZ8FPNKx8mYMGnqjs= +github.com/yuin/gluamapper v0.0.0-20150323120927-d836955830e7/go.mod h1:bbMEM6aU1WDF1ErA5YJ0p91652pGv140gGw4Ww3RGp8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v0.0.0-20200816102855-ee81675732da/go.mod h1:E1AXubJBdNmFERAOucpDIxNzeGfLzg0mYh+UfMWdChA= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= @@ -481,6 +655,7 @@ go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -512,34 +687,44 @@ golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -547,13 +732,21 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -582,6 +775,9 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -607,8 +803,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= @@ -625,20 +821,37 @@ google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8 google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc h1:LMEBgNcZUqXaP7evD1PZcL6EcDVa2QOFuI+cqM3+AJM= +gopkg.in/xmlpath.v2 v2.0.0-20150820204837-860cbeca3ebc/go.mod h1:N8UOSI6/c2yOpa/XDz3KVUiegocTziPiqNkeNTMiG1k= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/groups/README.md b/groups/README.md index 22470f680..5f6e4ea13 100644 --- a/groups/README.md +++ b/groups/README.md @@ -10,53 +10,53 @@ The service is configured via environment variables (unset values fall back to d | Variable | Description | Default | | -------------------------------------- | ------------------------------------------------------------------------------------------------- | ------------------------------------- | -| `SMQ_GROUPS_LOG_LEVEL` | Log level for Groups (debug, info, warn, error) | debug | -| `SMQ_GROUPS_HTTP_HOST` | Groups service HTTP host | groups | -| `SMQ_GROUPS_HTTP_PORT` | Groups service HTTP port | 9004 | -| `SMQ_GROUPS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | -| `SMQ_GROUPS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | -| `SMQ_GROUPS_HTTP_SERVER_CA_CERTS` | Path to trusted CA bundle for the HTTP server | "" | -| `SMQ_GROUPS_HTTP_CLIENT_CA_CERTS` | Path to client CA bundle to require HTTP mTLS | "" | -| `SMQ_GROUPS_GRPC_HOST` | Groups service gRPC host | groups | -| `SMQ_GROUPS_GRPC_PORT` | Groups service gRPC port | 7004 | -| `SMQ_GROUPS_GRPC_SERVER_CERT` | Path to PEM-encoded gRPC server certificate | "" | -| `SMQ_GROUPS_GRPC_SERVER_KEY` | Path to PEM-encoded gRPC server key | "" | -| `SMQ_GROUPS_GRPC_SERVER_CA_CERTS` | Path to trusted CA bundle for the gRPC server | "" | -| `SMQ_GROUPS_GRPC_CLIENT_CA_CERTS` | Path to client CA bundle to require gRPC mTLS | "" | -| `SMQ_GROUPS_DB_HOST` | Database host address | groups-db | -| `SMQ_GROUPS_DB_PORT` | Database host port | 5432 | -| `SMQ_GROUPS_DB_USER` | Database user | supermq | -| `SMQ_GROUPS_DB_PASS` | Database password | supermq | -| `SMQ_GROUPS_DB_NAME` | Name of the database used by the service | groups | -| `SMQ_GROUPS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| `SMQ_GROUPS_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | -| `SMQ_GROUPS_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | -| `SMQ_GROUPS_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | -| `SMQ_GROUPS_INSTANCE_ID` | Groups instance ID (auto-generated when empty) | "" | -| `SMQ_GROUPS_EVENT_CONSUMER` | NATS consumer name for domain events | groups | -| `SMQ_SPICEDB_HOST` | SpiceDB host for policy checks | supermq-spicedb | -| `SMQ_SPICEDB_PORT` | SpiceDB port | 50051 | -| `SMQ_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file used to seed available actions | "/schema.zed" | -| `SMQ_SPICEDB_PRE_SHARED_KEY` | SpiceDB preshared key | 12345678 | -| `SMQ_ES_URL` | Event store URL | nats://nats:4222 | -| `SMQ_JAEGER_URL` | Jaeger server URL | | -| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | -| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | "" | -| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 1s | -| `SMQ_AUTH_GRPC_CLIENT_CERT` | Path to the PEM-encoded Auth gRPC client certificate | "" | -| `SMQ_AUTH_GRPC_CLIENT_KEY` | Path to the PEM-encoded Auth gRPC client key | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM-encoded Auth gRPC trusted CA bundle | "" | -| `SMQ_GROUPS_CALLOUT_URLS` | Comma-separated list of HTTP callout targets invoked on group operations | "" | -| `SMQ_GROUPS_CALLOUT_METHOD` | HTTP method for callouts (POST or GET) | POST | -| `SMQ_GROUPS_CALLOUT_TLS_VERIFICATION` | Verify TLS certificates for callouts | false | -| `SMQ_GROUPS_CALLOUT_TIMEOUT` | Callout request timeout | 10s | -| `SMQ_GROUPS_CALLOUT_CA_CERT` | CA bundle for verifying callout targets | "" | -| `SMQ_GROUPS_CALLOUT_CERT` | Client certificate for mTLS callouts | "" | -| `SMQ_GROUPS_CALLOUT_KEY` | Client key for mTLS callouts | "" | -| `SMQ_GROUPS_CALLOUT_OPERATIONS` | Comma-separated list of operation names that should trigger callouts | "" | +| `MG_GROUPS_LOG_LEVEL` | Log level for Groups (debug, info, warn, error) | debug | +| `MG_GROUPS_HTTP_HOST` | Groups service HTTP host | groups | +| `MG_GROUPS_HTTP_PORT` | Groups service HTTP port | 9004 | +| `MG_GROUPS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | +| `MG_GROUPS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | +| `MG_GROUPS_HTTP_SERVER_CA_CERTS` | Path to trusted CA bundle for the HTTP server | "" | +| `MG_GROUPS_HTTP_CLIENT_CA_CERTS` | Path to client CA bundle to require HTTP mTLS | "" | +| `MG_GROUPS_GRPC_HOST` | Groups service gRPC host | groups | +| `MG_GROUPS_GRPC_PORT` | Groups service gRPC port | 7004 | +| `MG_GROUPS_GRPC_SERVER_CERT` | Path to PEM-encoded gRPC server certificate | "" | +| `MG_GROUPS_GRPC_SERVER_KEY` | Path to PEM-encoded gRPC server key | "" | +| `MG_GROUPS_GRPC_SERVER_CA_CERTS` | Path to trusted CA bundle for the gRPC server | "" | +| `MG_GROUPS_GRPC_CLIENT_CA_CERTS` | Path to client CA bundle to require gRPC mTLS | "" | +| `MG_GROUPS_DB_HOST` | Database host address | groups-db | +| `MG_GROUPS_DB_PORT` | Database host port | 5432 | +| `MG_GROUPS_DB_USER` | Database user | supermq | +| `MG_GROUPS_DB_PASS` | Database password | supermq | +| `MG_GROUPS_DB_NAME` | Name of the database used by the service | groups | +| `MG_GROUPS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| `MG_GROUPS_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | +| `MG_GROUPS_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | +| `MG_GROUPS_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | +| `MG_GROUPS_INSTANCE_ID` | Groups instance ID (auto-generated when empty) | "" | +| `MG_GROUPS_EVENT_CONSUMER` | NATS consumer name for domain events | groups | +| `MG_SPICEDB_HOST` | SpiceDB host for policy checks | supermq-spicedb | +| `MG_SPICEDB_PORT` | SpiceDB port | 50051 | +| `MG_SPICEDB_SCHEMA_FILE` | Path to SpiceDB schema file used to seed available actions | "/schema.zed" | +| `MG_SPICEDB_PRE_SHARED_KEY` | SpiceDB preshared key | 12345678 | +| `MG_ES_URL` | Event store URL | nats://nats:4222 | +| `MG_JAEGER_URL` | Jaeger server URL | | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | +| `MG_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | +| `MG_AUTH_GRPC_URL` | Auth service gRPC URL | "" | +| `MG_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 1s | +| `MG_AUTH_GRPC_CLIENT_CERT` | Path to the PEM-encoded Auth gRPC client certificate | "" | +| `MG_AUTH_GRPC_CLIENT_KEY` | Path to the PEM-encoded Auth gRPC client key | "" | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM-encoded Auth gRPC trusted CA bundle | "" | +| `MG_GROUPS_CALLOUT_URLS` | Comma-separated list of HTTP callout targets invoked on group operations | "" | +| `MG_GROUPS_CALLOUT_METHOD` | HTTP method for callouts (POST or GET) | POST | +| `MG_GROUPS_CALLOUT_TLS_VERIFICATION` | Verify TLS certificates for callouts | false | +| `MG_GROUPS_CALLOUT_TIMEOUT` | Callout request timeout | 10s | +| `MG_GROUPS_CALLOUT_CA_CERT` | CA bundle for verifying callout targets | "" | +| `MG_GROUPS_CALLOUT_CERT` | Client certificate for mTLS callouts | "" | +| `MG_GROUPS_CALLOUT_KEY` | Client key for mTLS callouts | "" | +| `MG_GROUPS_CALLOUT_OPERATIONS` | Comma-separated list of operation names that should trigger callouts | "" | -**Note**: Set `SMQ_GROUPS_CALLOUT_OPERATIONS` to a subset of `OpCreateGroup`, `OpViewGroup`, `OpUpdateGroup`, `OpEnableGroup`, `OpDisableGroup`, `OpDeleteGroup`, `OpListGroups`, `OpHierarchy`, `OpAddParentGroup`, `OpRemoveParentGroup`, `OpAddChildrenGroups`, `OpRemoveChildrenGroups`, `OpRemoveAllChildrenGroups`, or `OpListChildrenGroups` to filter which actions produce callouts. +**Note**: Set `MG_GROUPS_CALLOUT_OPERATIONS` to a subset of `OpCreateGroup`, `OpViewGroup`, `OpUpdateGroup`, `OpEnableGroup`, `OpDisableGroup`, `OpDeleteGroup`, `OpListGroups`, `OpHierarchy`, `OpAddParentGroup`, `OpRemoveParentGroup`, `OpAddChildrenGroups`, `OpRemoveChildrenGroups`, `OpRemoveAllChildrenGroups`, or `OpListChildrenGroups` to filter which actions produce callouts. ## Deployment @@ -76,63 +76,63 @@ make groups make install # set the environment variables and run the service -SMQ_GROUPS_LOG_LEVEL=debug \ -SMQ_GROUPS_HTTP_HOST=groups \ -SMQ_GROUPS_HTTP_PORT=9004 \ -SMQ_GROUPS_HTTP_SERVER_CERT="" \ -SMQ_GROUPS_HTTP_SERVER_KEY="" \ -SMQ_GROUPS_GRPC_HOST=groups \ -SMQ_GROUPS_GRPC_PORT=7004 \ -SMQ_GROUPS_GRPC_SERVER_CERT="" \ -SMQ_GROUPS_GRPC_SERVER_KEY="" \ -SMQ_GROUPS_GRPC_SERVER_CA_CERTS="" \ -SMQ_GROUPS_GRPC_CLIENT_CA_CERTS="" \ -SMQ_GROUPS_DB_HOST=groups-db \ -SMQ_GROUPS_DB_PORT=5432 \ -SMQ_GROUPS_DB_USER=supermq \ -SMQ_GROUPS_DB_PASS=supermq \ -SMQ_GROUPS_DB_NAME=groups \ -SMQ_GROUPS_DB_SSL_MODE=disable \ -SMQ_GROUPS_DB_SSL_CERT="" \ -SMQ_GROUPS_DB_SSL_KEY="" \ -SMQ_GROUPS_DB_SSL_ROOT_CERT="" \ -SMQ_AUTH_GRPC_URL="" \ -SMQ_AUTH_GRPC_TIMEOUT=1s \ -SMQ_AUTH_GRPC_CLIENT_CERT="" \ -SMQ_AUTH_GRPC_CLIENT_KEY="" \ -SMQ_AUTH_GRPC_SERVER_CA_CERTS="" \ -SMQ_DOMAINS_GRPC_URL=domains:7003 \ -SMQ_DOMAINS_GRPC_TIMEOUT=1s \ -SMQ_DOMAINS_GRPC_CLIENT_CERT="" \ -SMQ_DOMAINS_GRPC_CLIENT_KEY="" \ -SMQ_DOMAINS_GRPC_SERVER_CA_CERTS="" \ -SMQ_CHANNELS_GRPC_URL=channels:7005 \ -SMQ_CHANNELS_GRPC_TIMEOUT=1s \ -SMQ_CHANNELS_GRPC_CLIENT_CERT="" \ -SMQ_CHANNELS_GRPC_CLIENT_KEY="" \ -SMQ_CHANNELS_GRPC_SERVER_CA_CERTS="" \ -SMQ_CLIENTS_GRPC_URL=clients:7000 \ -SMQ_CLIENTS_GRPC_TIMEOUT=1s \ -SMQ_CLIENTS_GRPC_CLIENT_CERT="" \ -SMQ_CLIENTS_GRPC_CLIENT_KEY="" \ -SMQ_CLIENTS_GRPC_SERVER_CA_CERTS="" \ -SMQ_SPICEDB_HOST=localhost \ -SMQ_SPICEDB_PORT=50051 \ -SMQ_SPICEDB_SCHEMA_FILE=schema.zed \ -SMQ_SPICEDB_PRE_SHARED_KEY=12345678 \ -SMQ_ES_URL=nats://localhost:4222 \ -SMQ_JAEGER_URL= \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_GROUPS_CALLOUT_URLS="" \ -SMQ_GROUPS_CALLOUT_METHOD=POST \ -SMQ_GROUPS_CALLOUT_TLS_VERIFICATION=false \ -SMQ_GROUPS_CALLOUT_TIMEOUT=10s \ -SMQ_GROUPS_CALLOUT_CA_CERT="" \ -SMQ_GROUPS_CALLOUT_CERT="" \ -SMQ_GROUPS_CALLOUT_KEY="" \ -SMQ_GROUPS_CALLOUT_OPERATIONS="" \ -SMQ_SEND_TELEMETRY=true \ -SMQ_GROUPS_INSTANCE_ID="" \ +MG_GROUPS_LOG_LEVEL=debug \ +MG_GROUPS_HTTP_HOST=groups \ +MG_GROUPS_HTTP_PORT=9004 \ +MG_GROUPS_HTTP_SERVER_CERT="" \ +MG_GROUPS_HTTP_SERVER_KEY="" \ +MG_GROUPS_GRPC_HOST=groups \ +MG_GROUPS_GRPC_PORT=7004 \ +MG_GROUPS_GRPC_SERVER_CERT="" \ +MG_GROUPS_GRPC_SERVER_KEY="" \ +MG_GROUPS_GRPC_SERVER_CA_CERTS="" \ +MG_GROUPS_GRPC_CLIENT_CA_CERTS="" \ +MG_GROUPS_DB_HOST=groups-db \ +MG_GROUPS_DB_PORT=5432 \ +MG_GROUPS_DB_USER=supermq \ +MG_GROUPS_DB_PASS=supermq \ +MG_GROUPS_DB_NAME=groups \ +MG_GROUPS_DB_SSL_MODE=disable \ +MG_GROUPS_DB_SSL_CERT="" \ +MG_GROUPS_DB_SSL_KEY="" \ +MG_GROUPS_DB_SSL_ROOT_CERT="" \ +MG_AUTH_GRPC_URL="" \ +MG_AUTH_GRPC_TIMEOUT=1s \ +MG_AUTH_GRPC_CLIENT_CERT="" \ +MG_AUTH_GRPC_CLIENT_KEY="" \ +MG_AUTH_GRPC_SERVER_CA_CERTS="" \ +MG_DOMAINS_GRPC_URL=domains:7003 \ +MG_DOMAINS_GRPC_TIMEOUT=1s \ +MG_DOMAINS_GRPC_CLIENT_CERT="" \ +MG_DOMAINS_GRPC_CLIENT_KEY="" \ +MG_DOMAINS_GRPC_SERVER_CA_CERTS="" \ +MG_CHANNELS_GRPC_URL=channels:7005 \ +MG_CHANNELS_GRPC_TIMEOUT=1s \ +MG_CHANNELS_GRPC_CLIENT_CERT="" \ +MG_CHANNELS_GRPC_CLIENT_KEY="" \ +MG_CHANNELS_GRPC_SERVER_CA_CERTS="" \ +MG_CLIENTS_GRPC_URL=clients:7000 \ +MG_CLIENTS_GRPC_TIMEOUT=1s \ +MG_CLIENTS_GRPC_CLIENT_CERT="" \ +MG_CLIENTS_GRPC_CLIENT_KEY="" \ +MG_CLIENTS_GRPC_SERVER_CA_CERTS="" \ +MG_SPICEDB_HOST=localhost \ +MG_SPICEDB_PORT=50051 \ +MG_SPICEDB_SCHEMA_FILE=schema.zed \ +MG_SPICEDB_PRE_SHARED_KEY=12345678 \ +MG_ES_URL=nats://localhost:4222 \ +MG_JAEGER_URL= \ +MG_JAEGER_TRACE_RATIO=1.0 \ +MG_GROUPS_CALLOUT_URLS="" \ +MG_GROUPS_CALLOUT_METHOD=POST \ +MG_GROUPS_CALLOUT_TLS_VERIFICATION=false \ +MG_GROUPS_CALLOUT_TIMEOUT=10s \ +MG_GROUPS_CALLOUT_CA_CERT="" \ +MG_GROUPS_CALLOUT_CERT="" \ +MG_GROUPS_CALLOUT_KEY="" \ +MG_GROUPS_CALLOUT_OPERATIONS="" \ +MG_SEND_TELEMETRY=true \ +MG_GROUPS_INSTANCE_ID="" \ $GOBIN/supermq-groups ``` @@ -252,9 +252,9 @@ curl -X GET http://localhost:9004//groups/roles/available-actions \ - Groups are stored in PostgreSQL with `ltree` paths for hierarchy queries; domain migrations are applied alongside group migrations for referential integrity. - Role tables are provisioned per entity with a `groups_` prefix. -- Event notifications are published to `SMQ_ES_URL`; domain events are consumed to keep group data aligned. +- Event notifications are published to `MG_ES_URL`; domain events are consumed to keep group data aligned. - Authorization and roles are enforced through SpiceDB and shared policy middleware. -- Optional HTTP callouts (pre-operation hooks) are controlled via `SMQ_GROUPS_CALLOUT_*`. +- Optional HTTP callouts (pre-operation hooks) are controlled via `MG_GROUPS_CALLOUT_*`. - Observability: Jaeger tracing, Prometheus metrics at `/metrics`, and a `/health` endpoint. ### Groups Table @@ -281,7 +281,7 @@ curl -X GET http://localhost:9004//groups/roles/available-actions \ - Prefer `disable` before `delete` when you need reversible off-boarding. - Use roles sparingly and audit with `list-role-members`; grant only required actions. - Fetch children with bounded levels to keep queries efficient. -- Limit callouts to necessary operations via `SMQ_GROUPS_CALLOUT_OPERATIONS`. +- Limit callouts to necessary operations via `MG_GROUPS_CALLOUT_OPERATIONS`. ## Versioning and Health Check diff --git a/groups/events/streams.go b/groups/events/streams.go index ae480af3a..a1f468c8e 100644 --- a/groups/events/streams.go +++ b/groups/events/streams.go @@ -46,7 +46,7 @@ type eventStore struct { // NewEventStoreMiddleware returns wrapper around clients service that sends // events to event store. func New(ctx context.Context, svc groups.Service, url string) (groups.Service, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "groups-es-pub") if err != nil { return nil, err } diff --git a/groups/middleware/authorization.go b/groups/middleware/authorization.go index 6e9b3e989..3a775dd73 100644 --- a/groups/middleware/authorization.go +++ b/groups/middleware/authorization.go @@ -375,7 +375,7 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }, nil); err != nil { return err } diff --git a/http/README.md b/http/README.md deleted file mode 100644 index 2246574bc..000000000 --- a/http/README.md +++ /dev/null @@ -1,170 +0,0 @@ -# HTTP Adapter - -The HTTP Adapter exposes HTTP endpoints for publishing messages and WebSocket capabilities for publishing and subscribing to messages from SuperMQ channels. It authenticates clients via tokens or Basic auth, resolves domains/channels over gRPC, and forwards payloads to the message broker. - -For more on SuperMQ, see the [official documentation][doc]. - -## Configuration - -Environment variables (unset values fall back to defaults): - -| Variable | Description | Default | -| ------------------------------------- | ---------------------------------------------------- | ------------------------------ | -| `SMQ_HTTP_ADAPTER_LOG_LEVEL` | Log level (debug, info, warn, error) | debug | -| `SMQ_HTTP_ADAPTER_HOST` | HTTP Adapter host | http-adapter | -| `SMQ_HTTP_ADAPTER_PORT` | HTTP Adapter port | 8008 | -| `SMQ_HTTP_ADAPTER_SERVER_CERT` | Path to PEM-encoded server certificate (enables TLS) | "" | -| `SMQ_HTTP_ADAPTER_SERVER_KEY` | Path to PEM-encoded server key | "" | -| `SMQ_HTTP_ADAPTER_SERVER_CA_CERTS` | Trusted CA bundle for HTTPS server | "" | -| `SMQ_HTTP_ADAPTER_CLIENT_CA_CERTS` | Client CA bundle to require mTLS on HTTPS server | "" | -| `SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS` | Cache counters for topic parsing | 200000 | -| `SMQ_HTTP_ADAPTER_CACHE_MAX_COST` | Maximum cache size (bytes) | 1048576 | -| `SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS` | Cache buffer items | 64 | -| `SMQ_MESSAGE_BROKER_URL` | Message broker URL (publishing target) | nats://nats:4222 | -| `SMQ_ES_URL` | Event store URL (publishing middleware) | nats://nats:4222 | -| `SMQ_JAEGER_URL` | Jaeger tracing endpoint | | -| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to SuperMQ call-home server | true | -| `SMQ_HTTP_ADAPTER_INSTANCE_ID` | Service instance ID (auto-generated when empty) | "" | -| `SMQ_CLIENTS_GRPC_URL` | Clients service gRPC URL | clients:7006 | -| `SMQ_CLIENTS_GRPC_TIMEOUT` | Clients gRPC request timeout | 300s | -| `SMQ_CLIENTS_GRPC_CLIENT_CERT` | Clients gRPC client certificate | "" | -| `SMQ_CLIENTS_GRPC_CLIENT_KEY` | Clients gRPC client key | "" | -| `SMQ_CLIENTS_GRPC_SERVER_CA_CERTS` | Clients gRPC trusted CA bundle | "" | -| `SMQ_CHANNELS_GRPC_URL` | Channels service gRPC URL | channels:7005 | -| `SMQ_CHANNELS_GRPC_TIMEOUT` | Channels gRPC request timeout | 300s | -| `SMQ_CHANNELS_GRPC_CLIENT_CERT` | Channels gRPC client certificate | "" | -| `SMQ_CHANNELS_GRPC_CLIENT_KEY` | Channels gRPC client key | "" | -| `SMQ_CHANNELS_GRPC_SERVER_CA_CERTS` | Channels gRPC trusted CA bundle | "" | -| `SMQ_DOMAINS_GRPC_URL` | Domains service gRPC URL | domains:7003 | -| `SMQ_DOMAINS_GRPC_TIMEOUT` | Domains gRPC request timeout | 300s | -| `SMQ_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client certificate | "" | -| `SMQ_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key | "" | -| `SMQ_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC trusted CA bundle | "" | -| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | auth:7001 | -| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC request timeout | 300s | -| `SMQ_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client certificate | "" | -| `SMQ_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC trusted CA bundle | "" | - -## Deployment - -The adapter is shipped as a Docker container. See the [`http-adapter` section](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml#L1226-L1365) of `docker-compose.yaml` for deployment details. - -To build and run locally: - -```bash -# download the latest version of the service -git clone https://github.com/absmach/supermq -cd supermq - -# compile the http adapter -make http - -# copy binary to $GOBIN -make install - -# set the environment variables and run the service -SMQ_HTTP_ADAPTER_LOG_LEVEL=debug \ -SMQ_HTTP_ADAPTER_HOST=http-adapter \ -SMQ_HTTP_ADAPTER_PORT=8008 \ -SMQ_HTTP_ADAPTER_SERVER_CERT="" \ -SMQ_HTTP_ADAPTER_SERVER_KEY="" \ -SMQ_HTTP_ADAPTER_CACHE_NUM_COUNTERS=200000 \ -SMQ_HTTP_ADAPTER_CACHE_MAX_COST=1048576 \ -SMQ_HTTP_ADAPTER_CACHE_BUFFER_ITEMS=64 \ -SMQ_MESSAGE_BROKER_URL=nats://nats:4222 \ -SMQ_ES_URL=nats://nats:4222 \ -SMQ_JAEGER_URL= \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_CLIENTS_GRPC_URL=clients:7006 \ -SMQ_CLIENTS_GRPC_TIMEOUT=300s \ -SMQ_CLIENTS_GRPC_CLIENT_CERT="" \ -SMQ_CLIENTS_GRPC_CLIENT_KEY="" \ -SMQ_CLIENTS_GRPC_SERVER_CA_CERTS="" \ -SMQ_CHANNELS_GRPC_URL=channels:7005 \ -SMQ_CHANNELS_GRPC_TIMEOUT=300s \ -SMQ_CHANNELS_GRPC_CLIENT_CERT="" \ -SMQ_CHANNELS_GRPC_CLIENT_KEY="" \ -SMQ_CHANNELS_GRPC_SERVER_CA_CERTS="" \ -SMQ_DOMAINS_GRPC_URL=domains:7003 \ -SMQ_DOMAINS_GRPC_TIMEOUT=300s \ -SMQ_DOMAINS_GRPC_CLIENT_CERT="" \ -SMQ_DOMAINS_GRPC_CLIENT_KEY="" \ -SMQ_DOMAINS_GRPC_SERVER_CA_CERTS="" \ -SMQ_AUTH_GRPC_URL=auth:7001 \ -SMQ_AUTH_GRPC_TIMEOUT=300s \ -SMQ_AUTH_GRPC_CLIENT_CERT="" \ -SMQ_AUTH_GRPC_CLIENT_KEY="" \ -SMQ_AUTH_GRPC_SERVER_CA_CERTS="" \ -SMQ_SEND_TELEMETRY=true \ -SMQ_HTTP_ADAPTER_INSTANCE_ID="" \ -$GOBIN/supermq-http -``` - -TLS is enabled by setting `SMQ_HTTP_ADAPTER_SERVER_CERT` and `SMQ_HTTP_ADAPTER_SERVER_KEY`. mTLS is enabled when `SMQ_HTTP_ADAPTER_CLIENT_CA_CERTS` is provided. gRPC client TLS/mTLS is enabled by setting the corresponding client cert/key/CA variables. - -## Usage - -Endpoints: - -- `POST /m/{domain}/c/{channel}` (and wildcard `/m/{domain}/c/{channel}/*`): publish a message. -- `POST /hc/{domain}`: health-check message path (authenticated). -- `GET /health`: service health probe. -- `GET /metrics`: Prometheus metrics. - -Authentication: - -- Bearer token in `Authorization` header, or -- Basic auth where the password is the token (username ignored). - -Supported content types: `application/json`, `application/senml+json`, `application/senml+cbor`. - -Example publish: - -```bash -curl -X POST http://localhost:8008/m//c//sub/topic \ - -H "Authorization: Bearer " \ - -H "Content-Type: application/json" \ - -d '{ "temp": 22.5, "unit": "C" }' -``` - -## Implementation Details - -- Publishes to the configured message broker (`SMQ_MESSAGE_BROKER_URL`) with optional event-store middleware (`SMQ_ES_URL`). -- Resolves domains and channels over gRPC to validate/route topics; authenticates via Auth gRPC; validates client identity via Clients gRPC. -- Topic parsing is cached (Ristretto) with configurable counters/cost/buffers to reduce resolver calls. -- Observability: Jaeger tracing, Prometheus metrics at `/metrics`, service health at `/health`. -- Optional call-home telemetry is enabled by default. - -## Best Practices - -- Use domain/channel routes consistently in publish paths; include subtopics to segment data. -- Keep cache defaults unless load patterns require tuning; monitor `/metrics` for cache hit ratios. -- Enable TLS/mTLS for production deployments (HTTP server and gRPC clients). -- Reuse a single broker URL across services (often NATS) to simplify operations. - -## Versioning and Health Check - -The adapter exposes `/health` with status and build metadata. - -```bash -curl -X GET http://localhost:8008/health \ - -H "accept: application/health+json" -``` - -Example response: - -```json -{ - "status": "pass", - "version": "0.18.0", - "commit": "7d6f4dc4f7f0c1fa3dc24eddfb18bb5073ff4f62", - "description": "http adapter", - "build_time": "1970-01-01_00:00:00" -} -``` - -For endpoint details, see the [HTTP Adapter API documentation](https://docs.api.supermq.absmach.eu/?urls.primaryName=http.yaml). - -[doc]: https://docs.supermq.absmach.eu/ diff --git a/http/adapter.go b/http/adapter.go deleted file mode 100644 index 21c5ad9bc..000000000 --- a/http/adapter.go +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "strings" - - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - apiutil "github.com/absmach/supermq/api/http/util" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var ( - // ErrFailedSubscription indicates that client couldn't subscribe to specified channel. - ErrFailedSubscription = errors.New("failed to subscribe to a channel") - // ErrFailedPublish indicates that client couldn't publish to specified channel. - ErrFailedSubscribe = errors.New("failed to unsubscribe from topic") - // ErrEmptyTopic indicate absence of clientKey in the request. - ErrEmptyTopic = errors.New("empty topic") -) - -// Service specifies web socket service API. -type Service interface { - // Subscribe subscribes message from the broker using the clientKey for authorization, - // the channelID for subscription and domainID specifies the domain for authorization. - // Subtopic is optional. - // If the subscription is successful, nil is returned otherwise error is returned. - Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *Client) error - - Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error -} - -var _ Service = (*adapterService)(nil) - -type adapterService struct { - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - authn smqauthn.Authentication - pubsub messaging.PubSub -} - -// NewService instantiates the HTTP adapter implementation. -func NewService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub messaging.PubSub) Service { - return &adapterService{ - clients: clients, - channels: channels, - authn: authn, - pubsub: pubsub, - } -} - -func (svc *adapterService) Subscribe(ctx context.Context, sessionID, username, password, domainID, channelID, subtopic string, topicType messaging.TopicType, c *Client) error { - if (channelID == "" && topicType != messaging.HealthType) || password == "" || domainID == "" { - return svcerr.ErrAuthentication - } - - clientID, err := svc.authorize(ctx, username, password, domainID, channelID, connections.Subscribe, topicType) - if err != nil { - return svcerr.ErrAuthorization - } - - c.id = clientID - - // Health check topics do not subscribe to the message broker. - if topicType == messaging.HealthType { - return nil - } - - subject := messaging.EncodeTopic(domainID, channelID, subtopic) - subCfg := messaging.SubscriberConfig{ - ID: sessionID, - ClientID: clientID, - Topic: subject, - Handler: c, - } - if err := svc.pubsub.Subscribe(ctx, subCfg); err != nil { - return errors.Wrap(ErrFailedSubscription, err) - } - - return nil -} - -func (svc *adapterService) Unsubscribe(ctx context.Context, sessionID, domainID, channelID, subtopic string, topicType messaging.TopicType) error { - topic := messaging.EncodeTopic(domainID, channelID, subtopic) - - // Health check topics do not subscribe to the message broker. - if topicType == messaging.MessageType { - if err := svc.pubsub.Unsubscribe(ctx, sessionID, topic); err != nil { - return errors.Wrap(ErrFailedSubscribe, err) - } - } - return nil -} - -// authorize checks if the authKey is authorized to access the channel -// and returns the clientID or userID if it is. -func (svc *adapterService) authorize(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) { - var token, clientType string - var err error - switch { - case strings.HasPrefix(password, apiutil.BearerPrefix): - token = strings.TrimPrefix(password, apiutil.BearerPrefix) - clientType = policies.UserType - case username != "" && password != "": - token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password) - clientType = policies.ClientType - case strings.HasPrefix(password, apiutil.BasicAuthPrefix): - username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix)) - if err != nil { - return "", errors.Wrap(svcerr.ErrAuthentication, err) - } - token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password) - clientType = policies.ClientType - default: - token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix)) - clientType = policies.ClientType - } - - id, err := svc.authenticate(ctx, clientType, token) - if err != nil { - return "", errors.Wrap(svcerr.ErrAuthentication, err) - } - - // Health check topics do not require channel authorization. - if topicType == messaging.HealthType { - return id, nil - } - - authzReq := &grpcChannelsV1.AuthzReq{ - ClientType: clientType, - ClientId: id, - Type: uint32(msgType), - ChannelId: chanID, - DomainId: domainID, - } - authzRes, err := svc.channels.Authorize(ctx, authzReq) - if err != nil { - return "", errors.Wrap(svcerr.ErrAuthorization, err) - } - if !authzRes.GetAuthorized() { - return "", errors.Wrap(svcerr.ErrAuthorization, err) - } - - return id, nil -} - -func (svc *adapterService) authenticate(ctx context.Context, authType, token string) (string, error) { - switch authType { - case policies.UserType: - authnSession, err := svc.authn.Authenticate(ctx, token) - if err != nil { - return "", err - } - return authnSession.UserID, nil - case policies.ClientType: - authnRes, err := svc.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token}) - if err != nil { - return "", err - } - if !authnRes.Authenticated { - return "", svcerr.ErrAuthentication - } - - return authnRes.GetId(), nil - default: - return "", errInvalidClientType - } -} diff --git a/http/adapter_test.go b/http/adapter_test.go deleted file mode 100644 index 174cdf70c..000000000 --- a/http/adapter_test.go +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http_test - -import ( - "context" - "encoding/base64" - "fmt" - "log/slog" - "strings" - "testing" - - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - apiutil "github.com/absmach/supermq/api/http/util" - chmocks "github.com/absmach/supermq/channels/mocks" - climocks "github.com/absmach/supermq/clients/mocks" - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/internal/testsutil" - smqauthn "github.com/absmach/supermq/pkg/authn" - authnmocks "github.com/absmach/supermq/pkg/authn/mocks" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/mocks" - "github.com/absmach/supermq/pkg/policies" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -const ( - invalidID = "invalidID" - invalidKey = "invalidKey" - id = "1" - clientKey = "client_key" - subTopic = "subtopic" - protocol = "ws" - token = "Bearer token" - invalidToken = "Bearer invalid_token" -) - -var ( - domainID = testsutil.GenerateUUID(&testing.T{}) - clientID = testsutil.GenerateUUID(&testing.T{}) - userID = testsutil.GenerateUUID(&testing.T{}) - chanID = testsutil.GenerateUUID(&testing.T{}) - msg = messaging.Message{ - Channel: chanID, - Domain: domainID, - Publisher: id, - Subtopic: "", - Protocol: protocol, - Payload: []byte(`[{"n":"current","t":-5,"v":1.2}]`), - } - sessionID = "sessionID" - validEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey))) - invalidEncodedCreds = base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", invalidID, invalidKey))) -) - -func newService() (smqhttp.Service, *mocks.PubSub, *climocks.ClientsServiceClient, *chmocks.ChannelsServiceClient, *authnmocks.Authentication) { - pubsub := new(mocks.PubSub) - clients := new(climocks.ClientsServiceClient) - channels := new(chmocks.ChannelsServiceClient) - authn := new(authnmocks.Authentication) - - return smqhttp.NewService(clients, channels, authn, pubsub), pubsub, clients, channels, authn -} - -func TestSubscribe(t *testing.T) { - svc, pubsub, clients, channels, auth := newService() - - c := smqhttp.NewClient(slog.Default(), nil, sessionID) - - cases := []struct { - desc string - username string - password string - chanID string - domainID string - subtopic string - clientType string - clientID string - topicType messaging.TopicType - authNToken string - authNRes *grpcClientsV1.AuthnRes - authNErr error - authNRes1 smqauthn.Session - authZRes *grpcChannelsV1.AuthzRes - authZErr error - subErr error - err error - }{ - { - desc: "subscribe to channel with valid clientKey, chanID, subtopic", - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with valid token, chanID, subtopic", - password: token, - chanID: chanID, - domainID: domainID, - clientID: userID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNRes1: smqauthn.Session{UserID: userID}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with invalid token", - password: invalidToken, - chanID: chanID, - domainID: domainID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNRes1: smqauthn.Session{}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe again to channel with valid clientKey, chanID, subtopic", - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with subscribe set to fail", - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - subErr: smqhttp.ErrFailedSubscription, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: smqhttp.ErrFailedSubscription, - }, - { - desc: "subscribe to channel with invalid clientKey", - password: invalidKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to channel with empty channel", - password: clientKey, - chanID: "", - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe to channel with empty clientKey", - password: "", - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe to channel with empty clientKey and empty channel", - password: "", - chanID: "", - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe to channel with invalid channel", - password: clientKey, - chanID: invalidID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - authZErr: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to channel with failed authentication", - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to channel with failed authorization", - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to channel with valid clientKey prefixed with 'client_', chanID, subtopic", - password: "Client " + clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with basic auth", - username: clientID, - password: clientKey, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with basic auth and invalid credentials", - username: invalidID, - password: invalidKey, - chanID: chanID, - domainID: domainID, - clientID: invalidID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to channel with b64 encoded credentials", - password: apiutil.BasicAuthPrefix + validEncodedCreds, - chanID: chanID, - domainID: domainID, - clientID: clientID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe to channel with b64 encoded credentials and invalid credentials", - password: apiutil.BasicAuthPrefix + invalidEncodedCreds, - chanID: chanID, - domainID: domainID, - clientID: invalidID, - subtopic: subTopic, - topicType: messaging.MessageType, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, invalidID, invalidKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthorization, - }, - { - desc: "subscribe to health check topic with empty channel and valid clientKey", - password: clientKey, - chanID: "", - domainID: domainID, - clientID: clientID, - topicType: messaging.HealthType, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - err: nil, - }, - { - desc: "subscribe to health check topic with empty channel and valid token", - password: token, - chanID: "", - domainID: domainID, - clientID: userID, - topicType: messaging.HealthType, - authNRes1: smqauthn.Session{UserID: userID}, - err: nil, - }, - { - desc: "subscribe to health check topic with empty domain and valid clientKey", - password: clientKey, - chanID: "", - domainID: "", - clientID: clientID, - topicType: messaging.HealthType, - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - err: svcerr.ErrAuthentication, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - subConfig := messaging.SubscriberConfig{ - ID: sessionID, - Topic: "m." + tc.domainID + ".c." + tc.chanID + "." + subTopic, - ClientID: tc.clientID, - Handler: c, - } - tc.clientType = policies.ClientType - if strings.HasPrefix(tc.password, apiutil.BearerPrefix) { - tc.clientType = policies.UserType - } - clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr) - authCall := auth.On("Authenticate", mock.Anything, strings.TrimPrefix(tc.password, apiutil.BearerPrefix)).Return(tc.authNRes1, tc.authNErr) - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - ClientType: tc.clientType, - ClientId: tc.clientID, - Type: uint32(connections.Subscribe), - ChannelId: tc.chanID, - DomainId: tc.domainID, - }).Return(tc.authZRes, tc.authZErr) - repoCall := pubsub.On("Subscribe", mock.Anything, subConfig).Return(tc.subErr) - err := svc.Subscribe(context.Background(), sessionID, tc.username, tc.password, tc.domainID, tc.chanID, tc.subtopic, tc.topicType, c) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - repoCall.Unset() - clientsCall.Unset() - authCall.Unset() - channelsCall.Unset() - }) - } -} diff --git a/http/api/endpoint.go b/http/api/endpoint.go deleted file mode 100644 index 417ec0f5a..000000000 --- a/http/api/endpoint.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "crypto/rand" - "encoding/hex" - "fmt" - "log/slog" - "net/http" - "strings" - - api "github.com/absmach/supermq/api/http" - apiutil "github.com/absmach/supermq/api/http/util" - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/messaging" - "github.com/go-kit/kit/endpoint" -) - -func messageHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if isWebSocketRequest(r) { - handleWebSocket(ctx, svc, resolver, logger, w, r) - return - } - if r.Method != http.MethodPost { - encodeError(ctx, w, errMethodNotAllowed) - return - } - req, err := decodePublishReq(ctx, r) - if err != nil { - encodeError(ctx, w, err) - return - } - _, err = sendMessageEndpoint()(ctx, req) - if err != nil { - encodeError(ctx, w, err) - return - } - - err = api.EncodeResponse(ctx, w, publishMessageRes{}) - if err != nil { - encodeError(ctx, w, err) - } - } -} - -func sendMessageEndpoint() endpoint.Endpoint { - return func(ctx context.Context, request any) (any, error) { - req := request.(publishReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - return publishMessageRes{}, nil - } -} - -func healthCheckEndpoint() endpoint.Endpoint { - return func(ctx context.Context, request any) (any, error) { - req := request.(healthCheckReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - return healthCheckRes{}, nil - } -} - -func handleWebSocket(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, w http.ResponseWriter, r *http.Request) { - req, err := decodeWSReq(r, resolver, logger) - if err != nil { - encodeError(ctx, w, err) - return - } - - sessionID, err := generateSessionID() - if err != nil { - logger.Warn(fmt.Sprintf("Failed to generate session id: %s", err.Error())) - http.Error(w, "", http.StatusInternalServerError) - return - } - - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - logger.Warn(fmt.Sprintf("Failed to upgrade connection to websocket: %s", err.Error())) - return - } - - client := smqhttp.NewClient(logger, conn, sessionID) - - client.SetCloseHandler(func(code int, text string) error { - return svc.Unsubscribe(ctx, sessionID, req.domainID, req.channelID, req.subtopic, messaging.MessageType) - }) - - go client.Start(ctx) - - if err := svc.Subscribe(ctx, sessionID, req.username, req.password, req.domainID, req.channelID, req.subtopic, messaging.MessageType, client); err != nil { - conn.Close() - return - } - - logger.Debug(fmt.Sprintf("Successfully upgraded communication to WS on channel %s", req.channelID)) -} - -func isWebSocketRequest(r *http.Request) bool { - return strings.EqualFold(r.Header.Get(connHeaderKey), connHeaderVal) && - strings.EqualFold(r.Header.Get(upgradeHeaderKey), upgradeHeaderVal) -} - -func generateSessionID() (string, error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", errors.Wrap(errGenSessionID, err) - } - return hex.EncodeToString(b), nil -} diff --git a/http/api/endpoint_test.go b/http/api/endpoint_test.go deleted file mode 100644 index 3d22c86de..000000000 --- a/http/api/endpoint_test.go +++ /dev/null @@ -1,531 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api_test - -import ( - "context" - "fmt" - "io" - "net" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - - "github.com/absmach/mgate" - proxy "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1" - grpcDomainsV1 "github.com/absmach/supermq/api/grpc/domains/v1" - apiutil "github.com/absmach/supermq/api/http/util" - chmocks "github.com/absmach/supermq/channels/mocks" - climocks "github.com/absmach/supermq/clients/mocks" - dmocks "github.com/absmach/supermq/domains/mocks" - server "github.com/absmach/supermq/http" - "github.com/absmach/supermq/http/api" - "github.com/absmach/supermq/internal/testsutil" - smqlog "github.com/absmach/supermq/logger" - smqauthn "github.com/absmach/supermq/pkg/authn" - authnMocks "github.com/absmach/supermq/pkg/authn/mocks" - "github.com/absmach/supermq/pkg/connections" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - pubsub "github.com/absmach/supermq/pkg/messaging/mocks" - "github.com/absmach/supermq/pkg/policies" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const ( - instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" - invalidValue = "invalid" - clientKey = "c02ff576-ccd5-40f6-ba5f-c85377aad529" - wsProtocol = "ws" - invalidKey = "invalid-key" - validToken = "valid-token" - invalidToken = "invalid-token" - ctSenmlJSON = "application/senml+json" - ctSenmlCBOR = "application/senml+cbor" - ctJSON = "application/json" - msgJSON = `{"field1":"val1","field2":"val2"}` - msgCBOR = `81A3616E6763757272656E746174206176FB3FF999999999999A` - msg = `[{"n":"current","t":-1,"v":1.6}]` -) - -var ( - clientID = testsutil.GenerateUUID(&testing.T{}) - chanID = testsutil.GenerateUUID(&testing.T{}) - domainID = testsutil.GenerateUUID(&testing.T{}) - userID = testsutil.GenerateUUID(&testing.T{}) -) - -func newService(clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, authn smqauthn.Authentication, pubsub *pubsub.PubSub) server.Service { - return server.NewService(clients, channels, authn, pubsub) -} - -func newHandler(authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, domains grpcDomainsV1.DomainsServiceClient) (session.Handler, *pubsub.PubSub, error) { - pub := new(pubsub.PubSub) - parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) - if err != nil { - return nil, nil, err - } - - return server.NewHandler(pub, smqlog.NewMock(), authn, clients, channels, parser), pub, nil -} - -func newTargetHTTPServer(resolver messaging.TopicResolver, svc server.Service) *httptest.Server { - mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), instanceID) - return httptest.NewServer(mux) -} - -func newProxyHTPPServer(svc session.Handler, targetServer *httptest.Server) (*httptest.Server, error) { - ptUrl, _ := url.Parse(targetServer.URL) - ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host) - config := mgate.Config{ - Host: "", - Port: "", - PathPrefix: "", - TargetHost: ptHost, - TargetPort: ptPort, - TargetProtocol: ptUrl.Scheme, - TargetPath: ptUrl.Path, - } - mp, err := proxy.NewProxy(config, svc, smqlog.NewMock(), []string{}, []string{}) - if err != nil { - return nil, err - } - return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), nil -} - -type testRequest struct { - client *http.Client - method string - url string - contentType string - token string - body io.Reader - basicAuth bool - bearerToken bool -} - -func (tr testRequest) make() (*http.Response, error) { - req, err := http.NewRequest(tr.method, tr.url, tr.body) - if err != nil { - return nil, err - } - - if tr.token != "" { - switch { - case tr.basicAuth: - req.SetBasicAuth("", apiutil.ClientPrefix+tr.token) - case tr.bearerToken: - req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) - default: - req.Header.Set("Authorization", apiutil.ClientPrefix+tr.token) - } - } - if tr.contentType != "" { - req.Header.Set("Content-Type", tr.contentType) - } - return tr.client.Do(req) -} - -func TestPublish(t *testing.T) { - clients := new(climocks.ClientsServiceClient) - authn := new(authnMocks.Authentication) - channels := new(chmocks.ChannelsServiceClient) - domains := new(dmocks.DomainsServiceClient) - resolver := messaging.NewTopicResolver(channels, domains) - handler, pubsub, err := newHandler(authn, clients, channels, domains) - assert.Nil(t, err, fmt.Sprintf("failed to create handler with err: %v", err)) - svc := newService(clients, channels, authn, pubsub) - target := newTargetHTTPServer(resolver, svc) - defer target.Close() - ts, err := newProxyHTPPServer(handler, target) - require.Nil(t, err) - defer ts.Close() - - cases := []struct { - desc string - domainID string - chanID string - clientID string - clientType string - msg string - contentType string - key string - status int - basicAuth bool - bearerToken bool - authnErr error - authnRes *grpcClientsV1.AuthnRes - authnRes1 smqauthn.Session - authzRes *grpcChannelsV1.AuthzRes - authzErr error - err error - }{ - { - desc: "publish message successfully", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with application/senml+cbor content-type", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msgCBOR, - contentType: ctSenmlCBOR, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with application/json content-type", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msgJSON, - contentType: ctJSON, - key: clientKey, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with empty key", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: "", - status: http.StatusBadRequest, - }, - { - desc: "publish message with basic auth", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - basicAuth: true, - status: http.StatusAccepted, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with invalid key", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: invalidKey, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - }, - { - desc: "publish message with invalid basic auth", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: invalidKey, - basicAuth: true, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - }, - { - desc: "publish message with valid bearer token", - domainID: domainID, - chanID: chanID, - clientID: userID, - msg: msg, - contentType: ctSenmlJSON, - key: validToken, - bearerToken: true, - status: http.StatusAccepted, - authnRes1: smqauthn.Session{UserID: userID}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message with invalid bearer token", - domainID: domainID, - chanID: chanID, - clientID: userID, - msg: msg, - contentType: ctSenmlJSON, - key: invalidToken, - bearerToken: true, - status: http.StatusUnauthorized, - authnRes1: smqauthn.Session{}, - authnErr: svcerr.ErrAuthentication, - }, - { - desc: "publish message without content type", - domainID: domainID, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: "", - key: clientKey, - status: http.StatusUnsupportedMediaType, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish message to empty channel", - domainID: domainID, - chanID: "", - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusBadRequest, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - }, - { - desc: "publish message with invalid domain ID", - domainID: invalidValue, - chanID: chanID, - clientID: clientID, - msg: msg, - contentType: ctSenmlJSON, - key: clientKey, - status: http.StatusUnauthorized, - authnRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: smqauthn.AuthPack(smqauthn.DomainAuth, tc.domainID, tc.key)}).Return(tc.authnRes, tc.authnErr) - authCall := authn.On("Authenticate", mock.Anything, tc.key).Return(tc.authnRes1, tc.authnErr) - domainsCall := domains.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil) - tc.clientType = policies.ClientType - clientID := tc.clientID - if tc.bearerToken { - tc.clientType = policies.UserType - clientID = policies.EncodeDomainUserID(tc.domainID, tc.clientID) - } - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - DomainId: tc.domainID, - ChannelId: tc.chanID, - ClientId: clientID, - ClientType: tc.clientType, - Type: uint32(connections.Publish), - }).Return(tc.authzRes, tc.authzErr) - svcCall := pubsub.On("Publish", mock.Anything, messaging.EncodeTopicSuffix(tc.domainID, tc.chanID, ""), mock.Anything).Return(nil) - req := testRequest{ - client: ts.Client(), - method: http.MethodPost, - url: fmt.Sprintf("%s/m/%s/c/%s", ts.URL, tc.domainID, tc.chanID), - contentType: tc.contentType, - token: tc.key, - body: strings.NewReader(tc.msg), - basicAuth: tc.basicAuth, - bearerToken: tc.bearerToken, - } - res, err := req.make() - assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) - assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - svcCall.Unset() - clientsCall.Unset() - authCall.Unset() - channelsCall.Unset() - domainsCall.Unset() - }) - } -} - -func TestHandshake(t *testing.T) { - clients := new(climocks.ClientsServiceClient) - channels := new(chmocks.ChannelsServiceClient) - authn := new(authnMocks.Authentication) - domains := new(dmocks.DomainsServiceClient) - resolver := messaging.NewTopicResolver(channels, domains) - handler, pubsub, err := newHandler(authn, clients, channels, domains) - assert.Nil(t, err, fmt.Sprintf("failed to create handler with err: %v", err)) - svc := newService(clients, channels, authn, pubsub) - target := newTargetHTTPServer(resolver, svc) - defer target.Close() - ts, err := newProxyHTPPServer(handler, target) - require.Nil(t, err) - defer ts.Close() - msg := []byte(`[{"n":"current","t":-1,"v":1.6}]`) - pubsub.On("Subscribe", mock.Anything, mock.Anything).Return(nil) - pubsub.On("Unsubscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil) - pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) - clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: true}, nil) - clients.On("Authenticate", mock.Anything, mock.Anything).Return(&grpcClientsV1.AuthnRes{Authenticated: false}, nil) - authn.On("Authenticate", mock.Anything, mock.Anything).Return(smqauthn.Session{}, nil) - channels.On("Authorize", mock.Anything, mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil) - - cases := []struct { - desc string - domainID string - chanID string - subtopic string - header bool - clientKey string - status int - err error - msg []byte - }{ - { - desc: "connect and send message", - domainID: domainID, - chanID: chanID, - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message with clientKey as query parameter", - domainID: domainID, - chanID: chanID, - subtopic: "", - header: false, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message that cannot be published", - domainID: domainID, - chanID: chanID, - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: []byte{}, - }, - { - desc: "connect and send message to subtopic", - domainID: domainID, - chanID: chanID, - subtopic: "subtopic", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message to nested subtopic", - domainID: domainID, - chanID: chanID, - subtopic: "subtopic/nested", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect and send message to all subtopics", - domainID: domainID, - chanID: chanID, - subtopic: ">", - header: true, - clientKey: clientKey, - status: http.StatusSwitchingProtocols, - msg: msg, - }, - { - desc: "connect to empty channel", - domainID: domainID, - chanID: "", - subtopic: "", - header: true, - clientKey: clientKey, - status: http.StatusUnauthorized, - msg: []byte{}, - }, - { - desc: "connect with empty clientKey", - domainID: domainID, - chanID: chanID, - subtopic: "", - header: true, - clientKey: "", - status: http.StatusBadRequest, - msg: []byte{}, - }, - { - desc: "connect and send message to subtopic with invalid name", - domainID: domainID, - chanID: chanID, - subtopic: "sub/a*b/topic", - header: true, - clientKey: clientKey, - status: http.StatusUnauthorized, - msg: msg, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - conn, res, err := handshake(ts.URL, tc.domainID, tc.chanID, tc.subtopic, tc.clientKey, tc.header) - assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code '%d' got '%d'\n", tc.desc, tc.status, res.StatusCode)) - if tc.status == http.StatusSwitchingProtocols { - assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) - err = conn.WriteMessage(websocket.TextMessage, tc.msg) - assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error %s\n", tc.desc, err)) - } - }) - } -} - -func makeURL(tsURL, domainID, chanID, subtopic, clientKey string, header bool) (string, error) { - u, _ := url.Parse(tsURL) - u.Scheme = wsProtocol - - if chanID == "0" || chanID == "" { - if header { - return fmt.Sprintf("%s/m/%s/c/%s", u, domainID, chanID), fmt.Errorf("invalid channel id") - } - return fmt.Sprintf("%s/m/%s/c/%s?authorization=%s", u, domainID, chanID, clientKey), fmt.Errorf("invalid channel id") - } - - subtopicPart := "" - if subtopic != "" { - subtopicPart = fmt.Sprintf("/%s", subtopic) - } - if header { - return fmt.Sprintf("%s/m/%s/c/%s%s", u, domainID, chanID, subtopicPart), nil - } - - return fmt.Sprintf("%s/m/%s/c/%s%s?authorization=%s", u, domainID, chanID, subtopicPart, clientKey), nil -} - -func handshake(tsURL, domainID, chanID, subtopic, clientKey string, addHeader bool) (*websocket.Conn, *http.Response, error) { - header := http.Header{} - if addHeader { - header.Add("Authorization", clientKey) - } - - turl, _ := makeURL(tsURL, domainID, chanID, subtopic, clientKey, addHeader) - conn, res, errRet := websocket.DefaultDialer.Dial(turl, header) - - return conn, res, errRet -} diff --git a/http/api/request.go b/http/api/request.go deleted file mode 100644 index e8a2f2237..000000000 --- a/http/api/request.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/pkg/messaging" -) - -type publishReq struct { - msg *messaging.Message - token string -} - -func (req publishReq) validate() error { - if req.token == "" { - return apiutil.ErrBearerKey - } - if len(req.msg.Payload) == 0 { - return apiutil.ErrEmptyMessage - } - - return nil -} - -type connReq struct { - username string - password string - channelID string - domainID string - subtopic string -} - -type healthCheckReq struct { - domain string - token string -} - -func (req healthCheckReq) validate() error { - if req.token == "" { - return apiutil.ErrBearerKey - } - if req.domain == "" { - return apiutil.ErrMissingDomainID - } - - return nil -} diff --git a/http/api/response.go b/http/api/response.go deleted file mode 100644 index 425da450e..000000000 --- a/http/api/response.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "net/http" - - "github.com/absmach/supermq" -) - -var ( - _ supermq.Response = (*publishMessageRes)(nil) - _ supermq.Response = (*healthCheckRes)(nil) -) - -type publishMessageRes struct{} - -func (res publishMessageRes) Code() int { - return http.StatusAccepted -} - -func (res publishMessageRes) Headers() map[string]string { - return map[string]string{} -} - -func (res publishMessageRes) Empty() bool { - return true -} - -type healthCheckRes struct{} - -func (res healthCheckRes) Code() int { - return http.StatusOK -} - -func (res healthCheckRes) Headers() map[string]string { - return map[string]string{} -} - -func (res healthCheckRes) Empty() bool { - return true -} diff --git a/http/api/transport.go b/http/api/transport.go deleted file mode 100644 index a223d7f74..000000000 --- a/http/api/transport.go +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "encoding/json" - "io" - "log/slog" - "net/http" - - "github.com/absmach/supermq" - api "github.com/absmach/supermq/api/http" - apiutil "github.com/absmach/supermq/api/http/util" - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/messaging" - "github.com/go-chi/chi/v5" - kithttp "github.com/go-kit/kit/transport/http" - "github.com/gorilla/websocket" - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" -) - -const ( - ctSenmlJSON = "application/senml+json" - ctSenmlCBOR = "application/senml+cbor" - contentType = "application/json" - authzHeaderKey = "Authorization" - authzQueryKey = "authorization" - connHeaderKey = "Connection" - connHeaderVal = "upgrade" - upgradeHeaderKey = "Upgrade" - upgradeHeaderVal = "websocket" - readwriteBufferSize = 1024 -) - -var ( - upgrader = websocket.Upgrader{ - ReadBufferSize: readwriteBufferSize, - WriteBufferSize: readwriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - errUnauthorizedAccess = errors.New("missing or invalid credentials provided") - errMalformedSubtopic = errors.New("malformed subtopic") - errGenSessionID = errors.New("failed to generate session id") - errMethodNotAllowed = errors.New("method not allowed") -) - -// MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(ctx context.Context, svc smqhttp.Service, resolver messaging.TopicResolver, logger *slog.Logger, instanceID string) http.Handler { - opts := []kithttp.ServerOption{ - kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), - } - r := chi.NewRouter() - - r.Handle("/m/{domain}/c/{channel}", messageHandler(ctx, svc, resolver, logger)) - - r.Handle("/m/{domain}/c/{channel}/*", messageHandler(ctx, svc, resolver, logger)) - - r.Post("/hc/{domain}", otelhttp.NewHandler(kithttp.NewServer( - healthCheckEndpoint(), - decodeHealthCheckRequest, - api.EncodeResponse, - opts..., - ), "health_check").ServeHTTP) - - r.Get("/health", supermq.Health("http", instanceID)) - r.Handle("/metrics", promhttp.Handler()) - - return r -} - -func decodePublishReq(_ context.Context, r *http.Request) (any, error) { - ct := r.Header.Get("Content-Type") - if ct != ctSenmlJSON && ct != contentType && ct != ctSenmlCBOR { - return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) - } - - var req publishReq - _, pass, ok := r.BasicAuth() - switch { - case ok: - req.token = pass - case !ok: - req.token = r.Header.Get(authzHeaderKey) - } - - payload, err := io.ReadAll(r.Body) - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, errors.ErrMalformedEntity) - } - defer r.Body.Close() - - req.msg = &messaging.Message{Payload: payload} - - return req, nil -} - -func decodeWSReq(r *http.Request, resolver messaging.TopicResolver, logger *slog.Logger) (connReq, error) { - username, password, ok := r.BasicAuth() - if !ok { - switch { - case r.URL.Query().Get(authzQueryKey) != "": - password = r.URL.Query().Get(authzQueryKey) - case r.Header.Get(authzHeaderKey) != "": - password = r.Header.Get(authzHeaderKey) - default: - logger.Debug("Missing authorization key.") - return connReq{}, errUnauthorizedAccess - } - } - - domain := chi.URLParam(r, "domain") - channel := chi.URLParam(r, "channel") - - domainID, channelID, _, err := resolver.Resolve(r.Context(), domain, channel) - if err != nil { - return connReq{}, err - } - - req := connReq{ - username: username, - password: password, - channelID: channelID, - domainID: domainID, - } - - subTopic := chi.URLParam(r, "*") - - if subTopic != "" { - subTopic, err := messaging.ParseSubscribeSubtopic(subTopic) - if err != nil { - return connReq{}, err - } - req.subtopic = subTopic - } - - return req, nil -} - -func decodeHealthCheckRequest(_ context.Context, r *http.Request) (any, error) { - var req healthCheckReq - req.domain = chi.URLParam(r, "domain") - _, pass, ok := r.BasicAuth() - switch { - case ok: - req.token = pass - case !ok: - req.token = r.Header.Get(authzHeaderKey) - } - - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - return req, nil -} - -func encodeError(ctx context.Context, w http.ResponseWriter, err error) { - switch err { - case smqhttp.ErrEmptyTopic: - w.WriteHeader(http.StatusBadRequest) - case errUnauthorizedAccess: - w.WriteHeader(http.StatusForbidden) - case errMalformedSubtopic, errors.ErrMalformedEntity: - w.WriteHeader(http.StatusBadRequest) - default: - api.EncodeError(ctx, err, w) - return - } - - if errorVal, ok := err.(errors.Error); ok { - if err := json.NewEncoder(w).Encode(errorVal); err != nil { - w.WriteHeader(http.StatusInternalServerError) - } - } -} diff --git a/http/client.go b/http/client.go deleted file mode 100644 index e4a403ca5..000000000 --- a/http/client.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "log/slog" - "time" - - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/messaging" - "github.com/gorilla/websocket" - "golang.org/x/sync/errgroup" -) - -var ( - errHandlerBlockedMsgChan = errors.New("message handler msg chan blocked (full)") - errHandlerClosedMsgChan = errors.New("message handler closed msg chan") - errFailedToWriteMsg = errors.New("failed to write message to connection") - errFailedToWritePing = errors.New("failed to write ping to connection") - errReadMsg = errors.New("failed to read messages ") -) - -const ( - // Time allowed to write a message to the peer. - writeWait = 10 * time.Second - - // Send pings to peer with this period. Must be less than pongWait. - pingPeriod = 30 * time.Second - - // Time allowed to read the next pong message from the peer. - pongWait = 60 * time.Second -) - -// Client handles messaging and websocket connection. -type Client struct { - logger *slog.Logger - conn *websocket.Conn - id string - msg chan *messaging.Message - handledClose bool -} - -// NewClient returns a new websocket client. -func NewClient(logger *slog.Logger, conn *websocket.Conn, sessionID string) *Client { - c := &Client{ - logger: logger, - conn: conn, - id: sessionID, - msg: make(chan *messaging.Message, 1024), - } - return c -} - -// Cancel handles the websocket connection after unsubscribing. -func (c *Client) Cancel() error { - if c.conn == nil { - return nil - } - return c.conn.Close() -} - -// Close handles the websocket connection after unsubscribing. -func (c *Client) Close() error { - err := c.conn.Close() - if err != nil { - c.logger.Debug("failed to close websocket client", slog.String("session_id", c.id), slog.String("error", err.Error())) - } - ch := c.conn.CloseHandler() - err = ch(0, "") - if err != nil { - c.logger.Debug("failed to execute websocket connection close handler", slog.String("session_id", c.id), slog.String("error", err.Error())) - } - return nil -} - -// Handle handles the sending and receiving of messages via the broker. -func (c *Client) Handle(msg *messaging.Message) error { - select { - case c.msg <- msg: - return nil - default: - return errHandlerBlockedMsgChan - } -} - -// CloseHandler will work only if messages are read. -func (c *Client) readPump(ctx context.Context, cancel context.CancelFunc) error { - defer cancel() - c.conn.SetPongHandler(func(string) error { - if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { - return err - } - return nil - }) - - errCh := make(chan error, 1) - go func() { - errCh <- c.readMessage() - }() - - for { - select { - case <-ctx.Done(): - c.logger.Debug("read_pump: received context Done") - return nil - case err := <-errCh: - return err - } - } -} - -func (c *Client) readMessage() error { - for { - msgType, msg, err := c.conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - c.logger.Debug("read_pump: unexpected close error", slog.String("error", err.Error())) - return nil - } - return errors.Wrap(errReadMsg, err) - } - c.logger.Debug("read_pump: received message ", slog.Int("message_type", msgType), slog.String("message", string(msg))) - } -} - -func (c *Client) writePump(ctx context.Context, cancel context.CancelFunc) error { - defer cancel() - ticker := time.NewTicker(pingPeriod) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - c.logger.Debug("write_pump: received context Done ") - return nil - case msg, ok := <-c.msg: - _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) - if !ok { - if err := c.conn.WriteMessage(websocket.CloseMessage, []byte{}); err != nil { - return errors.Wrap(errHandlerClosedMsgChan, err) - } - return errHandlerClosedMsgChan - } - if err := c.conn.WriteMessage(websocket.BinaryMessage, msg.GetPayload()); err != nil { - return errors.Wrap(errFailedToWriteMsg, err) - } - case <-ticker.C: - if err := c.conn.WriteControl(websocket.PingMessage, nil, time.Now().Add(writeWait)); err != nil { - return errors.Wrap(errFailedToWritePing, err) - } - } - } -} - -// SetCloseHandler sets a close handler for the WebSocket connection. -func (c *Client) SetCloseHandler(handler func(code int, text string) error) { - c.conn.SetCloseHandler(func(code int, text string) error { - if !c.handledClose { - if err := handler(code, text); err != nil { - c.logger.Warn("Error in close handler", slog.String("error", err.Error())) - } - c.handledClose = true - } - return nil - }) -} - -func (c *Client) Start(ctx context.Context) { - defer c.Close() - ctx, cancel := context.WithCancel(ctx) - g, ctx := errgroup.WithContext(ctx) - - g.Go(func() error { - return c.readPump(ctx, cancel) - }) - - g.Go(func() error { - return c.writePump(ctx, cancel) - }) - - err := g.Wait() - if err != nil { - c.logger.Warn("websocket client error", slog.String("session_id", c.id), slog.String("error", err.Error())) - } -} diff --git a/http/client_test.go b/http/client_test.go deleted file mode 100644 index 58d6fce1a..000000000 --- a/http/client_test.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http_test - -import ( - "context" - "fmt" - "log/slog" - "net/http" - "net/http/httptest" - "strings" - "sync/atomic" - "testing" - "time" - - smqhttp "github.com/absmach/supermq/http" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" -) - -const expectedCount = uint64(2) - -var ( - msgChan = make(chan []byte) - c *smqhttp.Client - count uint64 - - upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } -) - -func handler(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - defer conn.Close() - for { - _, message, err := conn.ReadMessage() - if err != nil { - break - } - atomic.AddUint64(&count, 1) - msgChan <- message - } -} - -func TestHandle(t *testing.T) { - s := httptest.NewServer(http.HandlerFunc(handler)) - defer s.Close() - - // Convert http://127.0.0.1 to ws://127.0.0.1 - u := strings.Replace(s.URL, "http", "ws", 1) - - // Connect to the server - wsConn, _, err := websocket.DefaultDialer.Dial(u, nil) - if err != nil { - t.Fatalf("%v", err) - } - defer wsConn.Close() - - c = smqhttp.NewClient(slog.Default(), wsConn, "sessionID") - go c.Start(context.Background()) - - cases := []struct { - desc string - publisher string - expectedPayload []byte - expectMsg bool - }{ - { - desc: "handling with different id from ws.Client", - publisher: msg.Publisher, - expectedPayload: msg.Payload, - expectMsg: true, - }, - { - desc: "handling with same id as ws.Client (empty by default) drops message", - publisher: "", - expectedPayload: []byte{}, - expectMsg: false, - }, - } - - for _, tc := range cases { - msg.Publisher = tc.publisher - err = c.Handle(&msg) - assert.Nil(t, err, fmt.Sprintf("expected nil error from handle, got: %s", err)) - receivedMsg := []byte{} - switch tc.expectMsg { - case true: - rec := <-msgChan // Wait for the message to be received. - receivedMsg = rec - case false: - time.Sleep(100 * time.Millisecond) // Give time to server to process c.Handle call. - } - assert.Equal(t, tc.expectedPayload, receivedMsg, fmt.Sprintf("%s: expected %+v, got %+v", tc.desc, &msg, receivedMsg)) - } - c := atomic.LoadUint64(&count) - assert.Equal(t, expectedCount, c, fmt.Sprintf("expected message count %d, got %d", expectedCount, c)) -} diff --git a/http/doc.go b/http/doc.go deleted file mode 100644 index 51b9e900e..000000000 --- a/http/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package http contains the domain concept definitions needed to support -// SuperMQ HTTP Adapter functionality. -package http diff --git a/http/handler.go b/http/handler.go deleted file mode 100644 index d00a7f99f..000000000 --- a/http/handler.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "encoding/base64" - "fmt" - "log/slog" - "net/http" - "strings" - "time" - - mgate "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - apiutil "github.com/absmach/supermq/api/http/util" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var _ session.Handler = (*handler)(nil) - -const protocol = "http" - -// Log message formats. -const ( - LogInfoSubscribed = "subscribed with client_id %s to topics %s" - LogInfoConnected = "connected with client_id %s" - LogInfoDisconnected = "disconnected client_id %s and username %s" - LogInfoPublished = "published with client_id %s to the topic %s" -) - -// Error wrappers for MQTT errors. -var ( - errClientNotInitialized = errors.New("client is not initialized") - errMissingTopicPub = errors.New("failed to publish due to missing topic") - errMissingTopicSub = errors.New("failed to subscribe due to missing topic") - errFailedPublish = errors.New("failed to publish") - errFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") - errInvalidAuthFormat = errors.New("invalid basic auth format") - errInvalidClientType = errors.New("invalid client type") -) - -// Event implements events.Event interface. -type handler struct { - pubsub messaging.PubSub - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - authn smqauthn.Authentication - logger *slog.Logger - parser messaging.TopicParser -} - -// NewHandler creates new Handler entity. -func NewHandler(pubsub messaging.PubSub, logger *slog.Logger, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler { - return &handler{ - logger: logger, - pubsub: pubsub, - authn: authn, - clients: clients, - channels: channels, - parser: parser, - } -} - -// AuthConnect is called on device connection, -// prior forwarding to the http server. -func (h *handler) AuthConnect(ctx context.Context) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - var tok string - switch { - case string(s.Password) == "": - return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerKey)) - case strings.HasPrefix(string(s.Password), apiutil.ClientPrefix): - tok = strings.TrimPrefix(string(s.Password), apiutil.ClientPrefix) - default: - tok = string(s.Password) - } - - h.logger.Info(fmt.Sprintf(LogInfoConnected, tok)) - return nil -} - -// AuthPublish is called on device publish, -// prior forwarding to the http server. -func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - if topic == nil { - return errMissingTopicPub - } - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - domainID, channelID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true) - if err != nil { - return mgate.NewHTTPProxyError(http.StatusBadRequest, errors.Wrap(errFailedPublish, err)) - } - - clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType) - if err != nil { - return err - } - - if s.Username == "" { - s.Username = clientID - } - - return nil -} - -// AuthSubscribe is called on device publish, -// prior forwarding to the MQTT broker. -func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - if topics == nil || *topics == nil { - return errMissingTopicSub - } - - for _, topic := range *topics { - domainID, channelID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, true) - if err != nil { - return err - } - if _, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Subscribe, topicType); err != nil { - return err - } - } - return nil -} - -// Connect - after client successfully connected. -func (h *handler) Connect(ctx context.Context) error { - return nil -} - -// Publish - after client successfully published. -func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - - if len(*payload) == 0 { - h.logger.Warn("Empty payload, not publishing to broker", slog.String("client_id", s.Username)) - return nil - } - - domainID, channelID, subtopic, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, true) - if err != nil { - return errors.Wrap(errFailedPublish, err) - } - - msg := messaging.Message{ - Protocol: protocol, - Domain: domainID, - Channel: channelID, - Subtopic: subtopic, - Payload: *payload, - Publisher: s.Username, - Created: time.Now().UnixNano(), - } - - // Health check topic messages do not get published to message broker. - if topicType == messaging.MessageType { - if err := h.pubsub.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil { - return mgate.NewHTTPProxyError(http.StatusInternalServerError, errors.Wrap(errFailedPublishToMsgBroker, err)) - } - } - - h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) - - return nil -} - -// Subscribe - after client successfully subscribed. -func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errClientNotInitialized - } - h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) - return nil -} - -// Unsubscribe - after client unsubscribed. -func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { - return nil -} - -// Disconnect - connection with broker or client lost. -func (h *handler) Disconnect(ctx context.Context) error { - return nil -} - -func (h *handler) authAccess(ctx context.Context, username, password, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) (string, error) { - var token, clientType string - var err error - switch { - case strings.HasPrefix(password, apiutil.BearerPrefix): - token = strings.TrimPrefix(password, apiutil.BearerPrefix) - clientType = policies.UserType - case username != "" && password != "": - token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password) - clientType = policies.ClientType - case strings.HasPrefix(password, apiutil.BasicAuthPrefix): - username, password, err := decodeAuth(strings.TrimPrefix(password, apiutil.BasicAuthPrefix)) - if err != nil { - return "", errors.Wrap(svcerr.ErrAuthentication, err) - } - token = smqauthn.AuthPack(smqauthn.BasicAuth, username, password) - clientType = policies.ClientType - default: - token = smqauthn.AuthPack(smqauthn.DomainAuth, domainID, strings.TrimPrefix(password, apiutil.ClientPrefix)) - clientType = policies.ClientType - } - - id, subject, err := h.authenticate(ctx, clientType, token, domainID) - if err != nil { - return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err)) - } - - // Health check topics do not require channel authorization. - if topicType == messaging.HealthType { - return id, nil - } - - ar := &grpcChannelsV1.AuthzReq{ - Type: uint32(msgType), - ClientId: subject, - ClientType: clientType, - ChannelId: chanID, - DomainId: domainID, - } - res, err := h.channels.Authorize(ctx, ar) - if err != nil { - return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, errors.Wrap(svcerr.ErrAuthentication, err)) - } - if !res.GetAuthorized() { - return "", mgate.NewHTTPProxyError(http.StatusUnauthorized, svcerr.ErrAuthentication) - } - - return id, nil -} - -func (h *handler) authenticate(ctx context.Context, authType, token, domainID string) (string, string, error) { - switch authType { - case policies.UserType: - authnSession, err := h.authn.Authenticate(ctx, token) - if err != nil { - return "", "", err - } - if authnSession.Role == smqauthn.SuperAdminRole { - return authnSession.UserID, authnSession.UserID, nil - } - return authnSession.UserID, policies.EncodeDomainUserID(domainID, authnSession.UserID), nil - case policies.ClientType: - authnRes, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: token}) - if err != nil { - return "", "", err - } - if !authnRes.Authenticated { - return "", "", svcerr.ErrAuthentication - } - - return authnRes.GetId(), authnRes.GetId(), nil - default: - return "", "", errInvalidClientType - } -} - -// decodeAuth decodes the base64 encoded string in the format "clientID:secret". -func decodeAuth(s string) (string, string, error) { - db, err := base64.URLEncoding.DecodeString(s) - if err != nil { - return "", "", err - } - parts := strings.SplitN(string(db), ":", 2) - if len(parts) != 2 { - return "", "", errInvalidAuthFormat - } - clientID := parts[0] - secret := parts[1] - - return clientID, secret, nil -} diff --git a/http/handler_test.go b/http/handler_test.go deleted file mode 100644 index 26c3fdb00..000000000 --- a/http/handler_test.go +++ /dev/null @@ -1,783 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http_test - -import ( - "context" - "encoding/base64" - "fmt" - "net/http" - "strings" - "testing" - - mgate "github.com/absmach/mgate/pkg/http" - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - apiutil "github.com/absmach/supermq/api/http/util" - chmocks "github.com/absmach/supermq/channels/mocks" - clmocks "github.com/absmach/supermq/clients/mocks" - dmocks "github.com/absmach/supermq/domains/mocks" - smqhttp "github.com/absmach/supermq/http" - smqlog "github.com/absmach/supermq/logger" - smqauthn "github.com/absmach/supermq/pkg/authn" - authnmocks "github.com/absmach/supermq/pkg/authn/mocks" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/mocks" - "github.com/absmach/supermq/pkg/policies" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -var ( - invalidValue = "invalid" - topicMsg = "/m/%s/c/%s" - subtopicMsg = "/m/%s/c/%s/subtopic" - topic = fmt.Sprintf(topicMsg, domainID, chanID) - subtopic = fmt.Sprintf(subtopicMsg, domainID, chanID) - hcTopicFmt = "/hc/%s" - hcTopic = fmt.Sprintf(hcTopicFmt, domainID) - invalidHCTopic = "/hc" - invalidTopic = invalidValue - topics = []string{topic} - payload = []byte("[{'n':'test-name', 'v': 1.2}]") - sessionClient = session.Session{ - ID: clientID, - Password: []byte(clientKey), - } - invalidChannelIDTopic = "m/**/c" - validToken = "token" - errClientNotInitialized = errors.New("client is not initialized") - errMissingTopicPub = errors.New("failed to publish due to missing topic") - errMissingTopicSub = errors.New("failed to subscribe due to missing topic") -) - -var ( - clients = new(clmocks.ClientsServiceClient) - channels = new(chmocks.ChannelsServiceClient) - authn = new(authnmocks.Authentication) - publisher = new(mocks.PubSub) - domains = new(dmocks.DomainsServiceClient) -) - -func newHandler(t *testing.T) session.Handler { - logger := smqlog.NewMock() - authn = new(authnmocks.Authentication) - clients = new(clmocks.ClientsServiceClient) - channels = new(chmocks.ChannelsServiceClient) - publisher = new(mocks.PubSub) - parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) - assert.Nil(t, err, fmt.Sprintf("unexpected error while creating topic parser: %v", err)) - - return smqhttp.NewHandler(publisher, logger, authn, clients, channels, parser) -} - -func TestAuthPublish(t *testing.T) { - handler := newHandler(t) - - clientKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - unauthorizedKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - invalidClientKeySession := session.Session{ - Password: []byte("Client " + invalidKey), - } - tokenSession := session.Session{ - Password: []byte(apiutil.BearerPrefix + validToken), - } - invalidTokenSession := session.Session{ - Password: []byte(apiutil.BearerPrefix + invalidToken), - } - basicAuthSession := session.Session{ - Username: clientID, - Password: []byte(clientKey), - } - invalidBasicAuthSession := session.Session{ - Username: clientID, - Password: []byte(invalidValue), - } - creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey))) - encodedCredsSession := session.Session{ - Password: []byte(apiutil.BasicAuthPrefix + creds), - } - invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue))) - invalidEncodedCredsSession := session.Session{ - Password: []byte(apiutil.BasicAuthPrefix + invalidCreds), - } - hcClientKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - - tests := []struct { - desc string - session *session.Session - topic *string - payload *[]byte - authKey string - status int - clientType string - chanID string - domainID string - clientID string - authNToken string - superAdmin bool - authNRes *grpcClientsV1.AuthnRes - authNRes1 smqauthn.Session - authNErr error - authZRes *grpcChannelsV1.AuthzRes - authZErr error - err error - }{ - { - desc: "publish with client key successfully", - session: &clientKeySession, - topic: &topic, - authKey: clientKey, - payload: &payload, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with invalid client key", - session: &invalidClientKeySession, - topic: &topic, - authKey: invalidKey, - payload: &payload, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with nil session", - session: nil, - topic: &topic, - authKey: clientKey, - status: http.StatusInternalServerError, - err: errClientNotInitialized, - }, - { - desc: "publish with empty topic", - session: &clientKeySession, - topic: nil, - authKey: clientKey, - status: http.StatusBadRequest, - err: errMissingTopicPub, - }, - { - desc: "publish with unauthorized client key", - session: &unauthorizedKeySession, - topic: &topic, - authKey: clientKey, - payload: &payload, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with token successfully", - session: &tokenSession, - topic: &topic, - authKey: token, - payload: &payload, - status: http.StatusOK, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{UserID: userID}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with superadmin token successfully", - session: &tokenSession, - topic: &topic, - authKey: token, - payload: &payload, - status: http.StatusOK, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - superAdmin: true, - authNRes1: smqauthn.Session{UserID: userID, Role: smqauthn.SuperAdminRole}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with invalid token", - session: &invalidTokenSession, - topic: &topic, - authKey: invalidToken, - payload: &payload, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{}, - authNErr: svcerr.ErrAuthentication, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with unauthorized token", - session: &tokenSession, - topic: &topic, - authKey: token, - payload: &payload, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{UserID: userID}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with basic auth successfully", - session: &basicAuthSession, - topic: &topic, - payload: &payload, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with invalid basic auth", - session: &invalidBasicAuthSession, - topic: &topic, - payload: &payload, - authKey: invalidValue, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with b64 encoded credentials", - session: &encodedCredsSession, - topic: &topic, - payload: &payload, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with invalid b64 encoded credentials", - session: &invalidEncodedCredsSession, - topic: &topic, - payload: &payload, - authKey: invalidValue, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with health check topic successfully", - session: &hcClientKeySession, - topic: &hcTopic, - authKey: clientKey, - payload: &payload, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: "", - domainID: domainID, - clientID: userID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: true}, - authNErr: nil, - }, - { - desc: "publish with invalid health check topic", - session: &hcClientKeySession, - topic: &invalidHCTopic, - authKey: clientKey, - payload: &payload, - status: http.StatusBadRequest, - err: messaging.ErrMalformedTopic, - clientType: policies.ClientType, - }, - } - - for _, tc := range tests { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - tc.clientType = policies.ClientType - clientID := tc.clientID - if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) { - tc.clientType = policies.UserType - clientID = policies.EncodeDomainUserID(tc.domainID, tc.clientID) - if tc.superAdmin { - clientID = tc.clientID - } - } - clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr) - authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr) - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - ClientType: tc.clientType, - ClientId: clientID, - Type: uint32(connections.Publish), - ChannelId: tc.chanID, - DomainId: tc.domainID, - }).Return(tc.authZRes, tc.authZErr) - err := handler.AuthPublish(ctx, tc.topic, tc.payload) - hpe, ok := err.(mgate.HTTPProxyError) - if ok { - assert.Equal(t, tc.status, hpe.StatusCode()) - } - if tc.err != nil { - assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err)) - } - authCall.Unset() - clientsCall.Unset() - channelsCall.Unset() - }) - } -} - -func TestAuthSubscribe(t *testing.T) { - handler := newHandler(t) - - clientKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - unauthorizedKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - invalidClientKeySession := session.Session{ - Password: []byte("Client " + invalidKey), - } - tokenSession := session.Session{ - Password: []byte(apiutil.BearerPrefix + validToken), - } - invalidTokenSession := session.Session{ - Password: []byte(apiutil.BearerPrefix + invalidToken), - } - basicAuthSession := session.Session{ - Username: clientID, - Password: []byte(clientKey), - } - invalidBasicAuthSession := session.Session{ - Username: clientID, - Password: []byte(invalidValue), - } - creds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, clientKey))) - encodedCredsSession := session.Session{ - Password: []byte(apiutil.BasicAuthPrefix + creds), - } - invalidCreds := base64.URLEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", clientID, invalidValue))) - invalidEncodedCredsSession := session.Session{ - Password: []byte(apiutil.BasicAuthPrefix + invalidCreds), - } - hcClientKeySession := session.Session{ - Password: []byte("Client " + clientKey), - } - - tests := []struct { - desc string - session *session.Session - topics *[]string - authKey string - status int - clientType string - chanID string - domainID string - clientID string - authNToken string - superAdmin bool - authNRes *grpcClientsV1.AuthnRes - authNRes1 smqauthn.Session - authNErr error - authZRes *grpcChannelsV1.AuthzRes - authZErr error - err error - }{ - { - desc: "subscribe with client key successfully", - session: &clientKeySession, - topics: &topics, - authKey: clientKey, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe with invalid client key", - session: &invalidClientKeySession, - topics: &topics, - authKey: invalidKey, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, invalidKey), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe with empty topics", - session: &clientKeySession, - topics: nil, - authKey: clientKey, - status: http.StatusBadRequest, - err: errMissingTopicSub, - }, - { - desc: "subscribe with nil session", - session: nil, - topics: &topics, - authKey: clientKey, - status: http.StatusInternalServerError, - err: errClientNotInitialized, - }, - { - desc: "subscribe with unauthorized client key", - session: &unauthorizedKeySession, - topics: &topics, - authKey: clientKey, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe with token successfully", - session: &tokenSession, - topics: &topics, - authKey: token, - status: http.StatusOK, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{UserID: userID}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe with superadmin token successfully", - session: &tokenSession, - topics: &topics, - authKey: token, - status: http.StatusOK, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - superAdmin: true, - authNRes1: smqauthn.Session{UserID: userID, Role: smqauthn.SuperAdminRole}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe with invalid token", - session: &invalidTokenSession, - topics: &topics, - authKey: invalidToken, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{}, - authNErr: svcerr.ErrAuthentication, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe with unauthorized token", - session: &tokenSession, - topics: &topics, - authKey: token, - clientType: policies.UserType, - chanID: chanID, - domainID: domainID, - clientID: userID, - authNRes1: smqauthn.Session{UserID: userID}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe with basic auth successfully", - session: &basicAuthSession, - topics: &topics, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "subscribe with invalid basic auth", - session: &invalidBasicAuthSession, - topics: &topics, - authKey: invalidValue, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "publish with b64 encoded credentials", - session: &encodedCredsSession, - topics: &topics, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - authNErr: nil, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - err: nil, - }, - { - desc: "publish with invalid b64 encoded credentials", - session: &invalidEncodedCredsSession, - topics: &topics, - authKey: invalidValue, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.BasicAuth, clientID, invalidValue), - authNRes: &grpcClientsV1.AuthnRes{Authenticated: false}, - status: http.StatusUnauthorized, - err: svcerr.ErrAuthentication, - }, - { - desc: "subscribe with health check topic successfully", - session: &hcClientKeySession, - topics: &[]string{hcTopic}, - authKey: clientKey, - status: http.StatusOK, - clientType: policies.ClientType, - chanID: chanID, - domainID: domainID, - clientID: clientID, - authNToken: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, clientKey), - authNRes: &grpcClientsV1.AuthnRes{Id: clientID, Authenticated: true}, - err: nil, - }, - { - desc: "subscribe with invalid health check topic", - session: &hcClientKeySession, - topics: &[]string{invalidHCTopic}, - authKey: clientKey, - status: http.StatusBadRequest, - err: messaging.ErrMalformedTopic, - clientType: policies.ClientType, - }, - } - - for _, tc := range tests { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - tc.clientType = policies.ClientType - clientID := tc.clientID - if tc.session != nil && strings.HasPrefix(string(tc.session.Password), apiutil.BearerPrefix) { - tc.clientType = policies.UserType - clientID = policies.EncodeDomainUserID(tc.domainID, tc.clientID) - if tc.superAdmin { - clientID = tc.clientID - } - } - clientsCall := clients.On("Authenticate", ctx, &grpcClientsV1.AuthnReq{Token: tc.authNToken}).Return(tc.authNRes, tc.authNErr) - authCall := authn.On("Authenticate", ctx, mock.Anything).Return(tc.authNRes1, tc.authNErr) - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - ClientType: tc.clientType, - ClientId: clientID, - Type: uint32(connections.Subscribe), - ChannelId: tc.chanID, - DomainId: tc.domainID, - }).Return(tc.authZRes, tc.authZErr) - err := handler.AuthSubscribe(ctx, tc.topics) - hpe, ok := err.(mgate.HTTPProxyError) - if ok { - assert.Equal(t, tc.status, hpe.StatusCode()) - } - if tc.err != nil { - assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err)) - } - authCall.Unset() - clientsCall.Unset() - channelsCall.Unset() - }) - } -} - -func TestPublish(t *testing.T) { - handler := newHandler(t) - - malformedSubtopics := topic + "/" + subtopic + "%" - wrongCharSubtopics := topic + "/" + subtopic + ">" - validSubtopic := topic + "/" + subtopic - - cases := []struct { - desc string - session *session.Session - topic string - payload []byte - err error - }{ - { - desc: "publish without active session", - session: nil, - topic: topic, - payload: payload, - err: errClientNotInitialized, - }, - { - desc: "publish with invalid topic", - session: &sessionClient, - topic: invalidTopic, - payload: payload, - err: messaging.ErrMalformedTopic, - }, - { - desc: "publish with invalid channel ID", - session: &sessionClient, - topic: invalidChannelIDTopic, - payload: payload, - err: messaging.ErrMalformedTopic, - }, - { - desc: "publish with malformed subtopic", - session: &sessionClient, - topic: malformedSubtopics, - payload: payload, - err: messaging.ErrMalformedTopic, - }, - { - desc: "publish with subtopic containing wrong character", - session: &sessionClient, - topic: wrongCharSubtopics, - payload: payload, - err: messaging.ErrMalformedTopic, - }, - { - desc: "publish with subtopic", - session: &sessionClient, - topic: validSubtopic, - payload: payload, - }, - { - desc: "publish without subtopic", - session: &sessionClient, - topic: topic, - payload: payload, - }, - { - desc: "publish with health check topic", - session: &sessionClient, - topic: hcTopic, - payload: payload, - }, - { - desc: "puvlish with invalid health check topic", - session: &sessionClient, - topic: invalidHCTopic, - payload: payload, - err: messaging.ErrMalformedTopic, - }, - } - - for _, tc := range cases { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := handler.Publish(ctx, &tc.topic, &tc.payload) - if tc.err != nil { - assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err)) - } - repoCall.Unset() - } -} diff --git a/http/middleware/doc.go b/http/middleware/doc.go deleted file mode 100644 index 0f330e985..000000000 --- a/http/middleware/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package middleware provides logging, metrics and tracing middleware -// for SuperMQ HTTP service. -// -// For more details about tracing instrumentation for SuperMQ refer to the -// documentation at https://docs.supermq.absmach.eu/tracing/. -package middleware diff --git a/http/middleware/logging.go b/http/middleware/logging.go deleted file mode 100644 index 6df5e4a12..000000000 --- a/http/middleware/logging.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - "log/slog" - "time" - - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/messaging" -) - -var _ smqhttp.Service = (*loggingMiddleware)(nil) - -type loggingMiddleware struct { - logger *slog.Logger - svc smqhttp.Service -} - -// NewLogging adds logging facilities to the websocket service. -func NewLogging(svc smqhttp.Service, logger *slog.Logger) smqhttp.Service { - return &loggingMiddleware{logger, svc} -} - -// Subscribe logs the subscribe request. It logs the channel and subtopic(if present) and the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *smqhttp.Client) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("session_id", sessionID), - slog.String("channel_id", chanID), - slog.String("domain_id", domainID), - } - if subtopic != "" { - args = append(args, "subtopic", subtopic) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Subscribe failed", args...) - return - } - lm.logger.Info("Subscribe completed successfully", args...) - }(time.Now()) - - return lm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c) -} - -func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("session_id", sessionID), - slog.String("channel_id", chanID), - slog.String("domain_id", domainID), - } - if subtopic != "" { - args = append(args, "subtopic", subtopic) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Unsubscribe failed", args...) - return - } - lm.logger.Info("Unsubscribe completed successfully", args...) - }(time.Now()) - - return lm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic, topicType) -} diff --git a/http/middleware/metrics.go b/http/middleware/metrics.go deleted file mode 100644 index 2f5e8b47e..000000000 --- a/http/middleware/metrics.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build !test - -package middleware - -import ( - "context" - "time" - - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/messaging" - "github.com/go-kit/kit/metrics" -) - -var _ smqhttp.Service = (*metricsMiddleware)(nil) - -type metricsMiddleware struct { - counter metrics.Counter - latency metrics.Histogram - svc smqhttp.Service -} - -// NewMetrics instruments adapter by tracking request count and latency. -func NewMetrics(svc smqhttp.Service, counter metrics.Counter, latency metrics.Histogram) smqhttp.Service { - return &metricsMiddleware{ - counter: counter, - latency: latency, - svc: svc, - } -} - -// Subscribe instruments Subscribe method with metrics. -func (mm *metricsMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, c *smqhttp.Client) error { - defer func(begin time.Time) { - mm.counter.With("method", "subscribe").Add(1) - mm.latency.With("method", "subscribe").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, c) -} - -func (mm *metricsMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error { - defer func(begin time.Time) { - mm.counter.With("method", "unsubscribe").Add(1) - mm.latency.With("method", "unsubscribe").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic, topicType) -} diff --git a/http/middleware/tracing.go b/http/middleware/tracing.go deleted file mode 100644 index 9476bb189..000000000 --- a/http/middleware/tracing.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - - smqhttp "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/messaging" - "go.opentelemetry.io/otel/trace" -) - -var _ smqhttp.Service = (*tracingMiddleware)(nil) - -const ( - subscribeOP = "subscribe_op" - unsubscribeOP = "unsubscribe_op" -) - -type tracingMiddleware struct { - tracer trace.Tracer - svc smqhttp.Service -} - -// NewTracing returns a new websocket service with tracing capabilities. -func NewTracing(tracer trace.Tracer, svc smqhttp.Service) smqhttp.Service { - return &tracingMiddleware{ - tracer: tracer, - svc: svc, - } -} - -// Subscribe traces the "Subscribe" operation of the wrapped smqhttp.Service. -func (tm *tracingMiddleware) Subscribe(ctx context.Context, sessionID, username, password, domainID, chanID, subtopic string, topicType messaging.TopicType, client *smqhttp.Client) error { - ctx, span := tm.tracer.Start(ctx, subscribeOP) - defer span.End() - - return tm.svc.Subscribe(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client) -} - -func (tm *tracingMiddleware) Unsubscribe(ctx context.Context, sessionID, domainID, chanID, subtopic string, topicType messaging.TopicType) error { - ctx, span := tm.tracer.Start(ctx, unsubscribeOP) - defer span.End() - - return tm.svc.Unsubscribe(ctx, sessionID, domainID, chanID, subtopic, topicType) -} diff --git a/http/mocks/service.go b/http/mocks/service.go deleted file mode 100644 index c1ba27c02..000000000 --- a/http/mocks/service.go +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) Abstract Machines - -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mocks - -import ( - "context" - - "github.com/absmach/supermq/http" - "github.com/absmach/supermq/pkg/messaging" - mock "github.com/stretchr/testify/mock" -) - -// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewService(t interface { - mock.TestingT - Cleanup(func()) -}) *Service { - mock := &Service{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// Service is an autogenerated mock type for the Service type -type Service struct { - mock.Mock -} - -type Service_Expecter struct { - mock *mock.Mock -} - -func (_m *Service) EXPECT() *Service_Expecter { - return &Service_Expecter{mock: &_m.Mock} -} - -// Subscribe provides a mock function for the type Service -func (_mock *Service) Subscribe(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client) error { - ret := _mock.Called(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string, messaging.TopicType, *http.Client) error); ok { - r0 = returnFunc(ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Service_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' -type Service_Subscribe_Call struct { - *mock.Call -} - -// Subscribe is a helper method to define mock.On call -// - ctx context.Context -// - sessionID string -// - username string -// - password string -// - domainID string -// - chanID string -// - subtopic string -// - topicType messaging.TopicType -// - client *http.Client -func (_e *Service_Expecter) Subscribe(ctx interface{}, sessionID interface{}, username interface{}, password interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, topicType interface{}, client interface{}) *Service_Subscribe_Call { - return &Service_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, sessionID, username, password, domainID, chanID, subtopic, topicType, client)} -} - -func (_c *Service_Subscribe_Call) Run(run func(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client)) *Service_Subscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 string - if args[1] != nil { - arg1 = args[1].(string) - } - var arg2 string - if args[2] != nil { - arg2 = args[2].(string) - } - var arg3 string - if args[3] != nil { - arg3 = args[3].(string) - } - var arg4 string - if args[4] != nil { - arg4 = args[4].(string) - } - var arg5 string - if args[5] != nil { - arg5 = args[5].(string) - } - var arg6 string - if args[6] != nil { - arg6 = args[6].(string) - } - var arg7 messaging.TopicType - if args[7] != nil { - arg7 = args[7].(messaging.TopicType) - } - var arg8 *http.Client - if args[8] != nil { - arg8 = args[8].(*http.Client) - } - run( - arg0, - arg1, - arg2, - arg3, - arg4, - arg5, - arg6, - arg7, - arg8, - ) - }) - return _c -} - -func (_c *Service_Subscribe_Call) Return(err error) *Service_Subscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Service_Subscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, username string, password string, domainID string, chanID string, subtopic string, topicType messaging.TopicType, client *http.Client) error) *Service_Subscribe_Call { - _c.Call.Return(run) - return _c -} - -// Unsubscribe provides a mock function for the type Service -func (_mock *Service) Unsubscribe(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType) error { - ret := _mock.Called(ctx, sessionID, domainID, chanID, subtopic, topicType) - - if len(ret) == 0 { - panic("no return value specified for Unsubscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, messaging.TopicType) error); ok { - r0 = returnFunc(ctx, sessionID, domainID, chanID, subtopic, topicType) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// Service_Unsubscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Unsubscribe' -type Service_Unsubscribe_Call struct { - *mock.Call -} - -// Unsubscribe is a helper method to define mock.On call -// - ctx context.Context -// - sessionID string -// - domainID string -// - chanID string -// - subtopic string -// - topicType messaging.TopicType -func (_e *Service_Expecter) Unsubscribe(ctx interface{}, sessionID interface{}, domainID interface{}, chanID interface{}, subtopic interface{}, topicType interface{}) *Service_Unsubscribe_Call { - return &Service_Unsubscribe_Call{Call: _e.mock.On("Unsubscribe", ctx, sessionID, domainID, chanID, subtopic, topicType)} -} - -func (_c *Service_Unsubscribe_Call) Run(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType)) *Service_Unsubscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 string - if args[1] != nil { - arg1 = args[1].(string) - } - var arg2 string - if args[2] != nil { - arg2 = args[2].(string) - } - var arg3 string - if args[3] != nil { - arg3 = args[3].(string) - } - var arg4 string - if args[4] != nil { - arg4 = args[4].(string) - } - var arg5 messaging.TopicType - if args[5] != nil { - arg5 = args[5].(messaging.TopicType) - } - run( - arg0, - arg1, - arg2, - arg3, - arg4, - arg5, - ) - }) - return _c -} - -func (_c *Service_Unsubscribe_Call) Return(err error) *Service_Unsubscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *Service_Unsubscribe_Call) RunAndReturn(run func(ctx context.Context, sessionID string, domainID string, chanID string, subtopic string, topicType messaging.TopicType) error) *Service_Unsubscribe_Call { - _c.Call.Return(run) - return _c -} diff --git a/internal/email/README.md b/internal/email/README.md index cb7375b43..d373e3353 100644 --- a/internal/email/README.md +++ b/internal/email/README.md @@ -9,13 +9,13 @@ SuperMQ Email Agent is configured using the following configuration parameters: | Parameter | Description | | ----------------------------------- | ----------------------------------------------------------------------- | -| SMQ_EMAIL_HOST | Mail server host | -| SMQ_EMAIL_PORT | Mail server port | -| SMQ_EMAIL_USERNAME | Mail server username | -| SMQ_EMAIL_PASSWORD | Mail server password | -| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | -| SMQ_EMAIL_FROM_NAME | Email "from" name | -| SMQ_EMAIL_TEMPLATE | Email template for sending notification emails | +| MG_EMAIL_HOST | Mail server host | +| MG_EMAIL_PORT | Mail server port | +| MG_EMAIL_USERNAME | Mail server username | +| MG_EMAIL_PASSWORD | Mail server password | +| MG_EMAIL_FROM_ADDRESS | Email "from" address | +| MG_EMAIL_FROM_NAME | Email "from" name | +| MG_EMAIL_TEMPLATE | Email template for sending notification emails | There are two authentication methods supported: Basic Auth and CRAM-MD5. -If `SMQ_EMAIL_USERNAME` is empty, no authentication will be used. +If `MG_EMAIL_USERNAME` is empty, no authentication will be used. diff --git a/internal/email/email.go b/internal/email/email.go index a70c9d79c..2785b3e47 100644 --- a/internal/email/email.go +++ b/internal/email/email.go @@ -5,6 +5,8 @@ package email import ( "bytes" + "fmt" + "io" "net/mail" "strconv" "strings" @@ -35,13 +37,13 @@ type email struct { // Config email agent configuration. type Config struct { - Host string `env:"SMQ_EMAIL_HOST" envDefault:"localhost"` - Port string `env:"SMQ_EMAIL_PORT" envDefault:"25"` - Username string `env:"SMQ_EMAIL_USERNAME" envDefault:"root"` - Password string `env:"SMQ_EMAIL_PASSWORD" envDefault:""` - FromAddress string `env:"SMQ_EMAIL_FROM_ADDRESS" envDefault:""` - FromName string `env:"SMQ_EMAIL_FROM_NAME" envDefault:""` - Template string `env:"SMQ_EMAIL_TEMPLATE" envDefault:"email.tmpl"` + Host string `env:"MG_EMAIL_HOST" envDefault:"localhost"` + Port string `env:"MG_EMAIL_PORT" envDefault:"25"` + Username string `env:"MG_EMAIL_USERNAME" envDefault:"root"` + Password string `env:"MG_EMAIL_PASSWORD" envDefault:""` + FromAddress string `env:"MG_EMAIL_FROM_ADDRESS" envDefault:""` + FromName string `env:"MG_EMAIL_FROM_NAME" envDefault:""` + Template string `env:"MG_EMAIL_TEMPLATE" envDefault:"email.tmpl"` } // Agent for mailing. @@ -71,7 +73,7 @@ func New(c *Config) (*Agent, error) { } // Send sends e-mail. -func (a *Agent) Send(to []string, from, subject, header, user, content, footer string) error { +func (a *Agent) Send(to []string, from, subject, header, user, content, footer string, attachments map[string][]byte) error { if a.tmpl == nil { return errMissingEmailTemplate } @@ -102,6 +104,22 @@ func (a *Agent) Send(to []string, from, subject, header, user, content, footer s m.SetHeader("Subject", subject) m.SetBody("text/html", buff.String()) + for filename, data := range attachments { + reader := bytes.NewReader(data) + + settings := []gomail.FileSetting{ + gomail.SetHeader(map[string][]string{ + "Content-Disposition": {fmt.Sprintf(`attachment; filename="%s"`, filename)}, + }), + gomail.SetCopyFunc(func(w io.Writer) error { + _, err := io.Copy(w, reader) + return err + }), + } + + m.Attach(filename, settings...) + } + if err := a.dial.DialAndSend(m); err != nil { return errors.Wrap(errSendMail, err) } diff --git a/internal/proto/certs/v1/certs.proto b/internal/proto/certs/v1/certs.proto new file mode 100644 index 000000000..9185b3631 --- /dev/null +++ b/internal/proto/certs/v1/certs.proto @@ -0,0 +1,27 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package absmach.certs; + +import "google/protobuf/empty.proto"; + +option go_package = "github.com/absmach/supermq/api/grpc/certs/v1"; + +service CertsService { + rpc GetEntityID(EntityReq) returns (EntityRes) {} + rpc RevokeCerts(RevokeReq) returns (google.protobuf.Empty) {} +} + +message EntityReq { + string serial_number = 1; +} + +message EntityRes { + string entity_id = 1; +} + +message RevokeReq { + string entity_id = 1; +} diff --git a/internal/proto/readers/v1/readers.proto b/internal/proto/readers/v1/readers.proto new file mode 100644 index 000000000..03b614c84 --- /dev/null +++ b/internal/proto/readers/v1/readers.proto @@ -0,0 +1,92 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package readers.v1; + +option go_package = "github.com/absmach/supermq/api/grpc/readers/v1"; + +// ReadersService is a service that provides access to +// readers functionalities for SuperMQ services. +service ReadersService { + rpc ReadMessages(ReadMessagesReq) + returns (ReadMessagesRes) {} +} + +message PageMetadata { + uint64 limit = 1; + uint64 offset = 2; + string protocol = 3; + string name = 4; + double value = 5; + string publisher = 6; + bool bool_value = 7; + string string_value = 8; + string data_value = 9; + double from = 10; + double to = 11; + string subtopic = 12; + string interval = 13; + bool read = 14; + Aggregation aggregation = 15; + string comparator = 16; + string format = 17; + string order = 18; + string dir = 19; +} + +message ReadMessagesRes { + uint64 total = 1; + PageMetadata page_metadata = 2; + repeated Message messages = 3; +} + +message Message { + oneof payload { + SenMLMessage senml = 1; + JsonMessage json = 2; + } +} + +message BaseMessage { + string channel = 1; + string subtopic = 2; + string publisher = 3; + string protocol = 4; +} + +message SenMLMessage { + BaseMessage base = 1; + string name = 2; + string unit = 3; + double time = 4; + double update_time = 5; + optional double value = 6; + optional string string_value = 7; + optional string data_value = 8; + optional bool bool_value = 9; + optional double sum = 10; +} + +message JsonMessage { + BaseMessage base = 1; + int64 created = 2; + bytes payload = 3; +} + +message ReadMessagesReq { + string channel_id = 1; + string domain_id = 2; + PageMetadata page_metadata = 3; +} + +// Aggregation defines supported data aggregations. +enum Aggregation { + AGGREGATION_UNSPECIFIED = 0; + AGGREGATION_MAX = 1; + AGGREGATION_MIN = 2; + AGGREGATION_SUM = 3; + AGGREGATION_COUNT = 4; + AGGREGATION_AVG = 5; +} diff --git a/journal/README.md b/journal/README.md index 12efcf47d..5afd26cb5 100644 --- a/journal/README.md +++ b/journal/README.md @@ -8,42 +8,42 @@ The service is configured with the following environment variables (unset values | Variable | Description | Default | | --- | --- | --- | -| `SMQ_JOURNAL_LOG_LEVEL` | Log level for Journal (debug, info, warn, error) | info | -| `SMQ_JOURNAL_HTTP_HOST` | Journal HTTP host | localhost | -| `SMQ_JOURNAL_HTTP_PORT` | Journal HTTP port | 9021 | -| `SMQ_JOURNAL_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | -| `SMQ_JOURNAL_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | -| `SMQ_JOURNAL_HTTP_SERVER_CA_CERTS` | Path to trusted CA bundle for the HTTP server | "" | -| `SMQ_JOURNAL_HTTP_CLIENT_CA_CERTS` | Path to client CA bundle to require HTTP mTLS | "" | -| `SMQ_JOURNAL_DB_HOST` | Database host address | localhost | -| `SMQ_JOURNAL_DB_PORT` | Database host port | 5432 | -| `SMQ_JOURNAL_DB_USER` | Database user | supermq | -| `SMQ_JOURNAL_DB_PASS` | Database password | supermq | -| `SMQ_JOURNAL_DB_NAME` | Name of the database used by the service | journal | -| `SMQ_JOURNAL_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| `SMQ_JOURNAL_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | -| `SMQ_JOURNAL_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | -| `SMQ_JOURNAL_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | -| `SMQ_ES_URL` | Event store URL (NATS) consumed for journal entries | nats://localhost:4222 | -| `SMQ_JAEGER_URL` | Jaeger tracing endpoint | | -| `SMQ_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | -| `SMQ_AUTH_GRPC_URL` | Auth service gRPC URL | "" | -| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service gRPC timeout | 1s | -| `SMQ_AUTH_GRPC_CLIENT_CERT` | Path to PEM-encoded Auth gRPC client certificate | "" | -| `SMQ_AUTH_GRPC_CLIENT_KEY` | Path to PEM-encoded Auth gRPC client key | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Path to PEM-encoded Auth gRPC trusted CA bundle | "" | -| `SMQ_DOMAINS_GRPC_URL` | Domains service gRPC URL | "" | -| `SMQ_DOMAINS_GRPC_TIMEOUT` | Domains service gRPC timeout | 1s | -| `SMQ_DOMAINS_GRPC_CLIENT_CERT` | Path to PEM-encoded Domains gRPC client certificate | "" | -| `SMQ_DOMAINS_GRPC_CLIENT_KEY` | Path to PEM-encoded Domains gRPC client key | "" | -| `SMQ_DOMAINS_GRPC_SERVER_CA_CERTS` | Path to PEM-encoded Domains gRPC trusted CA bundle | "" | -| `SMQ_JOURNAL_INSTANCE_ID` | Journal instance ID (auto-generated when empty) | "" | -| `SMQ_ALLOW_UNVERIFIED_USER` | Allow unverified users to authenticate (useful in dev) | false | +| `MG_JOURNAL_LOG_LEVEL` | Log level for Journal (debug, info, warn, error) | info | +| `MG_JOURNAL_HTTP_HOST` | Journal HTTP host | localhost | +| `MG_JOURNAL_HTTP_PORT` | Journal HTTP port | 9021 | +| `MG_JOURNAL_HTTP_SERVER_CERT` | Path to PEM-encoded HTTP server certificate | "" | +| `MG_JOURNAL_HTTP_SERVER_KEY` | Path to PEM-encoded HTTP server key | "" | +| `MG_JOURNAL_HTTP_SERVER_CA_CERTS` | Path to trusted CA bundle for the HTTP server | "" | +| `MG_JOURNAL_HTTP_CLIENT_CA_CERTS` | Path to client CA bundle to require HTTP mTLS | "" | +| `MG_JOURNAL_DB_HOST` | Database host address | localhost | +| `MG_JOURNAL_DB_PORT` | Database host port | 5432 | +| `MG_JOURNAL_DB_USER` | Database user | supermq | +| `MG_JOURNAL_DB_PASS` | Database password | supermq | +| `MG_JOURNAL_DB_NAME` | Name of the database used by the service | journal | +| `MG_JOURNAL_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| `MG_JOURNAL_DB_SSL_CERT` | Path to the PEM-encoded certificate file | "" | +| `MG_JOURNAL_DB_SSL_KEY` | Path to the PEM-encoded key file | "" | +| `MG_JOURNAL_DB_SSL_ROOT_CERT` | Path to the PEM-encoded root certificate file | "" | +| `MG_ES_URL` | Event store URL (NATS) consumed for journal entries | nats://localhost:4222 | +| `MG_JAEGER_URL` | Jaeger tracing endpoint | | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | 1.0 | +| `MG_SEND_TELEMETRY` | Send telemetry to the SuperMQ call-home server | true | +| `MG_AUTH_GRPC_URL` | Auth service gRPC URL | "" | +| `MG_AUTH_GRPC_TIMEOUT` | Auth service gRPC timeout | 1s | +| `MG_AUTH_GRPC_CLIENT_CERT` | Path to PEM-encoded Auth gRPC client certificate | "" | +| `MG_AUTH_GRPC_CLIENT_KEY` | Path to PEM-encoded Auth gRPC client key | "" | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Path to PEM-encoded Auth gRPC trusted CA bundle | "" | +| `MG_DOMAINS_GRPC_URL` | Domains service gRPC URL | "" | +| `MG_DOMAINS_GRPC_TIMEOUT` | Domains service gRPC timeout | 1s | +| `MG_DOMAINS_GRPC_CLIENT_CERT` | Path to PEM-encoded Domains gRPC client certificate | "" | +| `MG_DOMAINS_GRPC_CLIENT_KEY` | Path to PEM-encoded Domains gRPC client key | "" | +| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Path to PEM-encoded Domains gRPC trusted CA bundle | "" | +| `MG_JOURNAL_INSTANCE_ID` | Journal instance ID (auto-generated when empty) | "" | +| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to authenticate (useful in dev) | false | ## Deployment -The service is distributed as a Docker container. Check the [`journals`](https://github.com/absmach/supermq/tree/main/docker/addons/journal) for the `journal` and `journal-db` services and how they are wired into the base stack. +The service is distributed as a Docker container. Check [`docker/docker-compose.yaml`](https://github.com/absmach/supermq/tree/main/docker/docker-compose.yaml) for the `journal` and `journal-db` services and how they are wired into the base stack. To start the service outside of the container, execute the following shell script: @@ -56,16 +56,16 @@ make journal make install # run with the essentials; requires Postgres, Auth gRPC, Domains gRPC, and NATS running -SMQ_JOURNAL_HTTP_HOST=localhost \ -SMQ_JOURNAL_HTTP_PORT=9021 \ -SMQ_JOURNAL_DB_HOST=localhost \ -SMQ_JOURNAL_DB_PORT=5432 \ -SMQ_JOURNAL_DB_USER=supermq \ -SMQ_JOURNAL_DB_PASS=supermq \ -SMQ_JOURNAL_DB_NAME=journal \ -SMQ_AUTH_GRPC_URL=localhost:7001 \ -SMQ_DOMAINS_GRPC_URL=localhost:7003 \ -SMQ_ES_URL=nats://localhost:4222 \ +MG_JOURNAL_HTTP_HOST=localhost \ +MG_JOURNAL_HTTP_PORT=9021 \ +MG_JOURNAL_DB_HOST=localhost \ +MG_JOURNAL_DB_PORT=5432 \ +MG_JOURNAL_DB_USER=supermq \ +MG_JOURNAL_DB_PASS=supermq \ +MG_JOURNAL_DB_NAME=journal \ +MG_AUTH_GRPC_URL=localhost:7001 \ +MG_DOMAINS_GRPC_URL=localhost:7003 \ +MG_ES_URL=nats://localhost:4222 \ $GOBIN/supermq-journal ``` diff --git a/journal/middleware/authorization.go b/journal/middleware/authorization.go index 0725c452f..6abb4a3c6 100644 --- a/journal/middleware/authorization.go +++ b/journal/middleware/authorization.go @@ -45,7 +45,7 @@ func (am *authorizationMiddleware) RetrieveAll(ctx context.Context, session smqa if page.EntityType.String() == policies.UserType { permission = policies.AdminPermission objectType = policies.PlatformType - object = policies.SuperMQObject + object = policies.MagistralaObject subject = session.UserID } diff --git a/journal/postgres/init.go b/journal/postgres/init.go index 830239561..ce9819cf0 100644 --- a/journal/postgres/init.go +++ b/journal/postgres/init.go @@ -53,6 +53,22 @@ func Migration() *migrate.MemoryMigrationSource { { Id: "journal_02", Up: []string{ + `CREATE TABLE IF NOT EXISTS clients_telemetry ( + client_id VARCHAR(36) PRIMARY KEY, + domain_id VARCHAR(36) NOT NULL, + inbound_messages BIGINT DEFAULT 0, + outbound_messages BIGINT DEFAULT 0, + first_seen TIMESTAMP, + last_seen TIMESTAMP + )`, + `CREATE TABLE IF NOT EXISTS subscriptions ( + id VARCHAR(36) PRIMARY KEY, + subscriber_id VARCHAR(1024) NOT NULL, + channel_id VARCHAR(36) NOT NULL, + subtopic VARCHAR(1024), + client_id VARCHAR(36), + FOREIGN KEY (client_id) REFERENCES clients_telemetry(client_id) ON DELETE CASCADE ON UPDATE CASCADE + )`, `ALTER TABLE journal ALTER COLUMN occurred_at TYPE TIMESTAMPTZ;`, `ALTER TABLE clients_telemetry ALTER COLUMN first_seen TYPE TIMESTAMPTZ;`, `ALTER TABLE clients_telemetry ALTER COLUMN last_seen TYPE TIMESTAMPTZ;`, diff --git a/mqtt/README.md b/mqtt/README.md deleted file mode 100644 index fbd6a2ff2..000000000 --- a/mqtt/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# MQTT adapter - -MQTT adapter provides an MQTT API for sending messages through the platform. MQTT adapter uses [mProxy](https://github.com/absmach/mproxy) for proxying traffic between client and MQTT broker. - -## Configuration - -The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. - -| Variable | Description | Default | -| ------------------------------------------ | ----------------------------------------------------------------------------------- | ----------------------------------- | -| SMQ_MQTT_ADAPTER_LOG_LEVEL | Log level for the MQTT Adapter (debug, info, warn, error) | info | -| SMQ_MQTT_ADAPTER_MQTT_PORT | mProxy port | 1883 | -| SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST | MQTT broker host | localhost | -| SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT | MQTT broker port | 1883 | -| SMQ_MQTT_ADAPTER_MQTT_QOS | MQTT broker QoS | 1 | -| SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT | MQTT forwarder for multiprotocol communication timeout | 30s | -| SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK | URL of broker health check | "" | -| SMQ_MQTT_ADAPTER_WS_PORT | mProxy MQTT over WS port | 8080 | -| SMQ_MQTT_ADAPTER_WS_TARGET_HOST | MQTT broker host for MQTT over WS | localhost | -| SMQ_MQTT_ADAPTER_WS_TARGET_PORT | MQTT broker port for MQTT over WS | 8080 | -| SMQ_MQTT_ADAPTER_WS_TARGET_PATH | MQTT broker MQTT over WS path | /mqtt | -| SMQ_MQTT_ADAPTER_CACHE_NUM_COUNTERS | Number of cache counters to keep that hold access frequency information | 200000 | -| SMQ_MQTT_ADAPTER_CACHE_MAX_COST | Maximum size of the cache(in bytes) | 1048576 | -| SMQ_MQTT_ADAPTER_CACHE_BUFFER_ITEMS | Number of cache `Get` buffers | 64 | -| SMQ_MQTT_ADAPTER_INSTANCE | Instance name for MQTT adapter | "" | -| SMQ_CLIENTS_GRPC_URL | Clients service Auth gRPC URL | | -| SMQ_CLIENTS_GRPC_TIMEOUT | Clients service Auth gRPC request timeout in seconds | 1s | -| SMQ_CLIENTS_GRPC_CLIENT_CERT | Path to the PEM encoded clients service Auth gRPC client certificate file | "" | -| SMQ_CLIENTS_GRPC_CLIENT_KEY | Path to the PEM encoded clients service Auth gRPC client key file | "" | -| SMQ_CLIENTS_GRPC_SERVER_CERTS | Path to the PEM encoded clients server Auth gRPC server trusted CA certificate file | "" | -| SMQ_ES_URL | Event sourcing URL | | -| SMQ_MESSAGE_BROKER_URL | Message broker instance URL | | -| SMQ_JAEGER_URL | Jaeger server URL | | -| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | -| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server | true | -| SMQ_MQTT_ADAPTER_INSTANCE_ID | Service instance ID | "" | -| SMQ_MQTT_ADAPTER_CERT_FILE | Path to the PEM encoded TLS certificate file for MQTT adapter | "" | -| SMQ_MQTT_ADAPTER_KEY_FILE | Path to the PEM encoded TLS key file for MQTT adapter | "" | -| SMQ_MQTT_ADAPTER_SERVER_CA_FILE | Path to the PEM encoded server CA certificate file for MQTT adapter | "" | -| SMQ_MQTT_ADAPTER_CLIENT_CA_FILE | Path to the PEM encoded client CA certificate file for MQTT adapter | "" | -| SMQ_MQTT_ADAPTER_OCSP_RESPONDER_URL | URL of the OCSP responder for MQTT adapter | "" | -| SMQ_MQTT_ADAPTER_CERT_VERIFICATION_METHODS | Methods for certificate verification (e.g., ocsp) | "" | - -## Deployment - -The service itself is distributed as Docker container. Check the [`mqtt-adapter`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in docker-compose file to see how service is deployed. - -Running this service outside of container requires working instance of the message broker service, clients service and Jaeger server. -To start the service outside of the container, execute the following shell script: - -```bash -# download the latest version of the service -git clone https://github.com/absmach/supermq - -cd supermq - -# compile the mqtt -make mqtt - -# copy binary to bin -make install - -# set the environment variables and run the service -SMQ_MQTT_ADAPTER_LOG_LEVEL=info \ -SMQ_MQTT_ADAPTER_MQTT_PORT=1883 \ -SMQ_MQTT_ADAPTER_MQTT_TARGET_HOST=localhost \ -SMQ_MQTT_ADAPTER_MQTT_TARGET_PORT=1883 \ -SMQ_MQTT_ADAPTER_MQTT_QOS=1 \ -SMQ_MQTT_ADAPTER_FORWARDER_TIMEOUT=30s \ -SMQ_MQTT_ADAPTER_MQTT_TARGET_HEALTH_CHECK="" \ -SMQ_MQTT_ADAPTER_WS_PORT=8080 \ -SMQ_MQTT_ADAPTER_WS_TARGET_HOST=localhost \ -SMQ_MQTT_ADAPTER_WS_TARGET_PORT=8080 \ -SMQ_MQTT_ADAPTER_WS_TARGET_PATH=/mqtt \ -SMQ_MQTT_ADAPTER_CACHE_NUM_COUNTERS=200000 \ -SMQ_MQTT_ADAPTER_CACHE_MAX_COST=1048576 \ -SMQ_MQTT_ADAPTER_CACHE_BUFFER_ITEMS=64 \ -SMQ_MQTT_ADAPTER_INSTANCE="" \ -SMQ_CLIENTS_GRPC_URL=localhost:7000 \ -SMQ_CLIENTS_GRPC_TIMEOUT=1s \ -SMQ_CLIENTS_GRPC_CLIENT_CERT="" \ -SMQ_CLIENTS_GRPC_CLIENT_KEY="" \ -SMQ_CLIENTS_GRPC_SERVER_CERTS="" \ -SMQ_ES_URL=amqp://guest:guest@rabbitmq:5672/ \ -SMQ_MESSAGE_BROKER_URL=amqp://guest:guest@rabbitmq:5672/ \ -SMQ_JAEGER_URL=http://localhost:14268/api/traces \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_SEND_TELEMETRY=true \ -SMQ_MQTT_ADAPTER_INSTANCE_ID="" \ -$GOBIN/supermq-mqtt -``` - -Setting `SMQ_CLIENTS_GRPC_CLIENT_CERT` and `SMQ_CLIENTS_GRPC_CLIENT_KEY` will enable TLS against the clients service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_CLIENTS_GRPC_SERVER_CERTS` will enable TLS against the clients service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. - -Setting `SMQ_MQTT_ADAPTER_CERT_FILE`, `SMQ_MQTT_ADAPTER_KEY_FILE`, and `SMQ_MQTT_ADAPTER_SERVER_CA_FILE` will enable TLS for incoming MQTT connections. The service expects a file in PEM format for both the certificate and the key. The service expects a file in PEM format of trusted CAs. Setting `SMQ_MQTT_ADAPTER_CLIENT_CA_FILE` will enable client certificate verification for incoming MQTT connections trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. Setting `SMQ_MQTT_ADAPTER_CERT_VERIFICATION_METHODS` to "ocsp" will enable OCSP verification for incoming MQTT connections. Setting `SMQ_MQTT_ADAPTER_OCSP_RESPONDER_URL` will set the OCSP responder URL for OCSP verification. - -For more information about service capabilities and its usage, please check out the API documentation [API](https://github.com/absmach/supermq/blob/main/api/asyncapi/mqtt.yaml). diff --git a/mqtt/doc.go b/mqtt/doc.go deleted file mode 100644 index dc4938a76..000000000 --- a/mqtt/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package mqtt contains the domain concept definitions needed to support -// SuperMQ MQTT service functionality. -package mqtt diff --git a/mqtt/events/events.go b/mqtt/events/events.go deleted file mode 100644 index 943ee00eb..000000000 --- a/mqtt/events/events.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package events - -import "github.com/absmach/supermq/pkg/events" - -const ( - mqttPrefix = "mqtt" - clientSubscribe = mqttPrefix + ".client_subscribe" - clientConnect = mqttPrefix + ".client_connect" - clientDisconnect = mqttPrefix + ".client_disconnect" -) - -var ( - _ events.Event = (*connectEvent)(nil) - _ events.Event = (*subscribeEvent)(nil) -) - -type connectEvent struct { - operation string - clientID string - subscriberID string - instance string -} - -func (ce connectEvent) Encode() (map[string]any, error) { - return map[string]any{ - "operation": ce.operation, - "client_id": ce.clientID, - "subscriber_id": ce.subscriberID, - "instance": ce.instance, - }, nil -} - -type subscribeEvent struct { - operation string - clientID string - subscriberID string - domainID string - channelID string - subtopic string -} - -func (se subscribeEvent) Encode() (map[string]any, error) { - return map[string]any{ - "operation": se.operation, - "client_id": se.clientID, - "subscriber_id": se.subscriberID, - "domainID": se.domainID, - "channel_id": se.channelID, - "subtopic": se.subtopic, - }, nil -} diff --git a/mqtt/events/streams.go b/mqtt/events/streams.go deleted file mode 100644 index 17f1ab7ae..000000000 --- a/mqtt/events/streams.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package events - -import ( - "context" - - "github.com/absmach/mgate/pkg/session" - "github.com/absmach/supermq/pkg/errors" - "github.com/absmach/supermq/pkg/events" - "github.com/absmach/supermq/pkg/events/store" - "github.com/absmach/supermq/pkg/messaging" -) - -const ( - supermqPrefix = "supermq." - subscribeStream = supermqPrefix + clientSubscribe - connectStream = supermqPrefix + clientConnect - disconnectStream = supermqPrefix + clientDisconnect -) - -var errFailedSession = errors.New("failed to obtain session from context") - -// EventStore is a struct used to store event streams in Redis. -type eventStore struct { - ep events.Publisher - handler session.Handler - instance string -} - -// NewEventStoreMiddleware returns middleware around mGate service that sends -// events to event store. -func NewEventStoreMiddleware(ctx context.Context, handler session.Handler, url, instance string) (session.Handler, error) { - publisher, err := store.NewPublisher(ctx, url) - if err != nil { - return nil, err - } - - return &eventStore{ - ep: publisher, - handler: handler, - instance: instance, - }, nil -} - -func (es *eventStore) AuthConnect(ctx context.Context) error { - if err := es.handler.AuthConnect(ctx); err != nil { - return err - } - s, ok := session.FromContext(ctx) - if !ok { - return errFailedSession - } - - ev := connectEvent{ - operation: clientConnect, - clientID: s.Username, - subscriberID: s.ID, - instance: es.instance, - } - - return es.ep.Publish(ctx, connectStream, ev) -} - -func (es *eventStore) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - return es.handler.AuthPublish(ctx, topic, payload) -} - -func (es *eventStore) AuthSubscribe(ctx context.Context, topics *[]string) error { - return es.handler.AuthSubscribe(ctx, topics) -} - -func (es *eventStore) Connect(ctx context.Context) error { - return es.handler.Connect(ctx) -} - -func (es *eventStore) Publish(ctx context.Context, topic *string, payload *[]byte) error { - return es.handler.Publish(ctx, topic, payload) -} - -func (es *eventStore) Subscribe(ctx context.Context, topics *[]string) error { - if err := es.handler.Subscribe(ctx, topics); err != nil { - return err - } - - s, ok := session.FromContext(ctx) - if !ok { - return errFailedSession - } - - for _, topic := range *topics { - domainID, channelID, subTopic, _, err := messaging.ParseSubscribeTopic(topic) - if err != nil { - return err - } - ev := subscribeEvent{ - operation: clientSubscribe, - clientID: s.Username, - domainID: domainID, - channelID: channelID, - subtopic: subTopic, - subscriberID: s.ID, - } - - if err := es.ep.Publish(ctx, subscribeStream, ev); err != nil { - return err - } - } - - return nil -} - -func (es *eventStore) Unsubscribe(ctx context.Context, topics *[]string) error { - return es.handler.Unsubscribe(ctx, topics) -} - -func (es *eventStore) Disconnect(ctx context.Context) error { - if err := es.handler.Disconnect(ctx); err != nil { - return err - } - - s, ok := session.FromContext(ctx) - if !ok { - return errFailedSession - } - - ev := connectEvent{ - operation: clientDisconnect, - clientID: s.Username, - subscriberID: s.ID, - instance: es.instance, - } - - return es.ep.Publish(ctx, disconnectStream, ev) -} diff --git a/mqtt/forwarder.go b/mqtt/forwarder.go deleted file mode 100644 index a790e2853..000000000 --- a/mqtt/forwarder.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package mqtt - -import ( - "context" - "fmt" - "log/slog" - - "github.com/absmach/supermq/pkg/messaging" -) - -// Forwarder specifies MQTT forwarder interface API. -type Forwarder interface { - // Forward subscribes to the Subscriber and - // publishes messages using provided Publisher. - Forward(ctx context.Context, id string, sub messaging.Subscriber, pub messaging.Publisher) error -} - -type forwarder struct { - topic string - logger *slog.Logger -} - -// NewForwarder returns new Forwarder implementation. -func NewForwarder(topic string, logger *slog.Logger) Forwarder { - return forwarder{ - topic: topic, - logger: logger, - } -} - -func (f forwarder) Forward(ctx context.Context, id string, sub messaging.Subscriber, pub messaging.Publisher) error { - subCfg := messaging.SubscriberConfig{ - ID: id, - Topic: f.topic, - Handler: handle(ctx, pub, f.logger), - } - - return sub.Subscribe(ctx, subCfg) -} - -func handle(ctx context.Context, pub messaging.Publisher, logger *slog.Logger) handleFunc { - return func(msg *messaging.Message) error { - if msg.GetProtocol() == protocol { - return nil - } - - topic := messaging.EncodeMessageMQTTTopic(msg) - - go func() { - if err := pub.Publish(ctx, topic, msg); err != nil { - logger.Warn(fmt.Sprintf("Failed to forward message: %s", err)) - } - }() - - return nil - } -} - -type handleFunc func(msg *messaging.Message) error - -func (h handleFunc) Handle(msg *messaging.Message) error { - return h(msg) -} - -func (h handleFunc) Cancel() error { - return nil -} diff --git a/mqtt/handler.go b/mqtt/handler.go deleted file mode 100644 index 53b5bd4e0..000000000 --- a/mqtt/handler.go +++ /dev/null @@ -1,246 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package mqtt - -import ( - "context" - "fmt" - "log/slog" - "strings" - "time" - - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/policies" -) - -var _ session.Handler = (*handler)(nil) - -const protocol = "mqtt" - -// Log message formats. -const ( - LogInfoSubscribed = "subscribed with client_id %s to topics %s" - LogInfoUnsubscribed = "unsubscribed client_id %s from topics %s" - LogInfoConnected = "connected with client_id %s" - LogInfoDisconnected = "disconnected client_id %s and username %s" - LogInfoPublished = "published with client_id %s to the topic %s" -) - -// Error wrappers for MQTT errors. -var ( - ErrClientNotInitialized = errors.New("client is not initialized") - ErrMissingClientID = errors.New("client_id not found") - ErrMissingTopicPub = errors.New("failed to publish due to missing topic") - ErrMissingTopicSub = errors.New("failed to subscribe due to missing topic") - ErrFailedConnect = errors.New("failed to connect") - ErrFailedSubscribe = errors.New("failed to subscribe") - ErrFailedUnsubscribe = errors.New("failed to unsubscribe") - ErrFailedPublish = errors.New("failed to publish") - ErrFailedDisconnect = errors.New("failed to disconnect") - ErrFailedPublishDisconnectEvent = errors.New("failed to publish disconnect event") - ErrFailedPublishConnectEvent = errors.New("failed to publish connect event") - ErrFailedSubscribeEvent = errors.New("failed to publish subscribe event") - ErrFailedPublishToMsgBroker = errors.New("failed to publish to supermq message broker") - - errInvalidUserId = errors.New("invalid user id") -) - -// Event implements events.Event interface. -type handler struct { - publisher messaging.Publisher - clients grpcClientsV1.ClientsServiceClient - channels grpcChannelsV1.ChannelsServiceClient - parser messaging.TopicParser - logger *slog.Logger -} - -// NewHandler creates new Handler entity. -func NewHandler(publisher messaging.Publisher, logger *slog.Logger, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, parser messaging.TopicParser) session.Handler { - return &handler{ - logger: logger, - publisher: publisher, - clients: clients, - channels: channels, - parser: parser, - } -} - -// AuthConnect is called on device connection, -// prior forwarding to the MQTT broker. -func (h *handler) AuthConnect(ctx context.Context) error { - s, ok := session.FromContext(ctx) - if !ok { - return ErrClientNotInitialized - } - - if s.ID == "" { - return ErrMissingClientID - } - - pwd := string(s.Password) - - res, err := h.clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, s.Username, pwd)}) - if err != nil { - return errors.Wrap(svcerr.ErrAuthentication, err) - } - if !res.GetAuthenticated() { - return svcerr.ErrAuthentication - } - - if s.Username != "" && res.GetId() != s.Username { - return errInvalidUserId - } - - return nil -} - -// AuthPublish is called on device publish, -// prior forwarding to the MQTT broker. -func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - if topic == nil { - return ErrMissingTopicPub - } - s, ok := session.FromContext(ctx) - if !ok { - return ErrClientNotInitialized - } - - domainID, chanID, _, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, false) - if err != nil { - return err - } - - return h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Publish, topicType) -} - -// AuthSubscribe is called on device subscribe, -// prior forwarding to the MQTT broker. -func (h *handler) AuthSubscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return ErrClientNotInitialized - } - if topics == nil || *topics == nil { - return ErrMissingTopicSub - } - - for _, topic := range *topics { - domainID, chanID, _, topicType, err := h.parser.ParseSubscribeTopic(ctx, topic, false) - if err != nil { - return err - } - - if err := h.authAccess(ctx, string(s.Username), domainID, chanID, connections.Subscribe, topicType); err != nil { - return err - } - } - - return nil -} - -// Connect - after client successfully connected. -func (h *handler) Connect(ctx context.Context) error { - s, ok := session.FromContext(ctx) - if !ok { - return errors.Wrap(ErrFailedConnect, ErrClientNotInitialized) - } - h.logger.Info(fmt.Sprintf(LogInfoConnected, s.ID)) - return nil -} - -// Publish - after client successfully published. -func (h *handler) Publish(ctx context.Context, topic *string, payload *[]byte) error { - s, ok := session.FromContext(ctx) - if !ok { - return errors.Wrap(ErrFailedPublish, ErrClientNotInitialized) - } - h.logger.Info(fmt.Sprintf(LogInfoPublished, s.ID, *topic)) - - domainID, chanID, subTopic, topicType, err := h.parser.ParsePublishTopic(ctx, *topic, false) - if err != nil { - return errors.Wrap(ErrFailedPublish, err) - } - - msg := messaging.Message{ - Protocol: protocol, - Domain: domainID, - Channel: chanID, - Subtopic: subTopic, - Publisher: s.Username, - Payload: *payload, - Created: time.Now().UnixNano(), - } - - if topicType == messaging.MessageType { - if err := h.publisher.Publish(ctx, messaging.EncodeMessageTopic(&msg), &msg); err != nil { - return errors.Wrap(ErrFailedPublishToMsgBroker, err) - } - } - - return nil -} - -// Subscribe - after client successfully subscribed. -func (h *handler) Subscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errors.Wrap(ErrFailedSubscribe, ErrClientNotInitialized) - } - h.logger.Info(fmt.Sprintf(LogInfoSubscribed, s.ID, strings.Join(*topics, ","))) - - return nil -} - -// Unsubscribe - after client unsubscribed. -func (h *handler) Unsubscribe(ctx context.Context, topics *[]string) error { - s, ok := session.FromContext(ctx) - if !ok { - return errors.Wrap(ErrFailedUnsubscribe, ErrClientNotInitialized) - } - h.logger.Info(fmt.Sprintf(LogInfoUnsubscribed, s.ID, strings.Join(*topics, ","))) - - return nil -} - -// Disconnect - connection with broker or client lost. -func (h *handler) Disconnect(ctx context.Context) error { - s, ok := session.FromContext(ctx) - if !ok { - return errors.Wrap(ErrFailedDisconnect, ErrClientNotInitialized) - } - h.logger.Info(fmt.Sprintf(LogInfoDisconnected, s.ID, s.Username)) - - return nil -} - -func (h *handler) authAccess(ctx context.Context, clientID, domainID, chanID string, msgType connections.ConnType, topicType messaging.TopicType) error { - switch topicType { - case messaging.HealthType: - return nil - default: - ar := &grpcChannelsV1.AuthzReq{ - Type: uint32(msgType), - ClientId: clientID, - ClientType: policies.ClientType, - ChannelId: chanID, - DomainId: domainID, - } - res, err := h.channels.Authorize(ctx, ar) - if err != nil { - return err - } - if !res.GetAuthorized() { - return svcerr.ErrAuthorization - } - - return nil - } -} diff --git a/mqtt/handler_test.go b/mqtt/handler_test.go deleted file mode 100644 index af06c1102..000000000 --- a/mqtt/handler_test.go +++ /dev/null @@ -1,616 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package mqtt_test - -import ( - "bytes" - "context" - "fmt" - "log" - "testing" - - "github.com/absmach/mgate/pkg/session" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - chmocks "github.com/absmach/supermq/channels/mocks" - climocks "github.com/absmach/supermq/clients/mocks" - dmocks "github.com/absmach/supermq/domains/mocks" - "github.com/absmach/supermq/internal/testsutil" - smqlog "github.com/absmach/supermq/logger" - "github.com/absmach/supermq/mqtt" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/connections" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/mocks" - "github.com/absmach/supermq/pkg/policies" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -const ( - password = "password" - password1 = "password1" - chanID = "123e4567-e89b-12d3-a456-000000000001" - invalidID = "invalidID" - invalidValue = "invalidValue" - clientID = "clientID" - clientID1 = "clientID1" - subtopic = "testSubtopic" - invalidChannelIDTopic = "m/**/c" -) - -var ( - domainID = testsutil.GenerateUUID(&testing.T{}) - topicMsg = "/m/%s/c/%s" - topic = fmt.Sprintf(topicMsg, domainID, chanID) - hcTopicFmt = "/hc/%s" - hcTopic = fmt.Sprintf(hcTopicFmt, domainID) - invalidHCTopic = "/hc" - invalidTopic = invalidValue - payload = []byte("[{'n':'test-name', 'v': 1.2}]") - topics = []string{topic} - invalidTopics = []string{invalidValue} - invalidChanIDTopics = []string{fmt.Sprintf(topicMsg, domainID, invalidValue)} - // Test log messages for cases the handler does not provide a return value. - logBuffer = bytes.Buffer{} - sessionClient = session.Session{ - ID: clientID, - Username: clientID, - Password: []byte(password), - } - sessionClientSub = session.Session{ - ID: clientID1, - Username: clientID1, - Password: []byte(password1), - } - invalidClientSessionClient = session.Session{ - ID: clientID, - Username: invalidID, - Password: []byte(password), - } - errInvalidUserId = errors.New("invalid user id") -) - -var ( - clients *climocks.ClientsServiceClient - channels *chmocks.ChannelsServiceClient - publisher *mocks.PubSub -) - -func TestAuthConnect(t *testing.T) { - handler := newHandler() - - cases := []struct { - desc string - session *session.Session - authNRes *grpcClientsV1.AuthnRes - authNErr error - err error - }{ - { - desc: "connect without active session", - err: mqtt.ErrClientNotInitialized, - session: nil, - }, - { - desc: "connect without clientID", - err: mqtt.ErrMissingClientID, - session: &session.Session{ - ID: "", - Username: clientID, - Password: []byte(password), - }, - }, - { - desc: "connect with empty password", - session: &session.Session{ - ID: clientID, - Username: clientID, - Password: []byte(""), - }, - authNErr: svcerr.ErrAuthentication, - err: svcerr.ErrAuthentication, - }, - { - desc: "connect with invalid password", - session: &session.Session{ - ID: clientID, - Username: clientID, - Password: []byte("invalid"), - }, - authNRes: &grpcClientsV1.AuthnRes{ - Authenticated: false, - }, - err: svcerr.ErrAuthentication, - }, - { - desc: "connect with valid password and invalid username", - session: &invalidClientSessionClient, - authNRes: &grpcClientsV1.AuthnRes{ - Authenticated: true, - Id: testsutil.GenerateUUID(t), - }, - err: errInvalidUserId, - }, - { - desc: "connect with valid username and password", - err: nil, - session: &sessionClient, - authNRes: &grpcClientsV1.AuthnRes{ - Authenticated: true, - Id: clientID, - }, - }, - } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - password := "" - username := "" - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - password = string(tc.session.Password) - username = tc.session.Username - } - clientsCall := clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{Token: authn.AuthPack(authn.BasicAuth, username, password)}).Return(tc.authNRes, tc.authNErr) - err := handler.AuthConnect(ctx) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - clientsCall.Unset() - }) - } -} - -func TestAuthPublish(t *testing.T) { - handler := newHandler() - - cases := []struct { - desc string - session *session.Session - err error - topic *string - payload []byte - authZRes *grpcChannelsV1.AuthzRes - authZErr error - }{ - { - desc: "publish successfully", - session: &sessionClient, - err: nil, - topic: &topic, - payload: payload, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publish with an inactive client", - session: nil, - err: mqtt.ErrClientNotInitialized, - topic: &topic, - payload: payload, - }, - { - desc: "publish without topic", - session: &sessionClient, - err: mqtt.ErrMissingTopicPub, - topic: nil, - payload: payload, - }, - { - desc: "publish with malformed topic", - session: &sessionClient, - err: messaging.ErrMalformedTopic, - topic: &invalidTopic, - payload: payload, - }, - { - desc: "publish with authorization error", - session: &sessionClient, - err: svcerr.ErrAuthorization, - topic: &topic, - payload: payload, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - authZErr: svcerr.ErrAuthorization, - }, - { - desc: "publish to health check topic", - session: &sessionClient, - err: nil, - topic: &hcTopic, - payload: payload, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - }, - { - desc: "publich with invalid health check topic", - session: &sessionClient, - err: messaging.ErrMalformedTopic, - topic: &invalidHCTopic, - payload: payload, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - DomainId: domainID, - ChannelId: chanID, - ClientId: clientID, - ClientType: policies.ClientType, - Type: uint32(connections.Publish), - }).Return(tc.authZRes, tc.authZErr) - err := handler.AuthPublish(ctx, tc.topic, &tc.payload) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - channelsCall.Unset() - }) - } -} - -func TestAuthSubscribe(t *testing.T) { - handler := newHandler() - - cases := []struct { - desc string - session *session.Session - err error - topic *[]string - channelID string - authZRes *grpcChannelsV1.AuthzRes - authZErr error - }{ - { - desc: "subscribe without active session", - session: nil, - err: mqtt.ErrClientNotInitialized, - topic: &topics, - }, - { - desc: "subscribe without topics", - session: &sessionClient, - err: mqtt.ErrMissingTopicSub, - topic: nil, - }, - { - desc: "subscribe with invalid topics", - session: &sessionClient, - err: messaging.ErrMalformedTopic, - topic: &invalidTopics, - }, - { - desc: "subscribe with invalid channel ID", - session: &sessionClientSub, - err: svcerr.ErrAuthorization, - topic: &invalidChanIDTopics, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - channelID: invalidValue, - }, - { - desc: "subscribe successfully", - session: &sessionClientSub, - err: nil, - topic: &topics, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - channelID: chanID, - }, - { - desc: "subscribe with failed authorization", - session: &sessionClientSub, - err: svcerr.ErrAuthorization, - topic: &topics, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: false}, - channelID: chanID, - }, - { - desc: "subscribe successfully with health check topic", - session: &sessionClientSub, - err: nil, - topic: &[]string{hcTopic}, - authZRes: &grpcChannelsV1.AuthzRes{Authorized: true}, - channelID: "", - }, - { - desc: "subscribe with invalid health check topic", - session: &sessionClientSub, - err: messaging.ErrMalformedTopic, - topic: &[]string{invalidHCTopic}, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - channelsCall := channels.On("Authorize", mock.Anything, &grpcChannelsV1.AuthzReq{ - DomainId: domainID, - ChannelId: tc.channelID, - ClientId: clientID1, - ClientType: policies.ClientType, - Type: uint32(connections.Subscribe), - }).Return(tc.authZRes, tc.authZErr) - err := handler.AuthSubscribe(ctx, tc.topic) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - channelsCall.Unset() - }) - } -} - -func TestConnect(t *testing.T) { - handler := newHandler() - logBuffer.Reset() - - cases := []struct { - desc string - session *session.Session - err error - logMsg string - }{ - { - desc: "connect without active session", - session: nil, - err: errors.Wrap(mqtt.ErrFailedConnect, mqtt.ErrClientNotInitialized), - }, - { - desc: "connect with active session", - session: &sessionClient, - logMsg: fmt.Sprintf(mqtt.LogInfoConnected, clientID), - err: nil, - }, - } - - for _, tc := range cases { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - err := handler.Connect(ctx) - assert.Contains(t, logBuffer.String(), tc.logMsg) - assert.Equal(t, tc.err, err) - } -} - -func TestPublish(t *testing.T) { - handler := newHandler() - logBuffer.Reset() - - malformedSubtopics := topic + "/" + subtopic + "%" - wrongCharSubtopics := topic + "/" + subtopic + ">" - validSubtopic := topic + "/" + subtopic - - cases := []struct { - desc string - session *session.Session - topic string - payload []byte - logMsg string - err error - }{ - { - desc: "publish without active session", - session: nil, - topic: topic, - payload: payload, - err: errors.Wrap(mqtt.ErrFailedPublish, mqtt.ErrClientNotInitialized), - }, - { - desc: "publish with invalid topic", - session: &sessionClient, - topic: invalidTopic, - payload: payload, - logMsg: fmt.Sprintf(mqtt.LogInfoPublished, clientID, invalidTopic), - err: errors.Wrap(mqtt.ErrFailedPublish, messaging.ErrMalformedTopic), - }, - { - desc: "publish with invalid channel ID", - session: &sessionClient, - topic: invalidChannelIDTopic, - payload: payload, - err: errors.Wrap(mqtt.ErrFailedPublish, messaging.ErrMalformedTopic), - }, - { - desc: "publish with malformed subtopic", - session: &sessionClient, - topic: malformedSubtopics, - payload: payload, - err: errors.New("invalid URL escape \"%\""), - }, - { - desc: "publish with subtopic containing wrong character", - session: &sessionClient, - topic: wrongCharSubtopics, - payload: payload, - err: errors.Wrap(mqtt.ErrFailedPublish, errors.Wrap(messaging.ErrMalformedTopic, messaging.ErrMalformedSubtopic)), - }, - { - desc: "publish with subtopic", - session: &sessionClient, - topic: validSubtopic, - payload: payload, - logMsg: subtopic, - }, - { - desc: "publish without subtopic", - session: &sessionClient, - topic: topic, - payload: payload, - logMsg: "", - }, - { - desc: "publish with health check topic", - session: &sessionClient, - topic: hcTopic, - payload: payload, - logMsg: "", - }, - { - desc: "publish with invalid health check topic", - session: &sessionClient, - topic: invalidHCTopic, - payload: payload, - err: errors.Wrap(mqtt.ErrFailedPublish, messaging.ErrMalformedTopic), - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - repoCall := publisher.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil) - err := handler.Publish(ctx, &tc.topic, &tc.payload) - assert.Contains(t, logBuffer.String(), tc.logMsg) - if tc.err != nil { - assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error containing: %v, got: %v", tc.err, err)) - } - repoCall.Unset() - }) - } -} - -func TestSubscribe(t *testing.T) { - handler := newHandler() - logBuffer.Reset() - - cases := []struct { - desc string - session *session.Session - topic []string - logMsg string - err error - }{ - { - desc: "subscribe without active session", - session: nil, - topic: topics, - err: errors.Wrap(mqtt.ErrFailedSubscribe, mqtt.ErrClientNotInitialized), - }, - { - desc: "subscribe with valid session and topics", - session: &sessionClient, - topic: topics, - logMsg: fmt.Sprintf(mqtt.LogInfoSubscribed, clientID, topics[0]), - }, - } - - for _, tc := range cases { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - err := handler.Subscribe(ctx, &tc.topic) - assert.Contains(t, logBuffer.String(), tc.logMsg) - assert.Equal(t, tc.err, err) - } -} - -func TestUnsubscribe(t *testing.T) { - handler := newHandler() - logBuffer.Reset() - - cases := []struct { - desc string - session *session.Session - topic []string - logMsg string - err error - }{ - { - desc: "unsubscribe without active session", - session: nil, - topic: topics, - err: errors.Wrap(mqtt.ErrFailedUnsubscribe, mqtt.ErrClientNotInitialized), - }, - { - desc: "unsubscribe with valid session and topics", - session: &sessionClient, - topic: topics, - logMsg: fmt.Sprintf(mqtt.LogInfoUnsubscribed, clientID, topics[0]), - }, - } - - for _, tc := range cases { - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - err := handler.Unsubscribe(ctx, &tc.topic) - assert.Contains(t, logBuffer.String(), tc.logMsg) - assert.Equal(t, tc.err, err) - } -} - -func TestDisconnect(t *testing.T) { - handler := newHandler() - logBuffer.Reset() - - cases := []struct { - desc string - session *session.Session - topic []string - logMsg string - err error - }{ - { - desc: "disconnect without active session", - session: nil, - topic: topics, - err: errors.Wrap(mqtt.ErrFailedDisconnect, mqtt.ErrClientNotInitialized), - }, - { - desc: "disconnect with valid session", - session: &sessionClient, - topic: topics, - err: nil, - }, - { - desc: "disconnect logs username not password", - session: &session.Session{ - ID: "testClient", - Username: "testUser", - Password: []byte("secretPassword123"), - }, - topic: topics, - logMsg: fmt.Sprintf(mqtt.LogInfoDisconnected, "testClient", "testUser"), - err: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - logBuffer.Reset() - ctx := context.TODO() - if tc.session != nil { - ctx = session.NewContext(ctx, tc.session) - } - err := handler.Disconnect(ctx) - assert.Contains(t, logBuffer.String(), tc.logMsg) - assert.Equal(t, tc.err, err) - - if tc.session != nil { - assert.NotContains(t, logBuffer.String(), string(tc.session.Password), "password should not be logged") - } - }) - } -} - -func newHandler() session.Handler { - logger, err := smqlog.New(&logBuffer, "debug") - if err != nil { - log.Fatalf("failed to create logger: %s", err) - } - clients = new(climocks.ClientsServiceClient) - channels = new(chmocks.ChannelsServiceClient) - domains := new(dmocks.DomainsServiceClient) - parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channels, domains) - if err != nil { - log.Fatalf("failed to create topic parser: %s", err) - } - publisher = new(mocks.PubSub) - return mqtt.NewHandler(publisher, logger, clients, channels, parser) -} diff --git a/mqtt/tracing/forwarder.go b/mqtt/tracing/forwarder.go deleted file mode 100644 index b904edc50..000000000 --- a/mqtt/tracing/forwarder.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package tracing - -import ( - "context" - "fmt" - - "github.com/absmach/supermq/mqtt" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/server" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -const forwardOP = "process" - -var _ mqtt.Forwarder = (*forwarderMiddleware)(nil) - -type forwarderMiddleware struct { - topic string - forwarder mqtt.Forwarder - tracer trace.Tracer - host server.Config -} - -// New creates new mqtt forwarder tracing middleware. -func New(config server.Config, tracer trace.Tracer, forwarder mqtt.Forwarder, topic string) mqtt.Forwarder { - return &forwarderMiddleware{ - forwarder: forwarder, - tracer: tracer, - topic: topic, - host: config, - } -} - -// Forward traces mqtt forward operations. -func (fm *forwarderMiddleware) Forward(ctx context.Context, id string, sub messaging.Subscriber, pub messaging.Publisher) error { - spanName := fmt.Sprintf("%s %s", fm.topic, forwardOP) - - ctx, span := fm.tracer.Start(ctx, - spanName, - trace.WithAttributes( - attribute.String("messaging.system", "mqtt"), - attribute.Bool("messaging.destination.anonymous", false), - attribute.String("messaging.destination.template", "m/{domainID}/c/{channelID}/*"), - attribute.Bool("messaging.destination.temporary", true), - attribute.String("network.protocol.name", "mqtt"), - attribute.String("network.protocol.version", "3.1.1"), - attribute.String("network.transport", "tcp"), - attribute.String("network.type", "ipv4"), - attribute.String("messaging.operation", forwardOP), - attribute.String("messaging.client_id", id), - attribute.String("server.address", fm.host.Host), - attribute.String("server.socket.port", fm.host.Port), - ), - ) - defer span.End() - - return fm.forwarder.Forward(ctx, id, sub, pub) -} diff --git a/notifications/README.md b/notifications/README.md index 88d618033..ba137f063 100644 --- a/notifications/README.md +++ b/notifications/README.md @@ -31,30 +31,30 @@ domains service → event store → notifications service → users service (gRP The service is configured using environment variables: ### General Configuration -- `SMQ_NOTIFICATIONS_LOG_LEVEL` - Log level (default: "info") -- `SMQ_NOTIFICATIONS_INSTANCE_ID` - Instance ID for the service -- `SMQ_NOTIFICATIONS_DOMAIN_ALT_NAME` - Alternative name for domains such as, say, workspaces or tenants (default: "domains") -- `SMQ_ES_URL` - Event store URL (default: "nats://localhost:4222") +- `MG_NOTIFICATIONS_LOG_LEVEL` - Log level (default: "info") +- `MG_NOTIFICATIONS_INSTANCE_ID` - Instance ID for the service +- `MG_NOTIFICATIONS_DOMAIN_ALT_NAME` - Alternative name for domains such as, say, workspaces or tenants (default: "domains") +- `MG_ES_URL` - Event store URL (default: "nats://localhost:4222") ### Email Configuration -- `SMQ_EMAIL_HOST` - SMTP server host (default: "localhost") -- `SMQ_EMAIL_PORT` - SMTP server port (default: "25") -- `SMQ_EMAIL_USERNAME` - SMTP username -- `SMQ_EMAIL_PASSWORD` - SMTP password -- `SMQ_EMAIL_FROM_ADDRESS` - From email address (default: "noreply@supermq.com") -- `SMQ_EMAIL_FROM_NAME` - From name (default: "SuperMQ Notifications") +- `MG_EMAIL_HOST` - SMTP server host (default: "localhost") +- `MG_EMAIL_PORT` - SMTP server port (default: "25") +- `MG_EMAIL_USERNAME` - SMTP username +- `MG_EMAIL_PASSWORD` - SMTP password +- `MG_EMAIL_FROM_ADDRESS` - From email address (default: "noreply@supermq.com") +- `MG_EMAIL_FROM_NAME` - From name (default: "SuperMQ Notifications") ### Template Configuration -- `SMQ_EMAIL_INVITATION_TEMPLATE` - Path to invitation email template -- `SMQ_EMAIL_ACCEPTANCE_TEMPLATE` - Path to acceptance email template -- `SMQ_EMAIL_REJECTION_TEMPLATE` - Path to rejection email template +- `MG_EMAIL_INVITATION_TEMPLATE` - Path to invitation email template +- `MG_EMAIL_ACCEPTANCE_TEMPLATE` - Path to acceptance email template +- `MG_EMAIL_REJECTION_TEMPLATE` - Path to rejection email template ### gRPC Configuration (Users Service) -- `SMQ_USERS_GRPC_URL` - Users service gRPC URL -- `SMQ_USERS_GRPC_TIMEOUT` - gRPC request timeout -- `SMQ_USERS_GRPC_CLIENT_CERT` - Client certificate path -- `SMQ_USERS_GRPC_CLIENT_KEY` - Client key path -- `SMQ_USERS_GRPC_SERVER_CA_CERTS` - Server CA certificates path +- `MG_USERS_GRPC_URL` - Users service gRPC URL +- `MG_USERS_GRPC_TIMEOUT` - gRPC request timeout +- `MG_USERS_GRPC_CLIENT_CERT` - Client certificate path +- `MG_USERS_GRPC_CLIENT_KEY` - Client key path +- `MG_USERS_GRPC_SERVER_CA_CERTS` - Server CA certificates path ## Running the Service @@ -94,7 +94,7 @@ go test ./notifications/... -v To run email integration tests (requires SMTP server): ```bash -SMQ_RUN_EMAIL_TESTS=true go test ./notifications/emailer -v +MG_RUN_EMAIL_TESTS=true go test ./notifications/emailer -v ``` ## Development diff --git a/notifications/emailer/emailer.go b/notifications/emailer/emailer.go index 51cf73df1..66ad08843 100644 --- a/notifications/emailer/emailer.go +++ b/notifications/emailer/emailer.go @@ -127,7 +127,7 @@ func (n *notifier) Notify(ctx context.Context, notif notifications.Notification) return errors.Wrap(errSendingEmail, fmt.Errorf("no email agent configured for notification type: %d", notif.Type)) } - if err := agent.Send([]string{recipientEmail}, "", subject, "", recipientName, content, n.fromName); err != nil { + if err := agent.Send([]string{recipientEmail}, "", subject, "", recipientName, content, n.fromName, nil); err != nil { return errors.Wrap(errSendingEmail, err) } diff --git a/notifications/emailer/emailer_test.go b/notifications/emailer/emailer_test.go index f319fa9df..6e3836939 100644 --- a/notifications/emailer/emailer_test.go +++ b/notifications/emailer/emailer_test.go @@ -49,8 +49,8 @@ func (m *mockUsersClient) RetrieveUsers(ctx context.Context, req *grpcUsersV1.Re } func TestNotify(t *testing.T) { - if os.Getenv("SMQ_RUN_EMAIL_TESTS") != envTrue { - t.Skip("Skipping email tests. Set SMQ_RUN_EMAIL_TESTS=true to run.") + if os.Getenv("MG_RUN_EMAIL_TESTS") != envTrue { + t.Skip("Skipping email tests. Set MG_RUN_EMAIL_TESTS=true to run.") } usersClient := new(mockUsersClient) diff --git a/pkg/authn/middleware.go b/pkg/authn/middleware.go index eecb4be9a..3be6df609 100644 --- a/pkg/authn/middleware.go +++ b/pkg/authn/middleware.go @@ -19,7 +19,7 @@ import ( type sessionKeyType string const ( - allowUnverifiedUserEnv = "SMQ_ALLOW_UNVERIFIED_USER" + allowUnverifiedUserEnv = "MG_ALLOW_UNVERIFIED_USER" jsonContentType = "application/json" SessionKey = sessionKeyType("session") @@ -81,12 +81,12 @@ type authnMiddleware struct { // NewAuthNMiddleware creates a new authenticated service with middleware support. // The order of precedence for options is as follows, with later options overriding earlier ones: // 1. Default options (lowest precedence). -// 2. Options from environment variables (e.g., SMQ_ALLOW_UNVERIFIED_USER). +// 2. Options from environment variables (e.g., MG_ALLOW_UNVERIFIED_USER). // 3. Options passed as arguments to this function (highest precedence). // // For example, consider the 'allowUnverifiedUser' option: // - By default, it is 'false'. -// - If the SMQ_ALLOW_UNVERIFIED_USER environment variable is set to "true", +// - If the MG_ALLOW_UNVERIFIED_USER environment variable is set to "true", // it becomes 'true'. // - If NewAuthNMiddleware is called with WithAllowUnverifiedUser(false), it will be 'false', // regardless of the environment variable, as function arguments have the highest precedence. diff --git a/pkg/authz/authsvc/authz.go b/pkg/authz/authsvc/authz.go index bcb2bd7f5..9e23763b6 100644 --- a/pkg/authz/authsvc/authz.go +++ b/pkg/authz/authsvc/authz.go @@ -108,7 +108,7 @@ func (a authorization) checkDomain(ctx context.Context, subjectType, subject, do Subject: subject, SubjectType: subjectType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, }) diff --git a/pkg/channels/events/consumer/streams.go b/pkg/channels/events/consumer/streams.go index 4ce64be1c..9b5376b29 100644 --- a/pkg/channels/events/consumer/streams.go +++ b/pkg/channels/events/consumer/streams.go @@ -47,7 +47,7 @@ type eventHandler struct { } func ChannelsEventsSubscribe(ctx context.Context, repo channels.Repository, esURL, esConsumerName string, logger *slog.Logger) error { - subscriber, err := store.NewSubscriber(ctx, esURL, logger) + subscriber, err := store.NewSubscriber(ctx, esURL, "channels-es-sub", logger) if err != nil { return err } diff --git a/pkg/clients/events/consumer/streams.go b/pkg/clients/events/consumer/streams.go index 52fd8773a..0349ef369 100644 --- a/pkg/clients/events/consumer/streams.go +++ b/pkg/clients/events/consumer/streams.go @@ -43,7 +43,7 @@ type eventHandler struct { } func ClientsEventsSubscribe(ctx context.Context, repo clients.Repository, esURL, esConsumerName string, logger *slog.Logger) error { - subscriber, err := store.NewSubscriber(ctx, esURL, logger) + subscriber, err := store.NewSubscriber(ctx, esURL, "clients-es-sub", logger) if err != nil { return err } diff --git a/pkg/domains/events/consumer/stream.go b/pkg/domains/events/consumer/stream.go index bb1fa87b8..45ec5bed6 100644 --- a/pkg/domains/events/consumer/stream.go +++ b/pkg/domains/events/consumer/stream.go @@ -45,7 +45,7 @@ type eventHandler struct { } func DomainsEventsSubscribe(ctx context.Context, repo domains.Repository, esURL, esConsumerName string, logger *slog.Logger) error { - subscriber, err := store.NewSubscriber(ctx, esURL, logger) + subscriber, err := store.NewSubscriber(ctx, esURL, "domains-es-sub", logger) if err != nil { return err } diff --git a/pkg/emailer/emailer.go b/pkg/emailer/emailer.go new file mode 100644 index 000000000..abd564ae0 --- /dev/null +++ b/pkg/emailer/emailer.go @@ -0,0 +1,28 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package emailer + +import ( + "github.com/absmach/supermq/internal/email" +) + +var _ Emailer = (*emailer)(nil) + +type Emailer interface { + // SendEmailNotification sends an email to the recipients based on a trigger. + SendEmailNotification(to []string, from, subject, header, user, content, footer string, attachments map[string][]byte) error +} + +type emailer struct { + agent *email.Agent +} + +func New(a *email.Config) (Emailer, error) { + e, err := email.New(a) + return &emailer{agent: e}, err +} + +func (e *emailer) SendEmailNotification(to []string, from, subject, header, user, content, footer string, attachments map[string][]byte) error { + return e.agent.Send(to, from, subject, header, user, content, footer, attachments) +} diff --git a/pkg/emailer/mocks/emailer.go b/pkg/emailer/mocks/emailer.go new file mode 100644 index 000000000..94f753450 --- /dev/null +++ b/pkg/emailer/mocks/emailer.go @@ -0,0 +1,133 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" +) + +// NewEmailer creates a new instance of Emailer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewEmailer(t interface { + mock.TestingT + Cleanup(func()) +}) *Emailer { + mock := &Emailer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Emailer is an autogenerated mock type for the Emailer type +type Emailer struct { + mock.Mock +} + +type Emailer_Expecter struct { + mock *mock.Mock +} + +func (_m *Emailer) EXPECT() *Emailer_Expecter { + return &Emailer_Expecter{mock: &_m.Mock} +} + +// SendEmailNotification provides a mock function for the type Emailer +func (_mock *Emailer) SendEmailNotification(to []string, from string, subject string, header string, user string, content string, footer string, attachments map[string][]byte) error { + ret := _mock.Called(to, from, subject, header, user, content, footer, attachments) + + if len(ret) == 0 { + panic("no return value specified for SendEmailNotification") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func([]string, string, string, string, string, string, string, map[string][]byte) error); ok { + r0 = returnFunc(to, from, subject, header, user, content, footer, attachments) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Emailer_SendEmailNotification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendEmailNotification' +type Emailer_SendEmailNotification_Call struct { + *mock.Call +} + +// SendEmailNotification is a helper method to define mock.On call +// - to []string +// - from string +// - subject string +// - header string +// - user string +// - content string +// - footer string +// - attachments map[string][]byte +func (_e *Emailer_Expecter) SendEmailNotification(to interface{}, from interface{}, subject interface{}, header interface{}, user interface{}, content interface{}, footer interface{}, attachments interface{}) *Emailer_SendEmailNotification_Call { + return &Emailer_SendEmailNotification_Call{Call: _e.mock.On("SendEmailNotification", to, from, subject, header, user, content, footer, attachments)} +} + +func (_c *Emailer_SendEmailNotification_Call) Run(run func(to []string, from string, subject string, header string, user string, content string, footer string, attachments map[string][]byte)) *Emailer_SendEmailNotification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []string + if args[0] != nil { + arg0 = args[0].([]string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + var arg6 string + if args[6] != nil { + arg6 = args[6].(string) + } + var arg7 map[string][]byte + if args[7] != nil { + arg7 = args[7].(map[string][]byte) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + arg6, + arg7, + ) + }) + return _c +} + +func (_c *Emailer_SendEmailNotification_Call) Return(err error) *Emailer_SendEmailNotification_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Emailer_SendEmailNotification_Call) RunAndReturn(run func(to []string, from string, subject string, header string, user string, content string, footer string, attachments map[string][]byte) error) *Emailer_SendEmailNotification_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/events/fluxmq/publisher.go b/pkg/events/fluxmq/publisher.go new file mode 100644 index 000000000..e4a5437d5 --- /dev/null +++ b/pkg/events/fluxmq/publisher.go @@ -0,0 +1,74 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "context" + "encoding/json" + "log/slog" + "time" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + "github.com/absmach/supermq/pkg/events" +) + +const ( + eventsQueue = "events" + eventsPrefix = "events." + queuePrefix = "$queue/" +) + +type pubEventStore struct { + client *fluxamqp.Client +} + +// NewPublisher creates a FluxMQ-backed event publisher. +func NewPublisher(_ context.Context, url, connectionName string) (events.Publisher, error) { + logger := slog.Default() + opts := fluxamqp.NewOptions().SetURL(url). + SetConnectionName(connectionName). + SetOnConnectionLost(func(err error) { + logger.Warn("FluxMQ event publisher connection lost", "error", err) + }). + SetOnReconnecting(func(attempt int) { + logger.Info("FluxMQ event publisher reconnecting", "attempt", attempt) + }). + SetOnConnect(func() { + logger.Info("FluxMQ event publisher connected") + }) + + client, err := fluxamqp.New(opts) + if err != nil { + return nil, err + } + if err := client.Connect(); err != nil { + return nil, err + } + if err := declareEventsStream(client); err != nil { + return nil, err + } + + return &pubEventStore{client: client}, nil +} + +func (es *pubEventStore) Publish(ctx context.Context, stream string, event events.Event) error { + values, err := event.Encode() + if err != nil { + return err + } + + values["occurred_at"] = time.Now().UnixNano() + values["stream"] = canonicalStream(stream) + + data, err := json.Marshal(values) + if err != nil { + return err + } + + return es.client.PublishContext(ctx, queueTopic(stream), data) +} + +func (es *pubEventStore) Close() error { + return es.client.Close() +} diff --git a/pkg/events/fluxmq/subscriber.go b/pkg/events/fluxmq/subscriber.go new file mode 100644 index 000000000..279f1c28d --- /dev/null +++ b/pkg/events/fluxmq/subscriber.go @@ -0,0 +1,133 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/messaging" +) + +var _ events.Subscriber = (*subEventStore)(nil) + +var ( + // ErrEmptyStream is returned when stream name is empty. + ErrEmptyStream = errors.New("stream name cannot be empty") + // ErrEmptyConsumer is returned when consumer name is empty. + ErrEmptyConsumer = errors.New("consumer name cannot be empty") +) + +type subEventStore struct { + client *fluxamqp.Client + logger *slog.Logger +} + +// NewSubscriber creates a FluxMQ-backed event subscriber. +func NewSubscriber(_ context.Context, url, connectionName string, logger *slog.Logger) (events.Subscriber, error) { + opts := fluxamqp.NewOptions().SetURL(url). + SetConnectionName(connectionName). + SetOnConnectionLost(func(err error) { + logger.Warn("FluxMQ event subscriber connection lost", "error", err) + }). + SetOnReconnecting(func(attempt int) { + logger.Info("FluxMQ event subscriber reconnecting", "attempt", attempt) + }). + SetOnConnect(func() { + logger.Info("FluxMQ event subscriber connected", url, connectionName) + }) + + client, err := fluxamqp.New(opts) + if err != nil { + return nil, err + } + if err := client.Connect(); err != nil { + return nil, err + } + if err := declareEventsStream(client); err != nil { + return nil, err + } + + return &subEventStore{ + client: client, + logger: logger, + }, nil +} + +func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error { + if cfg.Stream == "" { + return ErrEmptyStream + } + if cfg.Consumer == "" { + return ErrEmptyConsumer + } + + opts := &fluxamqp.StreamConsumeOptions{ + QueueName: eventsQueue, + Filter: streamFilter(cfg.Stream), + ConsumerGroup: cfg.Consumer, + } + + if cfg.DeliveryPolicy == messaging.DeliverNewPolicy { + opts.Offset = "last" + } + + return es.client.SubscribeToStream(opts, func(msg *fluxamqp.QueueMessage) { + if err := es.handle(ctx, cfg.Handler, msg); err != nil { + es.logWarn("failed to process FluxMQ event", "error", err) + } + }) +} + +func (es *subEventStore) Close() error { + return es.client.Close() +} + +func (es *subEventStore) handle(ctx context.Context, handler events.EventHandler, msg *fluxamqp.QueueMessage) error { + event := event{ + Data: make(map[string]any), + } + + if err := json.Unmarshal(msg.Body, &event.Data); err != nil { + if rejectErr := msg.Reject(); rejectErr != nil { + return errors.Join(err, rejectErr) + } + return err + } + + if err := handler.Handle(ctx, event); err != nil { + if nackErr := msg.Nack(); nackErr != nil { + return errors.Join(fmt.Errorf("failed to handle FluxMQ event: %w", err), nackErr) + } + return fmt.Errorf("failed to handle FluxMQ event: %w", err) + } + + if err := msg.Ack(); err != nil { + return err + } + + return nil +} + +func (es *subEventStore) logWarn(msg string, args ...any) { + if es.logger != nil { + es.logger.Warn(msg, args...) + return + } + + slog.Warn(msg, args...) +} + +type event struct { + Data map[string]any +} + +func (re event) Encode() (map[string]any, error) { + return re.Data, nil +} diff --git a/pkg/events/fluxmq/topic.go b/pkg/events/fluxmq/topic.go new file mode 100644 index 000000000..6668cac78 --- /dev/null +++ b/pkg/events/fluxmq/topic.go @@ -0,0 +1,61 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "strings" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" +) + +func canonicalStream(stream string) string { + stream = strings.TrimSpace(stream) + if stream == "" { + return eventsPrefix + } + if strings.HasPrefix(stream, eventsPrefix) { + return stream + } + return eventsPrefix + stream +} + +func queueTopic(stream string) string { + path := brokerPath(stream) + if path == "" { + return queuePrefix + eventsQueue + } + return queuePrefix + eventsQueue + "/" + path +} + +func queueFilter(stream string) string { + path := brokerPath(stream) + if path == "" || path == "#" { + return queuePrefix + eventsQueue + "/#" + } + return queuePrefix + eventsQueue + "/" + path +} + +func streamFilter(stream string) string { + return brokerPath(stream) +} + +func brokerPath(stream string) string { + stream = strings.TrimSpace(stream) + stream = strings.TrimPrefix(stream, eventsPrefix) + if stream == "" { + return "" + } + + replacer := strings.NewReplacer(".", "/", "*", "+", ">", "#") + return replacer.Replace(stream) +} + +func declareEventsStream(client *fluxamqp.Client) error { + _, err := client.DeclareStreamQueue(&fluxamqp.StreamQueueOptions{ + Name: eventsQueue, + Durable: true, + MaxAge: "30D", + }) + return err +} diff --git a/pkg/events/fluxmq/topic_test.go b/pkg/events/fluxmq/topic_test.go new file mode 100644 index 000000000..ba1b2636f --- /dev/null +++ b/pkg/events/fluxmq/topic_test.go @@ -0,0 +1,92 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import "testing" + +func TestCanonicalStream(t *testing.T) { + tests := []struct { + name string + stream string + want string + }{ + { + name: "raw supermq stream", + stream: "supermq.domain.create", + want: "events.supermq.domain.create", + }, + { + name: "already prefixed stream", + stream: "events.supermq.group.*", + want: "events.supermq.group.*", + }, + { + name: "all events wildcard", + stream: ">", + want: "events.>", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := canonicalStream(tc.stream); got != tc.want { + t.Fatalf("canonicalStream(%q) = %q, want %q", tc.stream, got, tc.want) + } + }) + } +} + +func TestQueueFilter(t *testing.T) { + tests := []struct { + name string + stream string + want string + }{ + { + name: "domain wildcard", + stream: "events.supermq.domain.*", + want: "$queue/events/supermq/domain/+", + }, + { + name: "all events", + stream: ">", + want: "$queue/events/#", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := queueFilter(tc.stream); got != tc.want { + t.Fatalf("queueFilter(%q) = %q, want %q", tc.stream, got, tc.want) + } + }) + } +} + +func TestStreamFilter(t *testing.T) { + tests := []struct { + name string + stream string + want string + }{ + { + name: "domain wildcard", + stream: "events.supermq.domain.*", + want: "supermq/domain/+", + }, + { + name: "all events", + stream: ">", + want: "#", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := streamFilter(tc.stream); got != tc.want { + t.Fatalf("streamFilter(%q) = %q, want %q", tc.stream, got, tc.want) + } + }) + } +} diff --git a/pkg/events/rabbitmq/doc.go b/pkg/events/rabbitmq/doc.go deleted file mode 100644 index a12d446e1..000000000 --- a/pkg/events/rabbitmq/doc.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package redis contains the domain concept definitions needed to support -// SuperMQ redis events source service functionality. -// -// It provides the abstraction of the redis stream and its operations. -package rabbitmq diff --git a/pkg/events/rabbitmq/publisher.go b/pkg/events/rabbitmq/publisher.go deleted file mode 100644 index 56a3781f5..000000000 --- a/pkg/events/rabbitmq/publisher.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq - -import ( - "context" - "encoding/json" - "time" - - "github.com/absmach/supermq/pkg/events" - "github.com/absmach/supermq/pkg/messaging" - broker "github.com/absmach/supermq/pkg/messaging/rabbitmq" -) - -type pubEventStore struct { - publisher messaging.Publisher -} - -func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { - publisher, err := broker.NewPublisher(url, broker.Prefix(eventsPrefix), broker.Exchange(exchangeName)) - if err != nil { - return nil, err - } - - es := &pubEventStore{ - publisher: publisher, - } - - return es, nil -} - -func (es *pubEventStore) Publish(ctx context.Context, stream string, event events.Event) error { - values, err := event.Encode() - if err != nil { - return err - } - values["occurred_at"] = time.Now().UnixNano() - - data, err := json.Marshal(values) - if err != nil { - return err - } - - record := &messaging.Message{ - Payload: data, - } - - return es.publisher.Publish(ctx, stream, record) -} - -func (es *pubEventStore) Close() error { - return es.publisher.Close() -} diff --git a/pkg/events/rabbitmq/publisher_test.go b/pkg/events/rabbitmq/publisher_test.go deleted file mode 100644 index 395fbbc55..000000000 --- a/pkg/events/rabbitmq/publisher_test.go +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq_test - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "math/rand" - "testing" - "time" - - smqlog "github.com/absmach/supermq/logger" - "github.com/absmach/supermq/pkg/events" - "github.com/absmach/supermq/pkg/events/rabbitmq" - "github.com/stretchr/testify/assert" -) - -var ( - eventsChan = make(chan map[string]any) - logger = smqlog.NewMock() - errFailed = errors.New("failed") - numEvents = 100 -) - -type testEvent struct { - Data map[string]any -} - -func (te testEvent) Encode() (map[string]any, error) { - data := make(map[string]any) - for k, v := range te.Data { - switch v.(type) { - case string: - data[k] = v - case float64: - data[k] = v - default: - b, err := json.Marshal(v) - if err != nil { - return nil, err - } - data[k] = string(b) - } - } - - return data, nil -} - -func TestPublish(t *testing.T) { - _, err := rabbitmq.NewPublisher(context.Background(), "http://invaliurl.com") - assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err) - - publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err)) - defer publisher.Close() - - _, err = rabbitmq.NewSubscriber("http://invaliurl.com", logger) - assert.NotNilf(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err), err) - - subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err)) - defer subcriber.Close() - - cfg := events.SubscriberConfig{ - Stream: "events." + stream, - Consumer: consumer, - Handler: handler{}, - } - err = subcriber.Subscribe(context.Background(), cfg) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err)) - - cases := []struct { - desc string - event map[string]any - err error - }{ - { - desc: "publish event successfully", - err: nil, - event: map[string]any{ - "temperature": fmt.Sprintf("%f", rand.Float64()), - "humidity": fmt.Sprintf("%f", rand.Float64()), - "sensor_id": "abc123", - "location": "Earth", - "status": "normal", - "timestamp": fmt.Sprintf("%d", time.Now().UnixNano()), - "operation": "create", - "occurred_at": time.Now().UnixNano(), - }, - }, - { - desc: "publish with nil event", - err: nil, - event: nil, - }, - { - desc: "publish event with invalid event location", - err: fmt.Errorf("json: unsupported type: chan int"), - event: map[string]any{ - "temperature": fmt.Sprintf("%f", rand.Float64()), - "humidity": fmt.Sprintf("%f", rand.Float64()), - "sensor_id": "abc123", - "location": make(chan int), - "status": "normal", - "timestamp": "invalid", - "operation": "create", - "occurred_at": time.Now().UnixNano(), - }, - }, - { - desc: "publish event with nested sting value", - err: nil, - event: map[string]any{ - "temperature": fmt.Sprintf("%f", rand.Float64()), - "humidity": fmt.Sprintf("%f", rand.Float64()), - "sensor_id": "abc123", - "location": map[string]string{ - "lat": fmt.Sprintf("%f", rand.Float64()), - "lng": fmt.Sprintf("%f", rand.Float64()), - }, - "status": "normal", - "timestamp": "invalid", - "operation": "create", - "occurred_at": time.Now().UnixNano(), - }, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - event := testEvent{Data: tc.event} - - err := publisher.Publish(context.Background(), stream, event) - switch tc.err { - case nil: - receivedEvent := <-eventsChan - - val := int64(receivedEvent["occurred_at"].(float64)) - if assert.WithinRange(t, time.Unix(0, val), time.Now().Add(-time.Second), time.Now().Add(time.Second)) { - delete(receivedEvent, "occurred_at") - delete(tc.event, "occurred_at") - } - - assert.Equal(t, tc.event["temperature"], receivedEvent["temperature"]) - assert.Equal(t, tc.event["humidity"], receivedEvent["humidity"]) - assert.Equal(t, tc.event["sensor_id"], receivedEvent["sensor_id"]) - assert.Equal(t, tc.event["status"], receivedEvent["status"]) - assert.Equal(t, tc.event["timestamp"], receivedEvent["timestamp"]) - assert.Equal(t, tc.event["operation"], receivedEvent["operation"]) - - default: - assert.ErrorContains(t, err, tc.err.Error()) - } - }) - } -} - -func TestPubsub(t *testing.T) { - cases := []struct { - desc string - stream string - consumer string - err error - handler events.EventHandler - }{ - { - desc: "Subscribe to a stream", - stream: fmt.Sprintf("events.%s", stream), - consumer: consumer, - err: nil, - handler: handler{false}, - }, - { - desc: "Subscribe to the same stream", - stream: fmt.Sprintf("events.%s", stream), - consumer: consumer, - err: nil, - handler: handler{false}, - }, - { - desc: "Subscribe to an empty stream with an empty consumer", - stream: "", - consumer: "", - err: rabbitmq.ErrEmptyStream, - handler: handler{false}, - }, - { - desc: "Subscribe to an empty stream with a valid consumer", - stream: "", - consumer: consumer, - err: rabbitmq.ErrEmptyStream, - handler: handler{false}, - }, - { - desc: "Subscribe to a valid stream with an empty consumer", - stream: fmt.Sprintf("events.%s", stream), - consumer: "", - err: rabbitmq.ErrEmptyConsumer, - handler: handler{false}, - }, - { - desc: "Subscribe to another stream", - stream: fmt.Sprintf("events.%s.%d", stream, 1), - consumer: consumer, - err: nil, - handler: handler{false}, - }, - { - desc: "Subscribe to a stream with malformed handler", - stream: fmt.Sprintf("events.%s", stream), - consumer: consumer, - err: nil, - handler: handler{true}, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger) - if err != nil { - assert.Equal(t, err, tc.err) - - return - } - - cfg := events.SubscriberConfig{ - Stream: tc.stream, - Consumer: tc.consumer, - Handler: tc.handler, - } - switch err := subcriber.Subscribe(context.Background(), cfg); { - case err == nil: - assert.Nil(t, err) - default: - assert.Equal(t, err, tc.err) - } - - err = subcriber.Close() - assert.Nil(t, err) - }) - } -} - -func TestUnavailablePublish(t *testing.T) { - publisher, err := rabbitmq.NewPublisher(context.Background(), rabbitmqURL) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err)) - - subcriber, err := rabbitmq.NewSubscriber(rabbitmqURL, logger) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on creating event store: %s", err)) - - cfg := events.SubscriberConfig{ - Stream: "events." + stream, - Consumer: consumer, - Handler: handler{}, - } - err = subcriber.Subscribe(context.Background(), cfg) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on subscribing to event store: %s", err)) - - err = pool.Client.PauseContainer(container.Container.ID) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on pausing container: %s", err)) - - spawnGoroutines(publisher, t) - - time.Sleep(1 * time.Second) - - err = pool.Client.UnpauseContainer(container.Container.ID) - assert.Nil(t, err, fmt.Sprintf("got unexpected error on unpausing container: %s", err)) - - // Wait for the events to be published. - time.Sleep(1 * time.Second) - - err = publisher.Close() - assert.Nil(t, err, fmt.Sprintf("got unexpected error on closing publisher: %s", err)) - - // read all the events from the channel and assert that they are 10. - var receivedEvents []map[string]any - for i := 0; i < numEvents; i++ { - event := <-eventsChan - receivedEvents = append(receivedEvents, event) - } - assert.Len(t, receivedEvents, numEvents, "got unexpected number of events") -} - -func generateRandomEvent() testEvent { - return testEvent{ - Data: map[string]any{ - "temperature": fmt.Sprintf("%f", rand.Float64()), - "humidity": fmt.Sprintf("%f", rand.Float64()), - "sensor_id": fmt.Sprintf("%d", rand.Intn(1000)), - "location": fmt.Sprintf("%f", rand.Float64()), - "status": fmt.Sprintf("%d", rand.Intn(1000)), - "timestamp": fmt.Sprintf("%d", time.Now().UnixNano()), - "operation": "create", - }, - } -} - -func spawnGoroutines(publisher events.Publisher, t *testing.T) { - for i := 0; i < numEvents; i++ { - go func() { - err := publisher.Publish(context.Background(), stream, generateRandomEvent()) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - }() - } -} - -type handler struct { - fail bool -} - -func (h handler) Handle(_ context.Context, event events.Event) error { - if h.fail { - return errFailed - } - data, err := event.Encode() - if err != nil { - return err - } - - eventsChan <- data - - return nil -} diff --git a/pkg/events/rabbitmq/setup_test.go b/pkg/events/rabbitmq/setup_test.go deleted file mode 100644 index a157b500d..000000000 --- a/pkg/events/rabbitmq/setup_test.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq_test - -import ( - "context" - "fmt" - "log" - "os" - "os/signal" - "syscall" - "testing" - - "github.com/absmach/supermq/pkg/events/rabbitmq" - "github.com/ory/dockertest/v3" -) - -var ( - rabbitmqURL string - stream = "tests.events" - consumer = "tests-consumer" - pool *dockertest.Pool - container *dockertest.Resource -) - -func TestMain(m *testing.M) { - var err error - pool, err = dockertest.NewPool("") - if err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - container, err = pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "rabbitmq", - Tag: "3.12.12", - }) - if err != nil { - log.Fatalf("Could not start container: %s", err) - } - - handleInterrupt(pool, container) - - rabbitmqURL = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort("5672/tcp")) - - if err := pool.Retry(func() error { - _, err = rabbitmq.NewPublisher(context.Background(), rabbitmqURL) - return err - }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - if err := pool.Retry(func() error { - _, err = rabbitmq.NewSubscriber(rabbitmqURL, logger) - return err - }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - code := m.Run() - - if err := pool.Purge(container); err != nil { - log.Fatalf("Could not purge container: %s", err) - } - - os.Exit(code) -} - -func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) { - c := make(chan os.Signal, 2) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - if err := pool.Purge(container); err != nil { - log.Fatalf("Could not purge container: %s", err) - } - os.Exit(0) - }() -} diff --git a/pkg/events/rabbitmq/subscriber.go b/pkg/events/rabbitmq/subscriber.go deleted file mode 100644 index 571050d02..000000000 --- a/pkg/events/rabbitmq/subscriber.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "log/slog" - - "github.com/absmach/supermq/pkg/events" - "github.com/absmach/supermq/pkg/messaging" - broker "github.com/absmach/supermq/pkg/messaging/rabbitmq" -) - -var _ events.Subscriber = (*subEventStore)(nil) - -var ( - exchangeName = "events" - eventsPrefix = "events" - - // ErrEmptyStream is returned when stream name is empty. - ErrEmptyStream = errors.New("stream name cannot be empty") - - // ErrEmptyConsumer is returned when consumer name is empty. - ErrEmptyConsumer = errors.New("consumer name cannot be empty") -) - -type subEventStore struct { - pubsub messaging.PubSub -} - -func NewSubscriber(url string, logger *slog.Logger) (events.Subscriber, error) { - pubsub, err := broker.NewPubSub(url, logger, broker.Prefix(eventsPrefix), broker.Exchange(exchangeName)) - if err != nil { - return nil, err - } - - return &subEventStore{ - pubsub: pubsub, - }, nil -} - -func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberConfig) error { - if cfg.Stream == "" { - return ErrEmptyStream - } - if cfg.Consumer == "" { - return ErrEmptyConsumer - } - - subCfg := messaging.SubscriberConfig{ - ID: cfg.Consumer, - Topic: cfg.Stream, - Handler: &eventHandler{ - handler: cfg.Handler, - ctx: ctx, - }, - DeliveryPolicy: messaging.DeliverNewPolicy, - } - - return es.pubsub.Subscribe(ctx, subCfg) -} - -func (es *subEventStore) Close() error { - return es.pubsub.Close() -} - -type event struct { - Data map[string]any -} - -func (re event) Encode() (map[string]any, error) { - return re.Data, nil -} - -type eventHandler struct { - handler events.EventHandler - ctx context.Context -} - -func (eh *eventHandler) Handle(msg *messaging.Message) error { - event := event{ - Data: make(map[string]any), - } - - if err := json.Unmarshal(msg.GetPayload(), &event.Data); err != nil { - return err - } - - if err := eh.handler.Handle(eh.ctx, event); err != nil { - return fmt.Errorf("failed to handle rabbitmq event: %s", err) - } - - return nil -} - -func (eh *eventHandler) Cancel() error { - return nil -} diff --git a/pkg/events/store/store_fluxmq.go b/pkg/events/store/store_fluxmq.go new file mode 100644 index 000000000..52030b7d7 --- /dev/null +++ b/pkg/events/store/store_fluxmq.go @@ -0,0 +1,41 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build es_fluxmq +// +build es_fluxmq + +package store + +import ( + "context" + "log" + "log/slog" + + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/fluxmq" +) + +// StreamAllEvents represents subject to subscribe for all the events. +const StreamAllEvents = ">" + +func init() { + log.Println("The binary was built using FluxMQ as the events store") +} + +func NewPublisher(ctx context.Context, url, connectionName string) (events.Publisher, error) { + pb, err := fluxmq.NewPublisher(ctx, url, connectionName) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewSubscriber(ctx context.Context, url, connectionName string, logger *slog.Logger) (events.Subscriber, error) { + pb, err := fluxmq.NewSubscriber(ctx, url, connectionName, logger) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/pkg/events/store/store_nats.go b/pkg/events/store/store_nats.go index 111d36974..08378a040 100644 --- a/pkg/events/store/store_nats.go +++ b/pkg/events/store/store_nats.go @@ -22,7 +22,7 @@ func init() { log.Println("The binary was build using Nats as the events store") } -func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { +func NewPublisher(ctx context.Context, url, _ string) (events.Publisher, error) { pb, err := nats.NewPublisher(ctx, url) if err != nil { return nil, err @@ -31,7 +31,7 @@ func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { return pb, nil } -func NewSubscriber(ctx context.Context, url string, logger *slog.Logger) (events.Subscriber, error) { +func NewSubscriber(ctx context.Context, url, _ string, logger *slog.Logger) (events.Subscriber, error) { pb, err := nats.NewSubscriber(ctx, url, logger) if err != nil { return nil, err diff --git a/pkg/events/store/store_rabbitmq.go b/pkg/events/store/store_rabbitmq.go deleted file mode 100644 index 45f78710d..000000000 --- a/pkg/events/store/store_rabbitmq.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build es_rabbitmq -// +build es_rabbitmq - -package store - -import ( - "context" - "log" - "log/slog" - - "github.com/absmach/supermq/pkg/events" - "github.com/absmach/supermq/pkg/events/rabbitmq" -) - -// StreamAllEvents represents subject to subscribe for all the events. -const StreamAllEvents = "events.#" - -func init() { - log.Println("The binary was build using RabbitMQ as the events store") -} - -func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { - pb, err := rabbitmq.NewPublisher(ctx, url) - if err != nil { - return nil, err - } - - return pb, nil -} - -func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) { - pb, err := rabbitmq.NewSubscriber(url, logger) - if err != nil { - return nil, err - } - - return pb, nil -} diff --git a/pkg/events/store/store_redis.go b/pkg/events/store/store_redis.go index 134f48b2b..924b26cef 100644 --- a/pkg/events/store/store_redis.go +++ b/pkg/events/store/store_redis.go @@ -1,8 +1,8 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -//go:build !es_nats && !es_rabbitmq -// +build !es_nats,!es_rabbitmq +//go:build !es_nats && !es_rabbitmq && !es_fluxmq +// +build !es_nats,!es_rabbitmq,!es_fluxmq package store @@ -22,7 +22,7 @@ func init() { log.Println("The binary was build using redis as the events store") } -func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { +func NewPublisher(ctx context.Context, url, _ string) (events.Publisher, error) { pb, err := redis.NewPublisher(ctx, url, events.UnpublishedEventsCheckInterval) if err != nil { return nil, err @@ -31,7 +31,7 @@ func NewPublisher(ctx context.Context, url string) (events.Publisher, error) { return pb, nil } -func NewSubscriber(_ context.Context, url string, logger *slog.Logger) (events.Subscriber, error) { +func NewSubscriber(_ context.Context, url, _ string, logger *slog.Logger) (events.Subscriber, error) { pb, err := redis.NewSubscriber(url, logger) if err != nil { return nil, err diff --git a/pkg/groups/events/consumer/streams.go b/pkg/groups/events/consumer/streams.go index 2e39e0ea2..7c94740d8 100644 --- a/pkg/groups/events/consumer/streams.go +++ b/pkg/groups/events/consumer/streams.go @@ -49,7 +49,7 @@ type eventHandler struct { } func GroupsEventsSubscribe(ctx context.Context, repo groups.Repository, esURL, esConsumerName string, logger *slog.Logger) error { - subscriber, err := store.NewSubscriber(ctx, esURL, logger) + subscriber, err := store.NewSubscriber(ctx, esURL, "groups-es-sub", logger) if err != nil { return err } diff --git a/pkg/grpcclient/connect_test.go b/pkg/grpcclient/connect_test.go index 3f7c8499d..6712fab19 100644 --- a/pkg/grpcclient/connect_test.go +++ b/pkg/grpcclient/connect_test.go @@ -43,8 +43,8 @@ func TestHandler(t *testing.T) { config: Config{ URL: "localhost:8080", Timeout: time.Second, - ClientCert: "../../docker/ssl/certs/supermq-server.crt", - ClientKey: "../../docker/ssl/certs/supermq-server.key", + ClientCert: "../../docker/ssl/certs/magistrala-server.crt", + ClientKey: "../../docker/ssl/certs/magistrala-server.key", ServerCAFile: "../../docker/ssl/certs/ca.crt", }, err: nil, @@ -72,7 +72,7 @@ func TestHandler(t *testing.T) { config: Config{ URL: "localhost:8080", Timeout: time.Second, - ServerCAFile: "../../docker/ssl/certs/supermq-server.key", + ServerCAFile: "../../docker/ssl/certs/magistrala-server.key", }, err: errors.New("failed to load root ca: failed to append root ca to tls.Config"), }, @@ -82,7 +82,7 @@ func TestHandler(t *testing.T) { URL: "localhost:8080", Timeout: time.Second, ClientCert: "invalid", - ClientKey: "../../docker/ssl/certs/supermq-server.key", + ClientKey: "../../docker/ssl/certs/magistrala-server.key", ServerCAFile: "../../docker/ssl/certs/ca.crt", }, err: errors.New("failed to client certificate and key tls: failed to find any PEM data in certificate input"), @@ -92,7 +92,7 @@ func TestHandler(t *testing.T) { config: Config{ URL: "localhost:8080", Timeout: time.Second, - ClientCert: "../../docker/ssl/certs/supermq-server.crt", + ClientCert: "../../docker/ssl/certs/magistrala-server.crt", ClientKey: "invalid", ServerCAFile: "../../docker/ssl/certs/ca.crt", }, diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 000000000..5daccfd24 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,12 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package logger + +import "log/slog" + +type RunInfo struct { + Level slog.Level + Details []slog.Attr + Message string +} diff --git a/pkg/messaging/brokers/brokers_fluxmq.go b/pkg/messaging/brokers/brokers_fluxmq.go new file mode 100644 index 000000000..ea5d4a7a7 --- /dev/null +++ b/pkg/messaging/brokers/brokers_fluxmq.go @@ -0,0 +1,48 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build msg_fluxmq +// +build msg_fluxmq + +package brokers + +import ( + "context" + "log" + "log/slog" + + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/messaging/fluxmq" +) + +// SubjectAllMessages represents subject to subscribe for all the messages. +const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".>" + +func init() { + log.Println("The binary was built using FluxMQ as the message broker") +} + +// ConnectionName returns an option that sets a human-readable connection name +// for identifying this client in the FluxMQ admin UI. +func ConnectionName(name string) messaging.Option { + return fluxmq.ConnectionName(name) +} + +func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) { + pb, err := fluxmq.NewPublisher(ctx, url, opts...) + if err != nil { + return nil, err + } + + return pb, nil +} + +func NewPubSub(ctx context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) { + opts = append(opts, fluxmq.DirectTopicIngress()) + pb, err := fluxmq.NewPubSub(ctx, url, logger, opts...) + if err != nil { + return nil, err + } + + return pb, nil +} diff --git a/pkg/messaging/brokers/brokers_nats.go b/pkg/messaging/brokers/brokers_nats.go index 25bc88405..a28640770 100644 --- a/pkg/messaging/brokers/brokers_nats.go +++ b/pkg/messaging/brokers/brokers_nats.go @@ -1,8 +1,8 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -//go:build !msg_rabbitmq -// +build !msg_rabbitmq +//go:build !msg_fluxmq && !msg_rabbitmq && !rabbitmq +// +build !msg_fluxmq,!msg_rabbitmq,!rabbitmq package brokers @@ -19,7 +19,13 @@ import ( const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".>" func init() { - log.Println("The binary was build using Nats as the message broker") + log.Println("The binary was built using NATS as the message broker") +} + +// ConnectionName is a no-op for the NATS backend. It exists for API +// compatibility with the FluxMQ variant. +func ConnectionName(_ string) messaging.Option { + return func(_ any) error { return nil } } func NewPublisher(ctx context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) { diff --git a/pkg/messaging/brokers/brokers_rabbitmq.go b/pkg/messaging/brokers/brokers_rabbitmq.go deleted file mode 100644 index 9f59df804..000000000 --- a/pkg/messaging/brokers/brokers_rabbitmq.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build msg_rabbitmq -// +build msg_rabbitmq - -package brokers - -import ( - "context" - "log" - "log/slog" - - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/rabbitmq" -) - -// SubjectAllMessages represents subject to subscribe for all the messages. -const SubjectAllMessages = string(messaging.MsgTopicPrefix) + ".#" - -func init() { - log.Println("The binary was build using RabbitMQ as the message broker") -} - -func NewPublisher(_ context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) { - pb, err := rabbitmq.NewPublisher(url, opts...) - if err != nil { - return nil, err - } - - return pb, nil -} - -func NewPubSub(_ context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) { - pb, err := rabbitmq.NewPubSub(url, logger, opts...) - if err != nil { - return nil, err - } - - return pb, nil -} diff --git a/pkg/messaging/brokers/tracing/brokers_rabbitmq.go b/pkg/messaging/brokers/tracing/brokers_fluxmq.go similarity index 54% rename from pkg/messaging/brokers/tracing/brokers_rabbitmq.go rename to pkg/messaging/brokers/tracing/brokers_fluxmq.go index 241576f38..5bdb7c47a 100644 --- a/pkg/messaging/brokers/tracing/brokers_rabbitmq.go +++ b/pkg/messaging/brokers/tracing/brokers_fluxmq.go @@ -1,8 +1,8 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -//go:build msg_rabbitmq -// +build msg_rabbitmq +//go:build msg_fluxmq +// +build msg_fluxmq package brokers @@ -10,17 +10,17 @@ import ( "log" "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/rabbitmq/tracing" + "github.com/absmach/supermq/pkg/messaging/fluxmq/tracing" "github.com/absmach/supermq/pkg/server" "go.opentelemetry.io/otel/trace" ) func init() { - log.Println("The binary was build using RabbitMQ as the message broker") + log.Println("The binary was built using FluxMQ as the message broker") } -func NewPublisher(cfg server.Config, tracer trace.Tracer, pub messaging.Publisher) messaging.Publisher { - return tracing.NewPublisher(cfg, tracer, pub) +func NewPublisher(cfg server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher { + return tracing.NewPublisher(cfg, tracer, publisher) } func NewPubSub(cfg server.Config, tracer trace.Tracer, pubsub messaging.PubSub) messaging.PubSub { diff --git a/pkg/messaging/brokers/tracing/brokers_nats.go b/pkg/messaging/brokers/tracing/brokers_nats.go index cfb77e3dd..7ce2faed7 100644 --- a/pkg/messaging/brokers/tracing/brokers_nats.go +++ b/pkg/messaging/brokers/tracing/brokers_nats.go @@ -1,8 +1,8 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -//go:build !msg_rabbitmq -// +build !msg_rabbitmq +//go:build !msg_fluxmq && !msg_rabbitmq && !rabbitmq +// +build !msg_fluxmq,!msg_rabbitmq,!rabbitmq package brokers @@ -16,7 +16,7 @@ import ( ) func init() { - log.Println("The binary was build using Nats as the message broker") + log.Println("The binary was built using NATS as the message broker") } func NewPublisher(cfg server.Config, tracer trace.Tracer, publisher messaging.Publisher) messaging.Publisher { diff --git a/pkg/messaging/events/publisher.go b/pkg/messaging/events/publisher.go index 22633c2fe..d3d601d1d 100644 --- a/pkg/messaging/events/publisher.go +++ b/pkg/messaging/events/publisher.go @@ -19,7 +19,7 @@ type publisherES struct { } func NewPublisherMiddleware(ctx context.Context, pub messaging.Publisher, url string) (messaging.Publisher, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "msg-es-pub") if err != nil { return nil, err } @@ -38,7 +38,7 @@ func (es *publisherES) Publish(ctx context.Context, topic string, msg *messaging me := publishEvent{ domainID: msg.Domain, channelID: msg.Channel, - clientID: msg.Publisher, + clientID: msg.ClientIdentity(), subtopic: msg.Subtopic, } diff --git a/pkg/messaging/events/pubsub.go b/pkg/messaging/events/pubsub.go index 09a4129c8..62369a1ba 100644 --- a/pkg/messaging/events/pubsub.go +++ b/pkg/messaging/events/pubsub.go @@ -26,7 +26,7 @@ type pubsubES struct { } func NewPubSubMiddleware(ctx context.Context, pubsub messaging.PubSub, url string) (messaging.PubSub, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "msg-es-pub") if err != nil { return nil, err } @@ -45,7 +45,7 @@ func (es *pubsubES) Publish(ctx context.Context, topic string, msg *messaging.Me me := publishEvent{ domainID: msg.Domain, channelID: msg.Channel, - clientID: msg.Publisher, + clientID: msg.ClientIdentity(), subtopic: msg.Subtopic, } diff --git a/pkg/messaging/fluxmq/options.go b/pkg/messaging/fluxmq/options.go new file mode 100644 index 000000000..236e06ee4 --- /dev/null +++ b/pkg/messaging/fluxmq/options.go @@ -0,0 +1,92 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "errors" + + "github.com/absmach/supermq/pkg/messaging" + "github.com/nats-io/nats.go/jetstream" +) + +// ErrInvalidType is returned when the provided value is not of the expected type. +var ErrInvalidType = errors.New("invalid type") + +const msgPrefix = "m" + +type options struct { + prefix string + connectionName string + directTopicIngress bool +} + +func defaultOptions() options { + return options{ + prefix: msgPrefix, + } +} + +// Prefix sets the topic prefix for publisher and subscriber. +func Prefix(prefix string) messaging.Option { + return func(val any) error { + switch v := val.(type) { + case *publisher: + v.prefix = prefix + case *pubsub: + v.prefix = prefix + default: + return ErrInvalidType + } + + return nil + } +} + +// ConnectionName sets a human-readable connection name sent to FluxMQ +// for identifying this client in the broker's admin UI. +func ConnectionName(name string) messaging.Option { + return func(val any) error { + switch v := val.(type) { + case *publisher: + v.connectionName = name + case *pubsub: + v.connectionName = name + default: + return ErrInvalidType + } + + return nil + } +} + +// DirectTopicIngress enables direct MQTT topic delivery in addition to stream +// queue delivery. This is opt-in because direct topic messages are normalized +// from broker-native metadata instead of the protobuf queue envelope. +func DirectTopicIngress() messaging.Option { + return func(val any) error { + switch v := val.(type) { + case *publisher: + return nil + case *pubsub: + v.directTopicIngress = true + default: + return ErrInvalidType + } + + return nil + } +} + +// JSStreamConfig is a no-op for FluxMQ AMQP backend and exists only to keep +// option-compatibility with legacy NATS broker wrappers. +func JSStreamConfig(_ jetstream.StreamConfig) messaging.Option { + return func(val any) error { + switch val.(type) { + case *publisher, *pubsub: + return nil + default: + return ErrInvalidType + } + } +} diff --git a/pkg/messaging/fluxmq/publisher.go b/pkg/messaging/fluxmq/publisher.go new file mode 100644 index 000000000..31fe8f677 --- /dev/null +++ b/pkg/messaging/fluxmq/publisher.go @@ -0,0 +1,89 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "context" + "log/slog" + "strconv" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + "github.com/absmach/supermq/pkg/messaging" +) + +var _ messaging.Publisher = (*publisher)(nil) + +type publisher struct { + client *fluxamqp.Client + options +} + +// NewPublisher creates a FluxMQ-backed message publisher. +func NewPublisher(_ context.Context, url string, opts ...messaging.Option) (messaging.Publisher, error) { + pub := &publisher{ + options: defaultOptions(), + } + + for _, opt := range opts { + if err := opt(pub); err != nil { + return nil, err + } + } + + logger := slog.Default() + amqpOpts := fluxamqp.NewOptions().SetURL(url). + SetConnectionName(pub.connectionName). + SetOnConnectionLost(func(err error) { + logger.Warn("FluxMQ message publisher connection lost", "error", err) + }). + SetOnReconnecting(func(attempt int) { + logger.Info("FluxMQ message publisher reconnecting", "attempt", attempt) + }). + SetOnConnect(func() { + logger.Info("FluxMQ message publisher connected") + }) + + client, err := fluxamqp.New(amqpOpts) + if err != nil { + return nil, err + } + if err := client.Connect(); err != nil { + return nil, err + } + if err := declareStream(client, pub.prefix); err != nil { + _ = client.Close() + return nil, err + } + + pub.client = client + + return pub, nil +} + +func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error { + if topic == "" { + return ErrEmptyTopic + } + + props := map[string]string{ + "external_id": msg.GetPublisher(), + "protocol": msg.GetProtocol(), + } + if clientID := msg.ClientIdentity(); clientID != "" { + props["client_id"] = clientID + } + if msg.GetCreated() != 0 { + props["created"] = strconv.FormatInt(msg.GetCreated(), 10) + } + + return pub.client.PublishWithOptionsContext(ctx, &fluxamqp.PublishOptions{ + Topic: queueTopic(pub.prefix, topic), + Payload: msg.GetPayload(), + Properties: props, + }) +} + +func (pub *publisher) Close() error { + return pub.client.Close() +} diff --git a/pkg/messaging/fluxmq/pubsub.go b/pkg/messaging/fluxmq/pubsub.go new file mode 100644 index 000000000..fc1c4088d --- /dev/null +++ b/pkg/messaging/fluxmq/pubsub.go @@ -0,0 +1,293 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strconv" + "strings" + "sync" + "time" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + fluxtopics "github.com/absmach/fluxmq/topics" + "github.com/absmach/supermq/pkg/messaging" +) + +// Publisher and Subscriber errors. +var ( + ErrNotSubscribed = errors.New("not subscribed") + ErrEmptyTopic = errors.New("empty topic") + ErrEmptyID = errors.New("empty id") +) + +var _ messaging.PubSub = (*pubsub)(nil) + +type pubsub struct { + publisher + logger *slog.Logger + + mu sync.Mutex + subscriptions map[string]subscription +} + +type subscription struct { + streamTopic string + mqttTopic string +} + +// NewPubSub creates a FluxMQ-backed message publisher/subscriber. +func NewPubSub(_ context.Context, url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) { + ps := &pubsub{ + publisher: publisher{ + options: defaultOptions(), + }, + logger: logger, + subscriptions: make(map[string]subscription), + } + + for _, opt := range opts { + if err := opt(ps); err != nil { + return nil, err + } + } + + amqpOpts := fluxamqp.NewOptions().SetURL(url). + SetConnectionName(ps.connectionName). + SetOnConnectionLost(func(err error) { + ps.logWarn("FluxMQ message pub/sub connection lost", "error", err) + }). + SetOnReconnecting(func(attempt int) { + ps.logInfo("FluxMQ message pub/sub reconnecting", "attempt", attempt) + }). + SetOnConnect(func() { + ps.logInfo("FluxMQ message pub/sub connected", url, ps.prefix) + }) + + client, err := fluxamqp.New(amqpOpts) + if err != nil { + return nil, err + } + if err := client.Connect(); err != nil { + return nil, err + } + if err := declareStream(client, ps.prefix); err != nil { + _ = client.Close() + return nil, err + } + + ps.client = client + + return ps, nil +} + +func (ps *pubsub) Subscribe(_ context.Context, cfg messaging.SubscriberConfig) error { + if cfg.ID == "" { + return ErrEmptyID + } + if cfg.Topic == "" { + return ErrEmptyTopic + } + + group := formatConsumerName(cfg.Topic, cfg.ID) + opts := &fluxamqp.StreamConsumeOptions{ + QueueName: streamQueue(ps.prefix), + Filter: streamFilter(ps.prefix, cfg.Topic), + ConsumerGroup: group, + } + + switch cfg.DeliveryPolicy { + case messaging.DeliverNewPolicy: + opts.Offset = "last" + case messaging.DeliverAllPolicy: + opts.Offset = "first" + } + + if err := ps.client.SubscribeToStream(opts, func(msg *fluxamqp.QueueMessage) { + if err := ps.handle(cfg.Handler, msg); err != nil { + ps.logWarn("failed to process FluxMQ stream message", "error", err, "topic", cfg.Topic, "consumer_group", group) + } + }); err != nil { + return err + } + + sub := subscription{ + streamTopic: queueFilter(ps.prefix, cfg.Topic), + } + + if ps.directTopicIngress { + // Subscribe to regular MQTT topics so that messages published directly + // by MQTT clients (not through the stream queue) are also received. + sub.mqttTopic = topicFilter(ps.prefix, cfg.Topic) + if err := ps.client.Subscribe(sub.mqttTopic, func(msg *fluxamqp.Message) { + if err := ps.handleTopicMessage(cfg.Handler, msg); err != nil { + ps.logWarn("failed to process FluxMQ topic message", "error", err, "topic", sub.mqttTopic) + } + }); err != nil { + _ = ps.client.UnsubscribeFromStream(sub.streamTopic) + + return err + } + } + + ps.mu.Lock() + ps.subscriptions[subscriptionKey(cfg.ID, cfg.Topic)] = sub + ps.mu.Unlock() + + return nil +} + +func (ps *pubsub) Unsubscribe(_ context.Context, id, topic string) error { + if id == "" { + return ErrEmptyID + } + if topic == "" { + return ErrEmptyTopic + } + + key := subscriptionKey(id, topic) + + ps.mu.Lock() + sub, ok := ps.subscriptions[key] + ps.mu.Unlock() + if !ok { + return ErrNotSubscribed + } + + streamErr := ps.client.UnsubscribeFromStream(sub.streamTopic) + var topicErr error + if sub.mqttTopic != "" { + topicErr = ps.client.Unsubscribe(sub.mqttTopic) + } + + ps.mu.Lock() + delete(ps.subscriptions, key) + ps.mu.Unlock() + + return errors.Join(streamErr, topicErr) +} + +func (ps *pubsub) handleTopicMessage(h messaging.MessageHandler, msg *fluxamqp.Message) error { + mqttTopic := fluxtopics.AMQPTopicToMQTT(msg.Topic) + m, err := messageFromDelivery(msg.Body, msg.Headers, msg.Timestamp, ps.prefix, mqttTopic) + if err != nil { + return fmt.Errorf("failed to parse MQTT topic %q: %w", msg.Topic, err) + } + + if err := h.Handle(m); err != nil { + ps.logWarn("failed to handle topic message", "error", err) + } + + return nil +} + +func (ps *pubsub) handle(h messaging.MessageHandler, msg *fluxamqp.QueueMessage) error { + mqttTopic := strings.TrimPrefix(msg.RoutingKey, queuePrefix) + m, err := messageFromDelivery(msg.Body, msg.Headers, msg.Timestamp, ps.prefix, mqttTopic) + if err != nil { + if rejectErr := msg.Reject(); rejectErr != nil { + return errors.Join(err, rejectErr) + } + return err + } + + handleErr := h.Handle(m) + ackType := ps.errAckType(handleErr) + if handleErr != nil { + ps.logWarn("failed to handle message", "ack_type", ackType.String(), "error", handleErr) + } + + if ackErr := ps.handleAck(ackType, msg); ackErr != nil { + return fmt.Errorf("failed to %s message: %w", ackType.String(), ackErr) + } + + return nil +} + +func messageFromDelivery(body []byte, headers map[string]any, ts time.Time, prefix, mqttTopic string) (*messaging.Message, error) { + domain, channel, subtopic, err := parseMQTTTopic(prefix, mqttTopic) + if err != nil { + return nil, err + } + + clientID := stringHeader(headers, "client_id") + publisher := stringHeader(headers, "external_id") + + protocol := stringHeader(headers, "protocol") + if protocol == "" { + protocol = "mqtt" + } + + created := ts.UnixNano() + if s := stringHeader(headers, "created"); s != "" { + if v, err := strconv.ParseInt(s, 10, 64); err == nil { + created = v + } + } + + return &messaging.Message{ + Domain: domain, + Channel: channel, + Subtopic: subtopic, + Payload: body, + Publisher: publisher, + ClientId: clientID, + Protocol: protocol, + Created: created, + }, nil +} + +func (ps *pubsub) errAckType(err error) messaging.AckType { + if err == nil { + return messaging.Ack + } + if e, ok := err.(messaging.Error); ok && e != nil { + return e.Ack() + } + return messaging.NoAck +} + +func (ps *pubsub) handleAck(at messaging.AckType, msg *fluxamqp.QueueMessage) error { + switch at { + case messaging.Ack, messaging.DoubleAck: + return msg.Ack() + case messaging.Nack, messaging.InProgress: + return msg.Nack() + case messaging.Term: + return msg.Reject() + case messaging.NoAck: + return nil + default: + return nil + } +} + +func (ps *pubsub) logInfo(msg string, args ...any) { + if ps.logger != nil { + ps.logger.Info(msg, args...) + return + } + + slog.Info(msg, args...) +} + +func (ps *pubsub) logWarn(msg string, args ...any) { + if ps.logger != nil { + ps.logger.Warn(msg, args...) + return + } + + slog.Warn(msg, args...) +} + +func (ps *pubsub) Close() error { + return ps.client.Close() +} + +func subscriptionKey(id, topic string) string { + return fmt.Sprintf("%s|%s", id, topic) +} diff --git a/pkg/messaging/fluxmq/pubsub_test.go b/pkg/messaging/fluxmq/pubsub_test.go new file mode 100644 index 000000000..a0c8a5344 --- /dev/null +++ b/pkg/messaging/fluxmq/pubsub_test.go @@ -0,0 +1,216 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "testing" + "time" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + "github.com/absmach/supermq/pkg/messaging" + amqp091 "github.com/rabbitmq/amqp091-go" +) + +type testHandler struct { + msg *messaging.Message +} + +func (h *testHandler) Handle(msg *messaging.Message) error { + h.msg = msg + return nil +} + +func (h *testHandler) Cancel() error { + return nil +} + +func TestHandleTopicMessageNormalizesAMQPRoutingKey(t *testing.T) { + ps := &pubsub{ + publisher: publisher{ + options: options{prefix: "m"}, + }, + } + h := &testHandler{} + ts := time.Unix(1710000000, 123) + + err := ps.handleTopicMessage(h, &fluxamqp.Message{ + Delivery: amqp091.Delivery{ + Body: []byte("payload"), + Timestamp: ts, + Headers: amqp091.Table{ + "external_id": "ext-user", + "client_id": "client-9", + }, + }, + Topic: "m.domain.c.channel.test", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.msg == nil { + t.Fatal("expected handler to receive a message") + } + if h.msg.Domain != "domain" || h.msg.Channel != "channel" || h.msg.Subtopic != "test" { + t.Fatalf("unexpected parsed message: %+v", h.msg) + } + if string(h.msg.Payload) != "payload" { + t.Fatalf("unexpected payload: %q", string(h.msg.Payload)) + } + if h.msg.Publisher != "ext-user" { + t.Fatalf("unexpected publisher: %q", h.msg.Publisher) + } + if h.msg.GetClientId() != "client-9" { + t.Fatalf("unexpected client ID: %q", h.msg.GetClientId()) + } + if h.msg.Created != ts.UnixNano() { + t.Fatalf("unexpected created timestamp: %d", h.msg.Created) + } +} + +func TestHandleTopicMessageUsesMQTTIdentityFields(t *testing.T) { + ps := &pubsub{ + publisher: publisher{ + options: options{prefix: "m"}, + }, + } + h := &testHandler{} + ts := time.Unix(1710000000, 0) + + err := ps.handleTopicMessage(h, &fluxamqp.Message{ + Delivery: amqp091.Delivery{ + Body: []byte("payload"), + Timestamp: ts, + Headers: amqp091.Table{ + "external_id": "ext-77", + "client_id": "client-7", + "protocol": "http", + "created": "1234567890000000000", + }, + }, + Topic: "m.domain.c.channel.sub", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.msg.Publisher != "ext-77" { + t.Fatalf("expected publisher from explicit header, got %q", h.msg.Publisher) + } + if h.msg.GetClientId() != "client-7" { + t.Fatalf("expected client ID from header, got %q", h.msg.GetClientId()) + } + if h.msg.Protocol != "http" { + t.Fatalf("expected protocol from header, got %q", h.msg.Protocol) + } + if h.msg.Created != 1234567890000000000 { + t.Fatalf("expected created from header, got %d", h.msg.Created) + } +} + +func TestMessageFromDelivery(t *testing.T) { + cases := []struct { + name string + body []byte + headers map[string]any + ts time.Time + prefix string + mqttTopic string + want *messaging.Message + wantErr bool + }{ + { + name: "use explicit publisher and client_id headers", + body: []byte(`{"temperature":22.5}`), + headers: map[string]any{"external_id": "ext-1", "client_id": "client-1", "protocol": "mqtt", "created": "1710000000000000123"}, + ts: time.Unix(1710000000, 0), + prefix: "writers", + mqttTopic: "writers/domain/c/channel/temp", + want: &messaging.Message{ + Domain: "domain", + Channel: "channel", + Subtopic: "temp", + Payload: []byte(`{"temperature":22.5}`), + Publisher: "ext-1", + ClientId: "client-1", + Protocol: "mqtt", + Created: 1710000000000000123, + }, + }, + { + name: "use explicit publisher header when present", + body: []byte("raw"), + headers: map[string]any{"external_id": "tenant-user", "client_id": "client-22"}, + ts: time.Unix(1710000000, 250), + prefix: "m", + mqttTopic: "m/dom/c/ch", + want: &messaging.Message{ + Domain: "dom", + Channel: "ch", + Subtopic: "", + Payload: []byte("raw"), + Publisher: "tenant-user", + ClientId: "client-22", + Protocol: "mqtt", + Created: time.Unix(1710000000, 250).UnixNano(), + }, + }, + { + name: "missing identity headers leaves publisher and client ID empty", + body: []byte("raw"), + headers: nil, + ts: time.Unix(1710000000, 500), + prefix: "m", + mqttTopic: "m/dom/c/ch", + want: &messaging.Message{ + Domain: "dom", + Channel: "ch", + Subtopic: "", + Payload: []byte("raw"), + Publisher: "", + ClientId: "", + Protocol: "mqtt", + Created: time.Unix(1710000000, 500).UnixNano(), + }, + }, + { + name: "invalid topic", + body: []byte("x"), + prefix: "m", + mqttTopic: "wrong/topic", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := messageFromDelivery(tc.body, tc.headers, tc.ts, tc.prefix, tc.mqttTopic) + if tc.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.Domain != tc.want.Domain || got.Channel != tc.want.Channel || got.Subtopic != tc.want.Subtopic { + t.Fatalf("topic mismatch: got domain=%q channel=%q subtopic=%q", got.Domain, got.Channel, got.Subtopic) + } + if string(got.Payload) != string(tc.want.Payload) { + t.Fatalf("payload mismatch: got %q, want %q", got.Payload, tc.want.Payload) + } + if got.Publisher != tc.want.Publisher { + t.Fatalf("publisher mismatch: got %q, want %q", got.Publisher, tc.want.Publisher) + } + if got.GetClientId() != tc.want.GetClientId() { + t.Fatalf("client ID mismatch: got %q, want %q", got.GetClientId(), tc.want.GetClientId()) + } + if got.Protocol != tc.want.Protocol { + t.Fatalf("protocol mismatch: got %q, want %q", got.Protocol, tc.want.Protocol) + } + if got.Created != tc.want.Created { + t.Fatalf("created mismatch: got %d, want %d", got.Created, tc.want.Created) + } + }) + } +} diff --git a/pkg/messaging/fluxmq/topic.go b/pkg/messaging/fluxmq/topic.go new file mode 100644 index 000000000..6b8bca591 --- /dev/null +++ b/pkg/messaging/fluxmq/topic.go @@ -0,0 +1,156 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "fmt" + "strings" + + fluxamqp "github.com/absmach/fluxmq/client/amqp" + "github.com/absmach/supermq/pkg/messaging" +) + +const queuePrefix = "$queue/" + +var ( + topicReplacer = strings.NewReplacer(".", "/", "*", "+", ">", "#") + nameReplacer = strings.NewReplacer( + " ", "_", + ".", "_", + "*", "_", + ">", "_", + "/", "_", + "\\", "_", + ) +) + +func canonicalPrefix(prefix string) string { + prefix = strings.TrimSpace(prefix) + if prefix == "" { + return msgPrefix + } + return prefix +} + +func streamQueue(prefix string) string { + return canonicalPrefix(prefix) +} + +func brokerPath(topic string) string { + topic = strings.TrimSpace(topic) + topic = strings.TrimPrefix(topic, ".") + if topic == "" { + return "" + } + + return topicReplacer.Replace(topic) +} + +func streamFilter(prefix, topic string) string { + path := filterPath(prefix, topic) + if path == "" { + return "#" + } + return path +} + +func queueFilter(prefix, topic string) string { + queue := streamQueue(prefix) + path := streamFilter(prefix, topic) + if path == "#" { + return queuePrefix + queue + "/#" + } + + return queuePrefix + queue + "/" + path +} + +func queueTopic(prefix, topic string) string { + queue := streamQueue(prefix) + path := brokerPath(topic) + if path == "" { + return queuePrefix + queue + } + + return queuePrefix + queue + "/" + path +} + +func filterPath(prefix, topic string) string { + topic = strings.TrimSpace(topic) + if topic == "" || topic == ">" { + return "#" + } + + prefix = canonicalPrefix(prefix) + switch { + case topic == prefix: + topic = ">" + case strings.HasPrefix(topic, prefix+"."): + topic = strings.TrimPrefix(topic, prefix+".") + } + + return brokerPath(topic) +} + +func formatConsumerName(topic, id string) string { + // Consumer group names must avoid whitespace and wildcard/path separators. + topic = nameReplacer.Replace(topic) + id = nameReplacer.Replace(id) + return fmt.Sprintf("%s-%s", topic, id) +} + +// topicFilter returns the MQTT topic filter for subscribing to regular +// (non-queued) messages. It converts a NATS-style topic to MQTT format +// with the prefix prepended. +// For example, with prefix "m" and topic "m.>", it returns "m/#". +func topicFilter(prefix, topic string) string { + prefix = canonicalPrefix(prefix) + path := filterPath(prefix, topic) + if path == "" || path == "#" { + return prefix + "/#" + } + + return prefix + "/" + path +} + +func parseMQTTTopic(prefix, topic string) (domainID, channelID, subtopic string, err error) { + topic = strings.TrimPrefix(strings.TrimSpace(topic), "/") + prefix = canonicalPrefix(prefix) + if !strings.HasPrefix(topic, prefix+"/") { + return "", "", "", messaging.ErrMalformedTopic + } + normalized := "/" + msgPrefix + "/" + strings.TrimPrefix(topic, prefix+"/") + + domainID, channelID, subtopic, _, err = messaging.ParseSubscribeTopic(normalized) + if err != nil { + return "", "", "", err + } + + return domainID, channelID, subtopic, nil +} + +func stringHeader(headers map[string]any, key string) string { + if headers == nil { + return "" + } + v, ok := headers[key] + if !ok { + return "" + } + switch s := v.(type) { + case string: + return s + case []byte: + return string(s) + default: + return "" + } +} + +func declareStream(client *fluxamqp.Client, prefix string) error { + _, err := client.DeclareStreamQueue(&fluxamqp.StreamQueueOptions{ + Name: streamQueue(prefix), + Durable: true, + }) + return err +} diff --git a/pkg/messaging/fluxmq/topic_test.go b/pkg/messaging/fluxmq/topic_test.go new file mode 100644 index 000000000..b4c844680 --- /dev/null +++ b/pkg/messaging/fluxmq/topic_test.go @@ -0,0 +1,272 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package fluxmq + +import ( + "strings" + "testing" +) + +func TestQueueTopic(t *testing.T) { + got := queueTopic("m", "domain.c.channel.subtopic") + want := "$queue/m/domain/c/channel/subtopic" + if got != want { + t.Fatalf("queue topic mismatch: got %q, want %q", got, want) + } +} + +func TestStreamFilter(t *testing.T) { + cases := []struct { + name string + prefix string + topic string + want string + }{ + { + name: "all messages with prefix", + prefix: "m", + topic: "m.>", + want: "#", + }, + { + name: "all messages without explicit prefix", + prefix: "writers", + topic: ">", + want: "#", + }, + { + name: "specific topic filter", + prefix: "writers", + topic: "writers.domain.c.channel.*", + want: "domain/c/channel/+", + }, + { + name: "topic without prefix", + prefix: "alarms", + topic: "domain.c.channel.>", + want: "domain/c/channel/#", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := streamFilter(tc.prefix, tc.topic) + if got != tc.want { + t.Fatalf("stream filter mismatch: got %q, want %q", got, tc.want) + } + }) + } +} + +func TestQueueFilter(t *testing.T) { + got := queueFilter("writers", "writers.>") + want := "$queue/writers/#" + if got != want { + t.Fatalf("queue filter mismatch: got %q, want %q", got, want) + } +} + +func TestTopicFilter(t *testing.T) { + cases := []struct { + name string + prefix string + topic string + want string + }{ + { + name: "all messages with prefix", + prefix: "m", + topic: "m.>", + want: "m/#", + }, + { + name: "wildcard topic", + prefix: "writers", + topic: ">", + want: "writers/#", + }, + { + name: "specific topic", + prefix: "m", + topic: "m.domain.c.channel.subtopic", + want: "m/domain/c/channel/subtopic", + }, + { + name: "single-level wildcard", + prefix: "m", + topic: "m.domain.c.*.subtopic", + want: "m/domain/c/+/subtopic", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := topicFilter(tc.prefix, tc.topic) + if got != tc.want { + t.Fatalf("topic filter mismatch: got %q, want %q", got, tc.want) + } + }) + } +} + +func TestParseMQTTTopic(t *testing.T) { + cases := []struct { + name string + prefix string + topic string + domain string + channel string + subtopic string + shouldErr bool + }{ + { + name: "default prefix with subtopic path", + prefix: "m", + topic: "m/domain/c/channel/sub/topic", + domain: "domain", + channel: "channel", + subtopic: "sub.topic", + }, + { + name: "alternate prefix without subtopic", + prefix: "writers", + topic: "writers/domain/c/channel", + domain: "domain", + channel: "channel", + subtopic: "", + }, + { + name: "leading slash is ignored", + prefix: "alarms", + topic: "/alarms/domain/c/channel/critical/high", + domain: "domain", + channel: "channel", + subtopic: "critical.high", + }, + { + name: "mismatched prefix", + prefix: "writers", + topic: "m/domain/c/channel", + shouldErr: true, + }, + { + name: "invalid shape", + prefix: "m", + topic: "m/domain/channel", + shouldErr: true, + }, + { + name: "empty subtopic segment", + prefix: "m", + topic: "m/domain/c/channel/sub//topic", + shouldErr: true, + }, + { + name: "dot topic is invalid", + prefix: "m", + topic: "m.domain.c.channel.sub.topic", + shouldErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + domain, channel, subtopic, err := parseMQTTTopic(tc.prefix, tc.topic) + if tc.shouldErr { + if err == nil { + t.Fatal("expected parse error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + if domain != tc.domain || channel != tc.channel || subtopic != tc.subtopic { + t.Fatalf("parsed topic mismatch: got domain=%q channel=%q subtopic=%q", domain, channel, subtopic) + } + }) + } +} + +func TestParseMQTTTopicFromStreamRoutingKey(t *testing.T) { + // Stream queue routing keys have the format "$queue///c/[/]". + // After stripping "$queue/", the remainder is a valid MQTT-style topic for parseMQTTTopic. + cases := []struct { + name string + routingKey string + prefix string + domain string + channel string + subtopic string + }{ + { + name: "writers queue with subtopic", + routingKey: "$queue/writers/domain/c/channel/temp", + prefix: "writers", + domain: "domain", + channel: "channel", + subtopic: "temp", + }, + { + name: "main queue without subtopic", + routingKey: "$queue/m/domain/c/channel", + prefix: "m", + domain: "domain", + channel: "channel", + subtopic: "", + }, + { + name: "alarms queue with nested subtopic", + routingKey: "$queue/alarms/dom/c/ch/critical/high", + prefix: "alarms", + domain: "dom", + channel: "ch", + subtopic: "critical.high", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + mqttTopic := strings.TrimPrefix(tc.routingKey, "$queue/") + domain, channel, subtopic, err := parseMQTTTopic(tc.prefix, mqttTopic) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if domain != tc.domain || channel != tc.channel || subtopic != tc.subtopic { + t.Fatalf("got domain=%q channel=%q subtopic=%q", domain, channel, subtopic) + } + }) + } +} + +func TestStringHeader(t *testing.T) { + headers := map[string]any{ + "external_id": "pub-1", + "number": 42, + "bytes": []byte("bin"), + } + if got := stringHeader(headers, "external_id"); got != "pub-1" { + t.Fatalf("expected pub-1, got %q", got) + } + if got := stringHeader(headers, "bytes"); got != "bin" { + t.Fatalf("expected bin, got %q", got) + } + if got := stringHeader(headers, "number"); got != "" { + t.Fatalf("expected empty for non-string, got %q", got) + } + if got := stringHeader(headers, "missing"); got != "" { + t.Fatalf("expected empty for missing key, got %q", got) + } + if got := stringHeader(nil, "any"); got != "" { + t.Fatalf("expected empty for nil headers, got %q", got) + } +} + +func TestFormatConsumerName(t *testing.T) { + got := formatConsumerName("m.domain.c.channel.>", "re/service 1") + want := "m_domain_c_channel__-re_service_1" + if got != want { + t.Fatalf("consumer name mismatch: got %q, want %q", got, want) + } +} diff --git a/pkg/messaging/rabbitmq/tracing/publisher.go b/pkg/messaging/fluxmq/tracing/publisher.go similarity index 77% rename from pkg/messaging/rabbitmq/tracing/publisher.go rename to pkg/messaging/fluxmq/tracing/publisher.go index 9213547a3..a73ad78d3 100644 --- a/pkg/messaging/rabbitmq/tracing/publisher.go +++ b/pkg/messaging/fluxmq/tracing/publisher.go @@ -16,10 +16,9 @@ import ( const publishOP = "publish" var defaultAttributes = []attribute.KeyValue{ - attribute.String("messaging.system", "rabbitmq"), + attribute.String("messaging.system", "fluxmq"), attribute.String("network.protocol.name", "amqp"), - attribute.String("network.protocol.version", "3.9.20"), - attribute.String("messaging.rabbitmq.destination.routing_key", "supermq"), + attribute.String("network.protocol.version", "0.9.1"), } var _ messaging.Publisher = (*publisherMiddleware)(nil) @@ -41,9 +40,8 @@ func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging } func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error { - ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer) + ctx, span := tracing.CreateSpan(ctx, publishOP, msg.ClientIdentity(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer) defer span.End() - span.SetAttributes(defaultAttributes...) return pm.publisher.Publish(ctx, topic, msg) diff --git a/pkg/messaging/rabbitmq/tracing/pubsub.go b/pkg/messaging/fluxmq/tracing/pubsub.go similarity index 100% rename from pkg/messaging/rabbitmq/tracing/pubsub.go rename to pkg/messaging/fluxmq/tracing/pubsub.go diff --git a/pkg/messaging/handler/logging.go b/pkg/messaging/handler/logging.go deleted file mode 100644 index 429306581..000000000 --- a/pkg/messaging/handler/logging.go +++ /dev/null @@ -1,171 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build !test - -package handler - -import ( - "context" - "log/slog" - "strings" - "time" - - "github.com/absmach/mgate/pkg/session" -) - -var _ session.Handler = (*loggingMiddleware)(nil) - -type loggingMiddleware struct { - logger *slog.Logger - svc session.Handler -} - -// NewLogging adds logging facilities to the adapter. -func NewLogging(svc session.Handler, logger *slog.Logger) session.Handler { - return &loggingMiddleware{logger, svc} -} - -// AuthConnect implements session.Handler. -func (lm *loggingMiddleware) AuthConnect(ctx context.Context) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("AuthConnect failed", args...) - return - } - lm.logger.Info("AuthConnect completed successfully", args...) - }(time.Now()) - return lm.svc.AuthConnect(ctx) -} - -// AuthPublish implements session.Handler. -func (lm *loggingMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if topic != nil { - args = append(args, slog.String("topic", *topic)) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("AuthPublish failed", args...) - return - } - lm.logger.Info("AuthPublish completed successfully", args...) - }(time.Now()) - return lm.svc.AuthPublish(ctx, topic, payload) -} - -// AuthSubscribe implements session.Handler. -func (lm *loggingMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if topics != nil { - args = append(args, slog.String("topics", strings.Join(*topics, ", "))) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("AuthSubscribe failed", args...) - return - } - lm.logger.Info("AuthSubscribe completed successfully", args...) - }(time.Now()) - return lm.svc.AuthSubscribe(ctx, topics) -} - -// Connect implements session.Handler. -func (lm *loggingMiddleware) Connect(ctx context.Context) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Connect failed", args...) - return - } - lm.logger.Info("Connect completed successfully", args...) - }(time.Now()) - return lm.svc.Connect(ctx) -} - -// Disconnect implements session.Handler. -func (lm *loggingMiddleware) Disconnect(ctx context.Context) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Disconnect failed", args...) - return - } - lm.logger.Info("Disconnect completed successfully", args...) - }(time.Now()) - return lm.svc.Disconnect(ctx) -} - -// Publish logs the publish request. It logs the time it took to complete the request. -// If the request fails, it logs the error. -func (lm *loggingMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if topic != nil { - args = append(args, slog.String("topic", *topic)) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Publish failed", args...) - return - } - lm.logger.Info("Publish completed successfully", args...) - }(time.Now()) - return lm.svc.Publish(ctx, topic, payload) -} - -// Subscribe implements session.Handler. -func (lm *loggingMiddleware) Subscribe(ctx context.Context, topics *[]string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if topics != nil { - args = append(args, slog.String("topics", strings.Join(*topics, ", "))) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Subscribe failed", args...) - return - } - lm.logger.Info("Subscribe completed successfully", args...) - }(time.Now()) - return lm.svc.Subscribe(ctx, topics) -} - -// Unsubscribe implements session.Handler. -func (lm *loggingMiddleware) Unsubscribe(ctx context.Context, topics *[]string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - } - if topics != nil { - args = append(args, slog.String("topics", strings.Join(*topics, ", "))) - } - if err != nil { - args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("Unsubscribe failed", args...) - return - } - lm.logger.Info("Unsubscribe completed successfully", args...) - }(time.Now()) - return lm.svc.Unsubscribe(ctx, topics) -} diff --git a/pkg/messaging/handler/metrics.go b/pkg/messaging/handler/metrics.go deleted file mode 100644 index 4f9521011..000000000 --- a/pkg/messaging/handler/metrics.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -//go:build !test - -package handler - -import ( - "context" - "time" - - "github.com/absmach/mgate/pkg/session" - "github.com/go-kit/kit/metrics" -) - -var _ session.Handler = (*metricsMiddleware)(nil) - -type metricsMiddleware struct { - counter metrics.Counter - latency metrics.Histogram - svc session.Handler -} - -// NewMetrics instruments adapter by tracking request count and latency. -func NewMetrics(svc session.Handler, counter metrics.Counter, latency metrics.Histogram) session.Handler { - return &metricsMiddleware{ - counter: counter, - latency: latency, - svc: svc, - } -} - -// AuthConnect implements session.Handler. -func (mm *metricsMiddleware) AuthConnect(ctx context.Context) error { - defer func(begin time.Time) { - mm.counter.With("method", "publish").Add(1) - mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.AuthConnect(ctx) -} - -// AuthPublish implements session.Handler. -func (mm *metricsMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - defer func(begin time.Time) { - mm.counter.With("method", "publish").Add(1) - mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.AuthPublish(ctx, topic, payload) -} - -// AuthSubscribe implements session.Handler. -func (*metricsMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error { - return nil -} - -// Connect implements session.Handler. -func (*metricsMiddleware) Connect(ctx context.Context) error { - return nil -} - -// Disconnect implements session.Handler. -func (*metricsMiddleware) Disconnect(ctx context.Context) error { - return nil -} - -// Publish instruments Publish method with metrics. -func (mm *metricsMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error { - defer func(begin time.Time) { - mm.counter.With("method", "publish").Add(1) - mm.latency.With("method", "publish").Observe(time.Since(begin).Seconds()) - }(time.Now()) - - return mm.svc.Publish(ctx, topic, payload) -} - -// Subscribe implements session.Handler. -func (*metricsMiddleware) Subscribe(ctx context.Context, topics *[]string) error { - return nil -} - -// Unsubscribe implements session.Handler. -func (*metricsMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error { - return nil -} diff --git a/pkg/messaging/handler/tracing.go b/pkg/messaging/handler/tracing.go deleted file mode 100644 index 3475c0f66..000000000 --- a/pkg/messaging/handler/tracing.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package handler - -import ( - "context" - - "github.com/absmach/mgate/pkg/session" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -const ( - authConnectOP = "auth_connect_op" - authPublishOP = "auth_publish_op" - authSubscribeOP = "auth_subscribe_op" - connectOP = "connect_op" - disconnectOP = "disconnect_op" - subscribeOP = "subscribe_op" - unsubscribeOP = "unsubscribe_op" - publishOP = "publish_op" -) - -var _ session.Handler = (*handlerMiddleware)(nil) - -type handlerMiddleware struct { - handler session.Handler - tracer trace.Tracer -} - -// NewTracing creates a new session.Handler middleware with tracing. -func NewTracing(tracer trace.Tracer, handler session.Handler) session.Handler { - return &handlerMiddleware{ - tracer: tracer, - handler: handler, - } -} - -// AuthConnect traces auth connect operations. -func (h *handlerMiddleware) AuthConnect(ctx context.Context) error { - kvOpts := []attribute.KeyValue{} - s, ok := session.FromContext(ctx) - if ok { - kvOpts = append(kvOpts, attribute.String("client_id", s.ID)) - kvOpts = append(kvOpts, attribute.String("username", s.Username)) - } - ctx, span := h.tracer.Start(ctx, authConnectOP, trace.WithAttributes(kvOpts...)) - defer span.End() - return h.handler.AuthConnect(ctx) -} - -// AuthPublish traces auth publish operations. -func (h *handlerMiddleware) AuthPublish(ctx context.Context, topic *string, payload *[]byte) error { - kvOpts := []attribute.KeyValue{} - s, ok := session.FromContext(ctx) - if ok { - kvOpts = append(kvOpts, attribute.String("client_id", s.ID)) - if topic != nil { - kvOpts = append(kvOpts, attribute.String("topic", *topic)) - } - } - ctx, span := h.tracer.Start(ctx, authPublishOP, trace.WithAttributes(kvOpts...)) - defer span.End() - return h.handler.AuthPublish(ctx, topic, payload) -} - -// AuthSubscribe traces auth subscribe operations. -func (h *handlerMiddleware) AuthSubscribe(ctx context.Context, topics *[]string) error { - kvOpts := []attribute.KeyValue{} - s, ok := session.FromContext(ctx) - if ok { - kvOpts = append(kvOpts, attribute.String("client_id", s.ID)) - if topics != nil { - kvOpts = append(kvOpts, attribute.StringSlice("topics", *topics)) - } - } - ctx, span := h.tracer.Start(ctx, authSubscribeOP, trace.WithAttributes(kvOpts...)) - defer span.End() - return h.handler.AuthSubscribe(ctx, topics) -} - -// Connect traces connect operations. -func (h *handlerMiddleware) Connect(ctx context.Context) error { - ctx, span := h.tracer.Start(ctx, connectOP) - defer span.End() - return h.handler.Connect(ctx) -} - -// Disconnect traces disconnect operations. -func (h *handlerMiddleware) Disconnect(ctx context.Context) error { - ctx, span := h.tracer.Start(ctx, disconnectOP) - defer span.End() - return h.handler.Disconnect(ctx) -} - -// Publish traces publish operations. -func (h *handlerMiddleware) Publish(ctx context.Context, topic *string, payload *[]byte) error { - ctx, span := h.tracer.Start(ctx, publishOP) - defer span.End() - return h.handler.Publish(ctx, topic, payload) -} - -// Subscribe traces subscribe operations. -func (h *handlerMiddleware) Subscribe(ctx context.Context, topics *[]string) error { - ctx, span := h.tracer.Start(ctx, subscribeOP) - defer span.End() - return h.handler.Subscribe(ctx, topics) -} - -// Unsubscribe traces unsubscribe operations. -func (h *handlerMiddleware) Unsubscribe(ctx context.Context, topics *[]string) error { - ctx, span := h.tracer.Start(ctx, unsubscribeOP) - defer span.End() - return h.handler.Unsubscribe(ctx, topics) -} diff --git a/pkg/messaging/message.pb.go b/pkg/messaging/message.pb.go index c1738b73c..1b55e949e 100644 --- a/pkg/messaging/message.pb.go +++ b/pkg/messaging/message.pb.go @@ -33,7 +33,8 @@ type Message struct { Publisher string `protobuf:"bytes,4,opt,name=publisher,proto3" json:"publisher,omitempty"` Protocol string `protobuf:"bytes,5,opt,name=protocol,proto3" json:"protocol,omitempty"` Payload []byte `protobuf:"bytes,6,opt,name=payload,proto3" json:"payload,omitempty"` - Created int64 `protobuf:"varint,7,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds + Created int64 `protobuf:"varint,7,opt,name=created,proto3" json:"created,omitempty"` // Unix timestamp in nanoseconds + ClientId string `protobuf:"bytes,8,opt,name=client_id,json=clientId,proto3" json:"client_id,omitempty"` // Transport-level client identifier unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -117,11 +118,18 @@ func (x *Message) GetCreated() int64 { return 0 } +func (x *Message) GetClientId() string { + if x != nil { + return x.ClientId + } + return "" +} + var File_pkg_messaging_message_proto protoreflect.FileDescriptor const file_pkg_messaging_message_proto_rawDesc = "" + "\n" + - "\x1bpkg/messaging/message.proto\x12\tmessaging\"\xc5\x01\n" + + "\x1bpkg/messaging/message.proto\x12\tmessaging\"\xe2\x01\n" + "\aMessage\x12\x18\n" + "\achannel\x18\x01 \x01(\tR\achannel\x12\x16\n" + "\x06domain\x18\x02 \x01(\tR\x06domain\x12\x1a\n" + @@ -129,7 +137,8 @@ const file_pkg_messaging_message_proto_rawDesc = "" + "\tpublisher\x18\x04 \x01(\tR\tpublisher\x12\x1a\n" + "\bprotocol\x18\x05 \x01(\tR\bprotocol\x12\x18\n" + "\apayload\x18\x06 \x01(\fR\apayload\x12\x18\n" + - "\acreated\x18\a \x01(\x03R\acreatedB\rZ\v./messagingb\x06proto3" + "\acreated\x18\a \x01(\x03R\acreated\x12\x1b\n" + + "\tclient_id\x18\b \x01(\tR\bclientIdB\rZ\v./messagingb\x06proto3" var ( file_pkg_messaging_message_proto_rawDescOnce sync.Once diff --git a/pkg/messaging/message.proto b/pkg/messaging/message.proto index 723e1f0a5..47df3b0c4 100644 --- a/pkg/messaging/message.proto +++ b/pkg/messaging/message.proto @@ -15,4 +15,5 @@ message Message { string protocol = 5; bytes payload = 6; int64 created = 7; // Unix timestamp in nanoseconds + string client_id = 8; // Transport-level client identifier } diff --git a/pkg/messaging/message_identity.go b/pkg/messaging/message_identity.go new file mode 100644 index 000000000..b66e99b4d --- /dev/null +++ b/pkg/messaging/message_identity.go @@ -0,0 +1,16 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package messaging + +// ClientIdentity returns the transport client identifier carried by the message. +// It falls back to Publisher for backward compatibility with older messages. +func (m *Message) ClientIdentity() string { + if m == nil { + return "" + } + if clientID := m.GetClientId(); clientID != "" { + return clientID + } + return m.GetPublisher() +} diff --git a/pkg/messaging/nats/tracing/publisher.go b/pkg/messaging/nats/tracing/publisher.go index a3380cedb..77dde8dff 100644 --- a/pkg/messaging/nats/tracing/publisher.go +++ b/pkg/messaging/nats/tracing/publisher.go @@ -40,7 +40,7 @@ func NewPublisher(config server.Config, tracer trace.Tracer, publisher messaging } func (pm *publisherMiddleware) Publish(ctx context.Context, topic string, msg *messaging.Message) error { - ctx, span := tracing.CreateSpan(ctx, publishOP, msg.GetPublisher(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer) + ctx, span := tracing.CreateSpan(ctx, publishOP, msg.ClientIdentity(), topic, msg.GetSubtopic(), len(msg.GetPayload()), pm.host, trace.SpanKindClient, pm.tracer) defer span.End() span.SetAttributes(defaultAttributes...) diff --git a/pkg/messaging/rabbitmq/doc.go b/pkg/messaging/rabbitmq/doc.go deleted file mode 100644 index 401e0669b..000000000 --- a/pkg/messaging/rabbitmq/doc.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package rabbitmq holds the implementation of the Publisher and PubSub -// interfaces for the RabbitMQ messaging system, the internal messaging -// broker of the SuperMQ IoT platform. Due to the practical requirements -// implementation Publisher is created alongside PubSub. The reason for -// this is that Subscriber implementation of RabbitMQ brings the burden of -// additional struct fields which are not used by Publisher. Subscriber -// is not implemented separately because PubSub can be used where Subscriber is needed. -package rabbitmq diff --git a/pkg/messaging/rabbitmq/options.go b/pkg/messaging/rabbitmq/options.go deleted file mode 100644 index e8121515a..000000000 --- a/pkg/messaging/rabbitmq/options.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq - -import ( - "errors" - - "github.com/absmach/supermq/pkg/messaging" -) - -// ErrInvalidType is returned when the provided value is not of the expected type. -var ErrInvalidType = errors.New("invalid type") - -const ( - exchangeName = "messages" - msgPrefix = "m" -) - -type options struct { - prefix string - exchange string -} - -func defaultOptions() options { - return options{ - prefix: msgPrefix, - exchange: exchangeName, - } -} - -// Prefix sets the prefix for the publisher. -func Prefix(prefix string) messaging.Option { - return func(val any) error { - switch v := val.(type) { - case *publisher: - v.prefix = prefix - case *pubsub: - v.prefix = prefix - default: - return ErrInvalidType - } - return nil - } -} - -// Exchange sets the exchange for the publisher or subscriber. -func Exchange(exchange string) messaging.Option { - return func(val any) error { - switch v := val.(type) { - case *publisher: - v.exchange = exchange - case *pubsub: - v.exchange = exchange - default: - return ErrInvalidType - } - - return nil - } -} diff --git a/pkg/messaging/rabbitmq/publisher.go b/pkg/messaging/rabbitmq/publisher.go deleted file mode 100644 index 448815930..000000000 --- a/pkg/messaging/rabbitmq/publisher.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq - -import ( - "context" - "fmt" - - "github.com/absmach/supermq/pkg/messaging" - amqp "github.com/rabbitmq/amqp091-go" - "google.golang.org/protobuf/proto" -) - -var _ messaging.Publisher = (*publisher)(nil) - -type publisher struct { - conn *amqp.Connection - channel *amqp.Channel - options -} - -// NewPublisher returns RabbitMQ message Publisher. -func NewPublisher(url string, opts ...messaging.Option) (messaging.Publisher, error) { - pub := &publisher{ - options: defaultOptions(), - } - - for _, opt := range opts { - if err := opt(pub); err != nil { - return nil, err - } - } - - conn, err := amqp.Dial(url) - if err != nil { - return nil, err - } - pub.conn = conn - - ch, err := conn.Channel() - if err != nil { - return nil, err - } - if err := ch.ExchangeDeclare(pub.exchange, amqp.ExchangeTopic, true, false, false, false, nil); err != nil { - return nil, err - } - pub.channel = ch - - return pub, nil -} - -func (pub *publisher) Publish(ctx context.Context, topic string, msg *messaging.Message) error { - if topic == "" { - return ErrEmptyTopic - } - data, err := proto.Marshal(msg) - if err != nil { - return err - } - - subject := fmt.Sprintf("%s.%s", pub.prefix, topic) - - err = pub.channel.PublishWithContext( - ctx, - pub.exchange, - subject, - false, - false, - amqp.Publishing{ - Headers: amqp.Table{}, - ContentType: "application/octet-stream", - AppId: "supermq-publisher", - Body: data, - }) - if err != nil { - return err - } - - return nil -} - -func (pub *publisher) Close() error { - return pub.conn.Close() -} diff --git a/pkg/messaging/rabbitmq/pubsub.go b/pkg/messaging/rabbitmq/pubsub.go deleted file mode 100644 index e520faf0a..000000000 --- a/pkg/messaging/rabbitmq/pubsub.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq - -import ( - "context" - "errors" - "fmt" - "log/slog" - "strings" - "sync" - - "github.com/absmach/supermq/pkg/messaging" - amqp "github.com/rabbitmq/amqp091-go" - "google.golang.org/protobuf/proto" -) - -var ( - // ErrNotSubscribed indicates that the topic is not subscribed to. - ErrNotSubscribed = errors.New("not subscribed") - - // ErrEmptyTopic indicates the absence of topic. - ErrEmptyTopic = errors.New("empty topic") - - // ErrEmptyID indicates the absence of ID. - ErrEmptyID = errors.New("empty ID") -) -var _ messaging.PubSub = (*pubsub)(nil) - -type subscription struct { - cancel func() error -} -type pubsub struct { - publisher - logger *slog.Logger - subscriptions map[string]map[string]subscription - mu sync.Mutex -} - -// NewPubSub returns RabbitMQ message publisher/subscriber. -func NewPubSub(url string, logger *slog.Logger, opts ...messaging.Option) (messaging.PubSub, error) { - ps := &pubsub{ - publisher: publisher{ - options: defaultOptions(), - }, - logger: logger, - subscriptions: make(map[string]map[string]subscription), - } - - for _, opt := range opts { - if err := opt(ps); err != nil { - return nil, err - } - } - conn, err := amqp.Dial(url) - if err != nil { - return nil, err - } - ps.conn = conn - - ch, err := conn.Channel() - if err != nil { - return nil, err - } - if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil { - return nil, err - } - ps.channel = ch - - return ps, nil -} - -func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) error { - if cfg.ID == "" { - return ErrEmptyID - } - if cfg.Topic == "" { - return ErrEmptyTopic - } - ps.mu.Lock() - - cfg.Topic = formatTopic(cfg.Topic) - // Check topic - s, ok := ps.subscriptions[cfg.Topic] - if ok { - // Check client ID - if _, ok := s[cfg.ID]; ok { - // Unlocking, so that Unsubscribe() can access ps.subscriptions - ps.mu.Unlock() - if err := ps.Unsubscribe(ctx, cfg.ID, cfg.Topic); err != nil { - return err - } - - ps.mu.Lock() - // value of s can be changed while ps.mu is unlocked - s = ps.subscriptions[cfg.Topic] - } - } - defer ps.mu.Unlock() - if s == nil { - s = make(map[string]subscription) - ps.subscriptions[cfg.Topic] = s - } - - clientID := fmt.Sprintf("%s-%s", cfg.Topic, cfg.ID) - - queue, err := ps.channel.QueueDeclare(clientID, true, false, false, false, nil) - if err != nil { - return err - } - - if err := ps.channel.QueueBind(queue.Name, cfg.Topic, ps.exchange, false, nil); err != nil { - return err - } - - msgs, err := ps.channel.Consume(queue.Name, clientID, true, false, false, false, nil) - if err != nil { - return err - } - go ps.handle(msgs, cfg.Handler) - s[cfg.ID] = subscription{ - cancel: func() error { - if err := ps.channel.Cancel(clientID, false); err != nil { - return err - } - return cfg.Handler.Cancel() - }, - } - - return nil -} - -func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error { - if id == "" { - return ErrEmptyID - } - if topic == "" { - return ErrEmptyTopic - } - ps.mu.Lock() - defer ps.mu.Unlock() - - topic = formatTopic(topic) - // Check topic - s, ok := ps.subscriptions[topic] - if !ok { - return ErrNotSubscribed - } - // Check topic ID - current, ok := s[id] - if !ok { - return ErrNotSubscribed - } - if current.cancel != nil { - if err := current.cancel(); err != nil { - return err - } - } - if err := ps.channel.QueueUnbind(topic, topic, exchangeName, nil); err != nil { - return err - } - - delete(s, id) - if len(s) == 0 { - delete(ps.subscriptions, topic) - } - return nil -} - -func (ps *pubsub) handle(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) { - for d := range deliveries { - var msg messaging.Message - if err := proto.Unmarshal(d.Body, &msg); err != nil { - ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err)) - return - } - if err := h.Handle(&msg); err != nil { - ps.logger.Warn(fmt.Sprintf("Failed to handle SuperMQ message: %s", err)) - return - } - } -} - -func formatTopic(topic string) string { - return strings.ReplaceAll(topic, ">", "#") -} diff --git a/pkg/messaging/rabbitmq/pubsub_test.go b/pkg/messaging/rabbitmq/pubsub_test.go deleted file mode 100644 index a3432a6a9..000000000 --- a/pkg/messaging/rabbitmq/pubsub_test.go +++ /dev/null @@ -1,460 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq_test - -import ( - "context" - "errors" - "fmt" - "testing" - - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/rabbitmq" - amqp "github.com/rabbitmq/amqp091-go" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/proto" -) - -const ( - topic = "topic" - msgPrefix = "m" - channel = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b" - subtopic = "engine" - clientID = "9b7b1b3f-b1b0-46a8-a717-b8213f9eda3b" - exchangeName = "messages" -) - -var ( - msgChan = make(chan *messaging.Message) - data = []byte("payload") -) - -var errFailedHandleMessage = errors.New("failed to handle supermq message") - -func TestPublisher(t *testing.T) { - // Subscribing with topic, and with subtopic, so that we can publish messages. - conn, ch, err := newConn() - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - topicChan := subscribe(t, ch, fmt.Sprintf("%s.%s", msgPrefix, topic)) - subtopicChan := subscribe(t, ch, fmt.Sprintf("%s.%s.%s", msgPrefix, topic, subtopic)) - - go rabbitHandler(topicChan, handler{}) - go rabbitHandler(subtopicChan, handler{}) - - t.Cleanup(func() { - conn.Close() - ch.Close() - }) - - cases := []struct { - desc string - channel string - subtopic string - payload []byte - }{ - { - desc: "publish message with nil payload", - payload: nil, - }, - { - desc: "publish message with string payload", - payload: data, - }, - { - desc: "publish message with channel", - payload: data, - channel: channel, - }, - { - desc: "publish message with subtopic", - payload: data, - subtopic: subtopic, - }, - { - desc: "publish message with channel and subtopic", - payload: data, - channel: channel, - subtopic: subtopic, - }, - } - - for _, tc := range cases { - expectedMsg := messaging.Message{ - Publisher: clientID, - Channel: tc.channel, - Subtopic: tc.subtopic, - Payload: tc.payload, - } - err = pubsub.Publish(context.TODO(), topic, &expectedMsg) - assert.Nil(t, err, fmt.Sprintf("%s: got unexpected error: %s", tc.desc, err)) - - receivedMsg := <-msgChan - assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - } -} - -func TestSubscribe(t *testing.T) { - // Creating rabbitmq connection and channel, so that we can publish messages. - conn, ch, err := newConn() - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - t.Cleanup(func() { - conn.Close() - ch.Close() - }) - - cases := []struct { - desc string - topic string - clientID string - err error - handler messaging.MessageHandler - }{ - { - desc: "Subscribe to a topic with an ID", - topic: topic, - clientID: "clientid1", - err: nil, - handler: handler{false, "clientid1"}, - }, - { - desc: "Subscribe to the same topic with a different ID", - topic: topic, - clientID: "clientid2", - err: nil, - handler: handler{false, "clientid2"}, - }, - { - desc: "Subscribe to an already subscribed topic with an ID", - topic: topic, - clientID: "clientid1", - err: nil, - handler: handler{false, "clientid1"}, - }, - { - desc: "Subscribe to a topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s", topic, subtopic), - clientID: "clientid1", - err: nil, - handler: handler{false, "clientid1"}, - }, - { - desc: "Subscribe to an already subscribed topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s", topic, subtopic), - clientID: "clientid1", - err: nil, - handler: handler{false, "clientid1"}, - }, - { - desc: "Subscribe to an empty topic with an ID", - topic: "", - clientID: "clientid1", - err: rabbitmq.ErrEmptyTopic, - handler: handler{false, "clientid1"}, - }, - { - desc: "Subscribe to a topic with empty id", - topic: topic, - clientID: "", - err: rabbitmq.ErrEmptyID, - handler: handler{false, ""}, - }, - } - for _, tc := range cases { - subCfg := messaging.SubscriberConfig{ - ID: tc.clientID, - Topic: tc.topic, - Handler: tc.handler, - } - err := pubsub.Subscribe(context.TODO(), subCfg) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err)) - - if tc.err == nil { - expectedMsg := messaging.Message{ - Publisher: "CLIENTID", - Channel: channel, - Subtopic: subtopic, - Payload: data, - } - - data, err := proto.Marshal(&expectedMsg) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - err = ch.PublishWithContext( - context.Background(), - exchangeName, - tc.topic, - false, - false, - amqp.Publishing{ - Headers: amqp.Table{}, - ContentType: "application/octet-stream", - AppId: "supermq-publisher", - Body: data, - }) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - receivedMsg := <-msgChan - assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Created, receivedMsg.Created, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Protocol, receivedMsg.Protocol, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Publisher, receivedMsg.Publisher, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Subtopic, receivedMsg.Subtopic, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - } - } -} - -func TestUnsubscribe(t *testing.T) { - // Test Subscribe and Unsubscribe - cases := []struct { - desc string - topic string - clientID string - err error - subscribe bool // True for subscribe and false for unsubscribe. - handler messaging.MessageHandler - }{ - { - desc: "Subscribe to a topic with an ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "clientid4", - err: nil, - subscribe: true, - handler: handler{false, "clientid4"}, - }, - { - desc: "Subscribe to the same topic with a different ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "clientid9", - err: nil, - subscribe: true, - handler: handler{false, "clientid9"}, - }, - { - desc: "Unsubscribe from a topic with an ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "clientid4", - err: nil, - subscribe: false, - handler: handler{false, "clientid4"}, - }, - { - desc: "Unsubscribe from same topic with different ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "clientid9", - err: nil, - subscribe: false, - handler: handler{false, "clientid9"}, - }, - { - desc: "Unsubscribe from a non-existent topic with an ID", - topic: "h", - clientID: "clientid4", - err: rabbitmq.ErrNotSubscribed, - subscribe: false, - handler: handler{false, "clientid4"}, - }, - { - desc: "Unsubscribe from an already unsubscribed topic with an ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "clientid4", - err: rabbitmq.ErrNotSubscribed, - subscribe: false, - handler: handler{false, "clientid4"}, - }, - { - desc: "Subscribe to a topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s.%s", msgPrefix, topic, subtopic), - clientID: "clientidd4", - err: nil, - subscribe: true, - handler: handler{false, "clientidd4"}, - }, - { - desc: "Unsubscribe from a topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s.%s", msgPrefix, topic, subtopic), - clientID: "clientidd4", - err: nil, - subscribe: false, - handler: handler{false, "clientidd4"}, - }, - { - desc: "Unsubscribe from an already unsubscribed topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s.%s", msgPrefix, topic, subtopic), - clientID: "clientid4", - err: rabbitmq.ErrNotSubscribed, - subscribe: false, - handler: handler{false, "clientid4"}, - }, - { - desc: "Unsubscribe from an empty topic with an ID", - topic: "", - clientID: "clientid4", - err: rabbitmq.ErrEmptyTopic, - subscribe: false, - handler: handler{false, "clientid4"}, - }, - { - desc: "Unsubscribe from a topic with empty ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic), - clientID: "", - err: rabbitmq.ErrEmptyID, - subscribe: false, - handler: handler{false, ""}, - }, - { - desc: "Subscribe to a new topic with an ID", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic+"2"), - clientID: "clientid55", - err: nil, - subscribe: true, - handler: handler{true, "clientid5"}, - }, - { - desc: "Unsubscribe from a topic with an ID with failing handler", - topic: fmt.Sprintf("%s.%s", msgPrefix, topic+"2"), - clientID: "clientid55", - err: errFailedHandleMessage, - subscribe: false, - handler: handler{true, "clientid5"}, - }, - { - desc: "Subscribe to a new topic with subtopic with an ID", - topic: fmt.Sprintf("%s.%s.%s", msgPrefix, topic+"2", subtopic), - clientID: "clientid55", - err: nil, - subscribe: true, - handler: handler{true, "clientid5"}, - }, - { - desc: "Unsubscribe from a topic with subtopic with an ID with failing handler", - topic: fmt.Sprintf("%s.%s.%s", msgPrefix, topic+"2", subtopic), - clientID: "clientid55", - err: errFailedHandleMessage, - subscribe: false, - handler: handler{true, "clientid5"}, - }, - } - - for _, tc := range cases { - subCfg := messaging.SubscriberConfig{ - ID: tc.clientID, - Topic: tc.topic, - Handler: tc.handler, - } - switch tc.subscribe { - case true: - err := pubsub.Subscribe(context.TODO(), subCfg) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err)) - default: - err := pubsub.Unsubscribe(context.TODO(), tc.clientID, tc.topic) - assert.Equal(t, err, tc.err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, tc.err, err)) - } - } -} - -func TestPubSub(t *testing.T) { - cases := []struct { - desc string - topic string - clientID string - err error - handler messaging.MessageHandler - }{ - { - desc: "Subscribe to a topic with an ID", - topic: topic, - clientID: clientID, - err: nil, - handler: handler{false, clientID}, - }, - { - desc: "Subscribe to the same topic with a different ID", - topic: topic, - clientID: clientID + "1", - err: nil, - handler: handler{false, clientID + "1"}, - }, - { - desc: "Subscribe to a topic with a subtopic with an ID", - topic: fmt.Sprintf("%s.%s", topic, subtopic), - clientID: clientID + "2", - err: nil, - handler: handler{false, clientID + "2"}, - }, - { - desc: "Subscribe to an empty topic with an ID", - topic: "", - clientID: clientID, - err: rabbitmq.ErrEmptyTopic, - handler: handler{false, clientID}, - }, - { - desc: "Subscribe to a topic with empty id", - topic: topic, - clientID: "", - err: rabbitmq.ErrEmptyID, - handler: handler{false, ""}, - }, - } - for _, tc := range cases { - subject := "" - if tc.topic != "" { - subject = fmt.Sprintf("%s.%s", msgPrefix, tc.topic) - } - subCfg := messaging.SubscriberConfig{ - ID: tc.clientID, - Topic: subject, - Handler: tc.handler, - } - err := pubsub.Subscribe(context.TODO(), subCfg) - - switch tc.err { - case nil: - // If no error, publish message, and receive after subscribing. - expectedMsg := messaging.Message{ - Channel: channel, - Payload: data, - } - - err = pubsub.Publish(context.TODO(), tc.topic, &expectedMsg) - assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err)) - - receivedMsg := <-msgChan - assert.Equal(t, expectedMsg.Channel, receivedMsg.Channel, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - assert.Equal(t, expectedMsg.Payload, receivedMsg.Payload, fmt.Sprintf("%s: expected %+v got %+v\n", tc.desc, &expectedMsg, receivedMsg)) - - err = pubsub.Unsubscribe(context.TODO(), tc.clientID, fmt.Sprintf("%s.%s", msgPrefix, tc.topic)) - assert.Nil(t, err, fmt.Sprintf("%s got unexpected error: %s", tc.desc, err)) - default: - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected: %s, but got: %s", tc.desc, err, tc.err)) - } - } -} - -type handler struct { - fail bool - publisher string -} - -func (h handler) Handle(msg *messaging.Message) error { - if msg.GetPublisher() != h.publisher { - msgChan <- msg - } - return nil -} - -func (h handler) Cancel() error { - if h.fail { - return errFailedHandleMessage - } - return nil -} diff --git a/pkg/messaging/rabbitmq/setup_test.go b/pkg/messaging/rabbitmq/setup_test.go deleted file mode 100644 index 1aaeb672e..000000000 --- a/pkg/messaging/rabbitmq/setup_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package rabbitmq_test - -import ( - "fmt" - "log" - "log/slog" - "os" - "os/signal" - "syscall" - "testing" - - smqlog "github.com/absmach/supermq/logger" - "github.com/absmach/supermq/pkg/messaging" - "github.com/absmach/supermq/pkg/messaging/rabbitmq" - "github.com/ory/dockertest/v3" - amqp "github.com/rabbitmq/amqp091-go" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/proto" -) - -const ( - port = "5672/tcp" - brokerName = "rabbitmq" - brokerVersion = "3.12.12-alpine" -) - -var ( - publisher messaging.Publisher - pubsub messaging.PubSub - logger *slog.Logger - address string -) - -func TestMain(m *testing.M) { - pool, err := dockertest.NewPool("") - if err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - container, err := pool.Run(brokerName, brokerVersion, []string{}) - if err != nil { - log.Fatalf("Could not start container: %s", err) - } - handleInterrupt(pool, container) - - address = fmt.Sprintf("amqp://%s:%s", "localhost", container.GetPort(port)) - if err := pool.Retry(func() error { - publisher, err = rabbitmq.NewPublisher(address) - return err - }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - logger, err = smqlog.New(os.Stdout, "debug") - if err != nil { - log.Fatal(err.Error()) - } - if err := pool.Retry(func() error { - pubsub, err = rabbitmq.NewPubSub(address, logger) - return err - }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - code := m.Run() - - if err := pool.Purge(container); err != nil { - log.Fatalf("Could not purge container: %s", err) - } - - os.Exit(code) -} - -func newConn() (*amqp.Connection, *amqp.Channel, error) { - conn, err := amqp.Dial(address) - if err != nil { - return nil, nil, err - } - ch, err := conn.Channel() - if err != nil { - return nil, nil, err - } - if err := ch.ExchangeDeclare(exchangeName, amqp.ExchangeTopic, true, false, false, false, nil); err != nil { - return nil, nil, err - } - - return conn, ch, nil -} - -func rabbitHandler(deliveries <-chan amqp.Delivery, h messaging.MessageHandler) { - for d := range deliveries { - var msg messaging.Message - if err := proto.Unmarshal(d.Body, &msg); err != nil { - logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err)) - return - } - if err := h.Handle(&msg); err != nil { - logger.Warn(fmt.Sprintf("Failed to handle SuperMQ message: %s", err)) - return - } - } -} - -func subscribe(t *testing.T, ch *amqp.Channel, topic string) <-chan amqp.Delivery { - _, err := ch.QueueDeclare(topic, true, true, true, false, nil) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - err = ch.QueueBind(topic, topic, exchangeName, false, nil) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - clientID := fmt.Sprintf("%s-%s", topic, clientID) - msgs, err := ch.Consume(topic, clientID, true, false, false, false, nil) - assert.Nil(t, err, fmt.Sprintf("got unexpected error: %s", err)) - - return msgs -} - -func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) { - c := make(chan os.Signal, 2) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - go func() { - <-c - if err := pool.Purge(container); err != nil { - log.Fatalf("Could not purge container: %s", err) - } - os.Exit(0) - }() -} diff --git a/pkg/policies/evaluator.go b/pkg/policies/evaluator.go index 0bcedd76d..334a604de 100644 --- a/pkg/policies/evaluator.go +++ b/pkg/policies/evaluator.go @@ -27,6 +27,9 @@ const ( UserType = "user" DomainType = "domain" PlatformType = "platform" + RulesType = "rules" + ReportsType = "reports" + AlarmsType = "alarms" ) const ( @@ -54,7 +57,7 @@ const ( CreatePermission = "create" ) -const SuperMQObject = "supermq" +const MagistralaObject = "magistrala" type Evaluator interface { // CheckPolicy checks if the subject has a relation on the object. diff --git a/pkg/policies/spicedb/service.go b/pkg/policies/spicedb/service.go index a9b82f42f..5207fc98d 100644 --- a/pkg/policies/spicedb/service.go +++ b/pkg/policies/spicedb/service.go @@ -318,7 +318,7 @@ func (ps *policyService) ListPermissions(ctx context.Context, pr policies.Policy } func (ps *policyService) policyValidation(pr policies.Policy) error { - if pr.ObjectType == policies.PlatformType && pr.Object != policies.SuperMQObject { + if pr.ObjectType == policies.PlatformType && pr.Object != policies.MagistralaObject { return errPlatform } @@ -409,7 +409,7 @@ func (ps *policyService) userGroupPreConditions(ctx context.Context, pr policies Subject: pr.Subject, SubjectType: pr.SubjectType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }); err == nil { isSuperAdmin = true @@ -484,7 +484,7 @@ func (ps *policyService) userClientPreConditions(ctx context.Context, pr policie Subject: pr.Subject, SubjectType: pr.SubjectType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }); err == nil { isSuperAdmin = true @@ -549,7 +549,7 @@ func (ps *policyService) userDomainPreConditions(ctx context.Context, pr policie Subject: pr.Subject, SubjectType: pr.SubjectType, Permission: policies.AdminPermission, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }); err == nil { return preconds, fmt.Errorf("use already exists in domain") diff --git a/pkg/re/events/consumer/decode.go b/pkg/re/events/consumer/decode.go new file mode 100644 index 000000000..fb7cb59ac --- /dev/null +++ b/pkg/re/events/consumer/decode.go @@ -0,0 +1,204 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/roles" + rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/re" +) + +var ( + errDecodeAddRuleEvent = errors.New("failed to decode rule add event") + errDecodeUpdateRuleEvent = errors.New("failed to decode rule update event") + errDecodeUpdateRuleTagsEvent = errors.New("failed to decode rule update tags event") + errDecodeUpdateRuleScheduleEvent = errors.New("failed to decode rule update schedule event") + errDecodeEnableRuleEvent = errors.New("failed to decode rule enable event") + errDecodeDisableRuleEvent = errors.New("failed to decode rule disable event") + errDecodeRemoveRuleEvent = errors.New("failed to decode rule remove event") + + errID = errors.New("missing or invalid 'id'") + errName = errors.New("missing or invalid 'name'") + errTags = errors.New("invalid 'tags'") + errStatus = errors.New("missing or invalid 'status'") + errConvertStatus = errors.New("failed to convert status") + errCreatedBy = errors.New("missing or invalid 'created_by'") + errCreatedAt = errors.New("failed to parse 'created_at' time") + errUpdatedAt = errors.New("failed to parse 'updated_at' time") + errDecodeLogic = errors.New("failed to decode 'logic'") + errDecodeSchedule = errors.New("failed to decode 'schedule'") +) + +// ToRule decodes a map[string]any event payload into a re.Rule. +func ToRule(data map[string]any) (re.Rule, error) { + var r re.Rule + + id, ok := data["id"].(string) + if !ok { + return re.Rule{}, errID + } + r.ID = id + + name, ok := data["name"].(string) + if !ok { + return re.Rule{}, errName + } + r.Name = name + + stat, ok := data["status"].(string) + if !ok { + return re.Rule{}, errStatus + } + st, err := re.ToStatus(stat) + if err != nil { + return re.Rule{}, errors.Wrap(errConvertStatus, err) + } + r.Status = st + + cby, ok := data["created_by"].(string) + if !ok { + return re.Rule{}, errCreatedBy + } + r.CreatedBy = cby + + cat, ok := data["created_at"].(string) + if !ok { + return re.Rule{}, errCreatedAt + } + ct, err := time.Parse(re.TimeLayout, cat) + if err != nil { + return re.Rule{}, errors.Wrap(errCreatedAt, err) + } + r.CreatedAt = ct + + if domain, ok := data["domain"].(string); ok { + r.DomainID = domain + } + + if itags, ok := data["tags"].([]any); ok { + tags, err := rconsumer.ToStrings(itags) + if err != nil { + return re.Rule{}, errors.Wrap(errTags, err) + } + r.Tags = tags + } + + if meta, ok := data["metadata"].(map[string]any); ok { + r.Metadata = meta + } + + if uby, ok := data["updated_by"].(string); ok { + r.UpdatedBy = uby + } + + if uat, ok := data["updated_at"].(string); ok { + ut, err := time.Parse(re.TimeLayout, uat) + if err != nil { + return re.Rule{}, errors.Wrap(errUpdatedAt, err) + } + r.UpdatedAt = ut + } + + if ic, ok := data["input_channel"].(string); ok { + r.InputChannel = ic + } + + if it, ok := data["input_topic"].(string); ok { + r.InputTopic = it + } + + if rawLogic, ok := data["logic"].(map[string]any); ok { + b, err := json.Marshal(rawLogic) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeLogic, err) + } + if err := json.Unmarshal(b, &r.Logic); err != nil { + return re.Rule{}, errors.Wrap(errDecodeLogic, err) + } + } + + if rawSched, ok := data["schedule"].(map[string]any); ok { + b, err := json.Marshal(rawSched) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeSchedule, err) + } + var sched schedule.Schedule + if err := json.Unmarshal(b, &sched); err != nil { + return re.Rule{}, errors.Wrap(errDecodeSchedule, err) + } + r.Schedule = sched + } + + return r, nil +} + +func decodeAddRuleEvent(data map[string]any) (re.Rule, []roles.RoleProvision, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, nil, errors.Wrap(errDecodeAddRuleEvent, err) + } + + var rps []roles.RoleProvision + if irps, ok := data["roles_provisioned"].([]any); ok { + rps, err = rconsumer.ToRoleProvisions(irps) + if err != nil { + return re.Rule{}, nil, errors.Wrap(errDecodeAddRuleEvent, err) + } + } + + return r, rps, nil +} + +func decodeUpdateRuleEvent(data map[string]any) (re.Rule, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeUpdateRuleEvent, err) + } + return r, nil +} + +func decodeUpdateRuleTagsEvent(data map[string]any) (re.Rule, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeUpdateRuleTagsEvent, err) + } + return r, nil +} + +func decodeUpdateRuleScheduleEvent(data map[string]any) (re.Rule, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeUpdateRuleScheduleEvent, err) + } + return r, nil +} + +func decodeEnableRuleEvent(data map[string]any) (re.Rule, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeEnableRuleEvent, err) + } + return r, nil +} + +func decodeDisableRuleEvent(data map[string]any) (re.Rule, error) { + r, err := ToRule(data) + if err != nil { + return re.Rule{}, errors.Wrap(errDecodeDisableRuleEvent, err) + } + return r, nil +} + +func decodeRemoveRuleEvent(data map[string]any) (string, error) { + id, ok := data["id"].(string) + if !ok { + return "", errors.Wrap(errDecodeRemoveRuleEvent, errID) + } + return id, nil +} diff --git a/pkg/re/events/consumer/doc.go b/pkg/re/events/consumer/doc.go new file mode 100644 index 000000000..581353cb3 --- /dev/null +++ b/pkg/re/events/consumer/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package consumer contains events consumer for events +// published by the Rules Engine service. +package consumer diff --git a/pkg/re/events/consumer/stream.go b/pkg/re/events/consumer/stream.go new file mode 100644 index 000000000..5c777e813 --- /dev/null +++ b/pkg/re/events/consumer/stream.go @@ -0,0 +1,193 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + "log/slog" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer" + "github.com/absmach/supermq/re" +) + +const ( + stream = "events.supermq.rule.*" + + create = "rule.create" + update = "rule.update" + updateTags = "rule.update_tags" + updateSchedule = "rule.update_schedule" + enable = "rule.enable" + disable = "rule.disable" + remove = "rule.remove" +) + +var ( + errNoOperationKey = errors.New("operation key is not found in event message") + errAddRuleEvent = errors.New("failed to consume rule create event") + errUpdateRuleEvent = errors.New("failed to consume rule update event") + errUpdateRuleTagsEvent = errors.New("failed to consume rule update tags event") + errUpdateRuleScheduleEvent = errors.New("failed to consume rule update schedule event") + errEnableRuleEvent = errors.New("failed to consume rule enable event") + errDisableRuleEvent = errors.New("failed to consume rule disable event") + errRemoveRuleEvent = errors.New("failed to consume rule remove event") +) + +type eventHandler struct { + repo re.Repository + rolesEventHandler rconsumer.EventHandler +} + +func RulesEventsSubscribe(ctx context.Context, repo re.Repository, esURL, esConsumerName string, logger *slog.Logger) error { + subscriber, err := store.NewSubscriber(ctx, esURL, "re-es-sub", logger) + if err != nil { + return err + } + + subConfig := events.SubscriberConfig{ + Stream: stream, + Consumer: esConsumerName, + Handler: NewEventHandler(repo), + Ordered: true, + } + return subscriber.Subscribe(ctx, subConfig) +} + +// NewEventHandler returns new event store handler. +func NewEventHandler(repo re.Repository) events.EventHandler { + reh := rconsumer.NewEventHandler("rule", repo) + return &eventHandler{ + repo: repo, + rolesEventHandler: reh, + } +} + +func (es *eventHandler) Handle(ctx context.Context, event events.Event) error { + msg, err := event.Encode() + if err != nil { + return err + } + + op, ok := msg["operation"] + if !ok { + return errNoOperationKey + } + + switch op { + case create: + return es.addRuleHandler(ctx, msg) + case update: + return es.updateRuleHandler(ctx, msg) + case updateTags: + return es.updateRuleTagsHandler(ctx, msg) + case updateSchedule: + return es.updateRuleScheduleHandler(ctx, msg) + case enable: + return es.enableRuleHandler(ctx, msg) + case disable: + return es.disableRuleHandler(ctx, msg) + case remove: + return es.removeRuleHandler(ctx, msg) + } + + return es.rolesEventHandler.Handle(ctx, op, msg) +} + +func (es *eventHandler) addRuleHandler(ctx context.Context, data map[string]any) error { + r, rps, err := decodeAddRuleEvent(data) + if err != nil { + return errors.Wrap(errAddRuleEvent, err) + } + + if _, err := es.repo.AddRule(ctx, r); err != nil { + return errors.Wrap(errAddRuleEvent, err) + } + + if _, err := es.repo.AddRoles(ctx, rps); err != nil { + return errors.Wrap(errAddRuleEvent, err) + } + + return nil +} + +func (es *eventHandler) updateRuleHandler(ctx context.Context, data map[string]any) error { + r, err := decodeUpdateRuleEvent(data) + if err != nil { + return errors.Wrap(errUpdateRuleEvent, err) + } + + if _, err := es.repo.UpdateRule(ctx, r); err != nil { + return errors.Wrap(errUpdateRuleEvent, err) + } + + return nil +} + +func (es *eventHandler) updateRuleTagsHandler(ctx context.Context, data map[string]any) error { + r, err := decodeUpdateRuleTagsEvent(data) + if err != nil { + return errors.Wrap(errUpdateRuleTagsEvent, err) + } + + if _, err := es.repo.UpdateRuleTags(ctx, r); err != nil { + return errors.Wrap(errUpdateRuleTagsEvent, err) + } + + return nil +} + +func (es *eventHandler) updateRuleScheduleHandler(ctx context.Context, data map[string]any) error { + r, err := decodeUpdateRuleScheduleEvent(data) + if err != nil { + return errors.Wrap(errUpdateRuleScheduleEvent, err) + } + + if _, err := es.repo.UpdateRuleSchedule(ctx, r); err != nil { + return errors.Wrap(errUpdateRuleScheduleEvent, err) + } + + return nil +} + +func (es *eventHandler) enableRuleHandler(ctx context.Context, data map[string]any) error { + r, err := decodeEnableRuleEvent(data) + if err != nil { + return errors.Wrap(errEnableRuleEvent, err) + } + + if _, err := es.repo.UpdateRuleStatus(ctx, r); err != nil { + return errors.Wrap(errEnableRuleEvent, err) + } + + return nil +} + +func (es *eventHandler) disableRuleHandler(ctx context.Context, data map[string]any) error { + r, err := decodeDisableRuleEvent(data) + if err != nil { + return errors.Wrap(errDisableRuleEvent, err) + } + + if _, err := es.repo.UpdateRuleStatus(ctx, r); err != nil { + return errors.Wrap(errDisableRuleEvent, err) + } + + return nil +} + +func (es *eventHandler) removeRuleHandler(ctx context.Context, data map[string]any) error { + id, err := decodeRemoveRuleEvent(data) + if err != nil { + return errors.Wrap(errRemoveRuleEvent, err) + } + + if err := es.repo.RemoveRule(ctx, id); err != nil { + return errors.Wrap(errRemoveRuleEvent, err) + } + + return nil +} diff --git a/pkg/reltime/reltime.go b/pkg/reltime/reltime.go new file mode 100644 index 000000000..c3a83a63f --- /dev/null +++ b/pkg/reltime/reltime.go @@ -0,0 +1,86 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reltime + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "github.com/absmach/supermq/pkg/errors" +) + +var ( + re = regexp.MustCompile(`(?i)^now\(\)([\+\-])(.+)$`) + + ErrInvalidDuration = errors.New("invalid duration format") + ErrInvalidExpression = errors.New("invalid time expression") + ErrUnsupportedUnit = errors.New("unsupported unit") +) + +func Parse(expr string) (time.Time, error) { + now := time.Now().UTC() + expr = strings.ReplaceAll(expr, " ", "") + + if strings.EqualFold(expr, "now()") { + return now, nil + } + + matches := re.FindStringSubmatch(expr) + if len(matches) != 3 { + return time.Time{}, errors.Wrap(ErrInvalidExpression, fmt.Errorf("%s", expr)) + } + + sign := matches[1] + durStr := matches[2] + if strings.ContainsAny(durStr, "+-") { + return time.Time{}, errors.Wrap(ErrInvalidExpression, fmt.Errorf("%s", expr)) + } + + dur, err := parseComplexDuration(durStr) + if err != nil { + return time.Time{}, err + } + + if sign == "-" { + return now.Add(-dur), nil + } + return now.Add(dur), nil +} + +func parseComplexDuration(s string) (time.Duration, error) { + var total time.Duration + re := regexp.MustCompile(`(\d+)([smhdwMY])`) + matches := re.FindAllStringSubmatch(s, -1) + + if matches == nil { + return 0, errors.Wrap(ErrInvalidDuration, fmt.Errorf("%s", s)) + } + + for _, match := range matches { + val, _ := strconv.Atoi(match[1]) + unit := match[2] + + var d time.Duration + switch unit { + case "s": + d = time.Duration(val) * time.Second + case "m": + d = time.Duration(val) * time.Minute + case "h": + d = time.Duration(val) * time.Hour + case "d": + d = time.Duration(val) * 24 * time.Hour + case "w": + d = time.Duration(val) * 7 * 24 * time.Hour + default: + return 0, errors.Wrap(ErrUnsupportedUnit, fmt.Errorf("%s", unit)) + } + + total += d + } + return total, nil +} diff --git a/pkg/reltime/reltime_test.go b/pkg/reltime/reltime_test.go new file mode 100644 index 000000000..4cc5aa4b6 --- /dev/null +++ b/pkg/reltime/reltime_test.go @@ -0,0 +1,76 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reltime + +import ( + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + now := time.Now() + + tests := []struct { + desc string + expr string + expected time.Time + err error + }{ + { + desc: "testing expression now()-5d", + expr: "now()-5d", + expected: now.Add(-5 * 24 * time.Hour), + err: nil, + }, + { + desc: "testing expression now()+2h30m", + expr: "now()+2h30m", + expected: now.Add(2*time.Hour + 30*time.Minute), + err: nil, + }, + { + desc: "testing expression now()-1w3d10h40m", + expr: "now()-1w3d10h40m", + expected: now.Add(-(7*24+3*24+10)*time.Hour - 40*time.Minute), + err: nil, + }, + { + desc: "testing expression yesterday", + expr: "yesterday", + err: ErrInvalidExpression, + }, + { + desc: "testing expression now()--5d", + expr: "now()--5d", + err: ErrInvalidExpression, + }, + { + desc: "testing expression now()+", + expr: "now()+", + err: ErrInvalidExpression, + }, + { + desc: "testing expression now()+5r", + expr: "now()+5r", + err: ErrInvalidDuration, + }, + { + desc: "testing expression now()+5M", + expr: "now()+5M", + err: ErrUnsupportedUnit, + }, + } + + for _, tc := range tests { + got, err := Parse(tc.expr) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v and response time %v\n", tc.desc, tc.err, err, got)) + if err == nil { + assert.WithinDuration(t, tc.expected, got, time.Duration(10*time.Second)) + } + } +} diff --git a/pkg/roles/rolemanager/middleware/authorization.go b/pkg/roles/rolemanager/middleware/authorization.go index a9807674d..86a221bfe 100644 --- a/pkg/roles/rolemanager/middleware/authorization.go +++ b/pkg/roles/rolemanager/middleware/authorization.go @@ -347,7 +347,7 @@ func (ram RoleManagerAuthorizationMiddleware) validateMembers(ctx context.Contex Subject: member, SubjectType: policies.UserType, SubjectKind: policies.UsersKind, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, ObjectType: policies.PlatformType, }, nil); err != nil { return errors.Wrap(errors.ErrMissingMember, err) diff --git a/pkg/schedule/schedule.go b/pkg/schedule/schedule.go new file mode 100644 index 000000000..e12b12371 --- /dev/null +++ b/pkg/schedule/schedule.go @@ -0,0 +1,172 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package schedule + +import ( + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/errors" +) + +const ( + noneType = "none" + hourlyType = "hourly" + dailyType = "daily" + weeklyType = "weekly" + monthlyType = "monthly" +) + +var ( + ErrInvalidRecurringType = errors.NewRequestError("invalid recurring type") + ErrStartDateTimeInPast = errors.NewRequestError("start_datetime must be greater than or equal to current time") +) + +// Type can be daily, weekly or monthly. +type Recurring uint + +const ( + None Recurring = iota + Hourly + Daily + Weekly + Monthly +) + +func (rt Recurring) String() string { + switch rt { + case Hourly: + return hourlyType + case Daily: + return dailyType + case Weekly: + return weeklyType + case Monthly: + return monthlyType + default: + return noneType + } +} + +func (rt Recurring) MarshalJSON() ([]byte, error) { + return json.Marshal(rt.String()) +} + +func (rt *Recurring) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + switch s { + case hourlyType: + *rt = Hourly + case dailyType: + *rt = Daily + case weeklyType: + *rt = Weekly + case monthlyType: + *rt = Monthly + case noneType: + *rt = None + default: + return ErrInvalidRecurringType + } + return nil +} + +type Schedule struct { + StartDateTime time.Time `json:"start_datetime,omitempty"` // When the schedule becomes active + Time time.Time `json:"time,omitempty"` // Specific time for the rule to run + Recurring Recurring `json:"recurring,omitempty"` // None, Daily, Weekly, Monthly + RecurringPeriod uint `json:"recurring_period,omitempty"` // Controls how many intervals to skip between executions: 1 = every interval, 2 = every second interval, etc. +} + +func (s Schedule) Validate() error { + if !s.StartDateTime.IsZero() { + now := time.Now().UTC() + if s.StartDateTime.Before(now) { + return ErrStartDateTimeInPast + } + } + return nil +} + +func (s Schedule) MarshalJSON() ([]byte, error) { + type Alias Schedule + jTimes := struct { + StartDateTime *string `json:"start_datetime"` + Time string `json:"time"` + *Alias + }{ + Time: s.Time.Format(time.RFC3339), + Alias: (*Alias)(&s), + } + if !s.StartDateTime.IsZero() { + formatted := s.StartDateTime.Format(time.RFC3339) + jTimes.StartDateTime = &formatted + } + + return json.Marshal(jTimes) +} + +func (s *Schedule) UnmarshalJSON(data []byte) error { + type Alias Schedule + temp := struct { + StartDateTime string `json:"start_datetime,omitempty"` + Time string `json:"time,omitempty"` + *Alias + }{ + Alias: (*Alias)(s), + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + if temp.StartDateTime != "" { + startDateTime, err := time.Parse(time.RFC3339, temp.StartDateTime) + if err != nil { + return err + } + s.StartDateTime = startDateTime + } + if temp.Time != "" { + parsedTime, err := time.Parse(time.RFC3339, temp.Time) + if err != nil { + return err + } + s.Time = parsedTime + } + return nil +} + +func (s Schedule) NextDue() time.Time { + switch s.Recurring { + case Hourly: + return s.Time.Add(time.Hour * time.Duration(s.RecurringPeriod)) + case Daily: + return s.Time.AddDate(0, 0, int(s.RecurringPeriod)) + case Weekly: + return s.Time.AddDate(0, 0, int(s.RecurringPeriod)*7) + case Monthly: + return s.Time.AddDate(0, int(s.RecurringPeriod), 0) + default: + return time.Time{} + } +} + +// EventEncode converts a schedule.Schedule struct to map[string]any. +func (s Schedule) EventEncode() map[string]any { + m := map[string]any{ + "recurring": s.Recurring.String(), + "recurring_period": s.RecurringPeriod, + } + if !s.StartDateTime.IsZero() { + m["start_datetime"] = s.StartDateTime.Format(time.RFC3339) + } + if !s.Time.IsZero() { + m["time"] = s.Time.Format(time.RFC3339) + } + return m +} diff --git a/pkg/sdk/alarms.go b/pkg/sdk/alarms.go new file mode 100644 index 000000000..3b950de6a --- /dev/null +++ b/pkg/sdk/alarms.go @@ -0,0 +1,115 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/absmach/supermq/pkg/errors" +) + +const alarmsEndpoint = "alarms" + +// Alarm represents an alarm instance. +type Alarm struct { + ID string `json:"id,omitempty"` + RuleID string `json:"rule_id,omitempty"` + DomainID string `json:"domain_id,omitempty"` + ChannelID string `json:"channel_id,omitempty"` + ClientID string `json:"client_id,omitempty"` + Subtopic string `json:"subtopic,omitempty"` + Status string `json:"status,omitempty"` + Measurement string `json:"measurement,omitempty"` + Value string `json:"value,omitempty"` + Unit string `json:"unit,omitempty"` + Threshold string `json:"threshold,omitempty"` + Cause string `json:"cause,omitempty"` + Severity uint8 `json:"severity,omitempty"` + AssigneeID string `json:"assignee_id,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + UpdatedBy string `json:"updated_by,omitempty"` + AssignedAt time.Time `json:"assigned_at,omitempty"` + AssignedBy string `json:"assigned_by,omitempty"` + AcknowledgedAt time.Time `json:"acknowledged_at,omitempty"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + ResolvedAt time.Time `json:"resolved_at,omitempty"` + ResolvedBy string `json:"resolved_by,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` +} + +type AlarmsPage struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Alarms []Alarm `json:"alarms"` +} + +func (sdk mgSDK) UpdateAlarm(ctx context.Context, alarm Alarm, domainID, token string) (Alarm, errors.SDKError) { + data, err := json.Marshal(alarm) + if err != nil { + return Alarm{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s", sdk.alarmsURL, domainID, alarmsEndpoint, alarm.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return Alarm{}, sdkerr + } + + var a Alarm + if err := json.Unmarshal(body, &a); err != nil { + return Alarm{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) ViewAlarm(ctx context.Context, id, domainID, token string) (Alarm, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.alarmsURL, domainID, alarmsEndpoint, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Alarm{}, sdkerr + } + + var a Alarm + if err := json.Unmarshal(body, &a); err != nil { + return Alarm{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) ListAlarms(ctx context.Context, pm PageMetadata, domainID, token string) (AlarmsPage, errors.SDKError) { + endpoint := fmt.Sprintf("%s/%s", domainID, alarmsEndpoint) + url, err := sdk.withQueryParams(sdk.alarmsURL, endpoint, pm) + if err != nil { + return AlarmsPage{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return AlarmsPage{}, sdkerr + } + + var ap AlarmsPage + if err := json.Unmarshal(body, &ap); err != nil { + return AlarmsPage{}, errors.NewSDKError(err) + } + + return ap, nil +} + +func (sdk mgSDK) DeleteAlarm(ctx context.Context, id, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.alarmsURL, domainID, alarmsEndpoint, id) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent, http.StatusOK) + return sdkerr +} diff --git a/pkg/sdk/alarms_test.go b/pkg/sdk/alarms_test.go new file mode 100644 index 000000000..5a67f2992 --- /dev/null +++ b/pkg/sdk/alarms_test.go @@ -0,0 +1,390 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "net/http/httptest" + "testing" + "time" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/alarms/api" + amocks "github.com/absmach/supermq/alarms/mocks" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/pkg/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const alarmID = "alarm-1" + +var testAlarm = sdk.Alarm{ + ID: alarmID, + RuleID: "rule-1", + DomainID: domainID, + ChannelID: "chan-1", + ClientID: "client-1", + Subtopic: "subtopic", + Status: "active", + Measurement: "temperature", + Value: "30.5", + Unit: "C", + Threshold: "25", + Cause: "threshold_exceeded", + Severity: 80, + AssigneeID: "user-1", + Metadata: sdk.Metadata{"key": "value"}, +} + +func setupAlarms() (*httptest.Server, *amocks.Service, *authnmocks.Authentication) { + asvc := new(amocks.Service) + logger := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + idp := uuid.NewMock() + mux := api.MakeHandler(asvc, logger, idp, "", am) + return httptest.NewServer(mux), asvc, authn +} + +func TestUpdateAlarm(t *testing.T) { + as, asvc, auth := setupAlarms() + defer as.Close() + + conf := sdk.Config{ + AlarmsURL: as.URL, + } + mgsdk := sdk.NewSDK(conf) + + updated := testAlarm + updated.Status = "cleared" + + svcAlarm := alarms.Alarm{ + ID: alarmID, + RuleID: "rule-1", + DomainID: domainID, + ChannelID: "chan-1", + ClientID: "client-1", + Subtopic: "subtopic", + Status: alarms.ClearedStatus, + Measurement: "temperature", + Value: "30.5", + Unit: "C", + Threshold: "25", + Cause: "threshold_exceeded", + Severity: 80, + AssigneeID: "user-1", + Metadata: alarms.Metadata{"key": "value"}, + } + + cases := []struct { + desc string + alarm sdk.Alarm + token string + session smqauthn.Session + svcRes alarms.Alarm + svcErr error + authenticateErr error + wantErr bool + resp sdk.Alarm + }{ + { + desc: "update alarm successfully", + alarm: updated, + token: validToken, + svcRes: svcAlarm, + resp: testAlarm, + }, + { + desc: "update alarm with empty token", + alarm: updated, + token: "", + wantErr: true, + }, + { + desc: "update non-existent alarm", + alarm: sdk.Alarm{ID: "non-existent"}, + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := asvc.On("UpdateAlarm", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateAlarm(context.Background(), tc.alarm, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewAlarm(t *testing.T) { + as, asvc, auth := setupAlarms() + defer as.Close() + + conf := sdk.Config{ + AlarmsURL: as.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcAlarm := alarms.Alarm{ + ID: alarmID, + RuleID: "rule-1", + DomainID: domainID, + ChannelID: "chan-1", + ClientID: "client-1", + Subtopic: "subtopic", + Status: alarms.ActiveStatus, + Measurement: "temperature", + Value: "30.5", + Unit: "C", + Threshold: "25", + Cause: "threshold_exceeded", + Severity: 80, + AssigneeID: "user-1", + Metadata: alarms.Metadata{"key": "value"}, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes alarms.Alarm + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "view alarm successfully", + id: alarmID, + token: validToken, + svcRes: svcAlarm, + }, + { + desc: "view alarm with empty token", + id: alarmID, + token: "", + wantErr: true, + }, + { + desc: "view non-existent alarm", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := asvc.On("ViewAlarm", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ViewAlarm(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListAlarms(t *testing.T) { + as, asvc, auth := setupAlarms() + defer as.Close() + + conf := sdk.Config{ + AlarmsURL: as.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcAlarm := alarms.Alarm{ + ID: alarmID, + RuleID: "rule-1", + DomainID: domainID, + ChannelID: "chan-1", + ClientID: "client-1", + Subtopic: "subtopic", + Status: alarms.ActiveStatus, + Measurement: "temperature", + Value: "30.5", + Unit: "C", + Threshold: "25", + Cause: "threshold_exceeded", + Severity: 80, + AssigneeID: "user-1", + Metadata: alarms.Metadata{"key": "value"}, + } + + svcAlarmsPage := alarms.AlarmsPage{ + Total: 2, + Offset: 0, + Limit: 10, + Alarms: []alarms.Alarm{svcAlarm}, + } + + cases := []struct { + desc string + pm sdk.PageMetadata + token string + session smqauthn.Session + svcRes alarms.AlarmsPage + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "list alarms successfully", + pm: sdk.PageMetadata{Offset: 0, Limit: 10}, + token: validToken, + svcRes: svcAlarmsPage, + }, + { + desc: "list alarms with status and entity filters", + pm: sdk.PageMetadata{ + Limit: 5, + Status: "active", + ChannelID: "chan-1", + ClientID: "client-1", + RuleID: "rule-1", + AssigneeID: "user-1", + Severity: 80, + }, + token: validToken, + svcRes: svcAlarmsPage, + }, + { + desc: "list alarms with time range and sorting", + pm: sdk.PageMetadata{ + Limit: 10, + CreatedFrom: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + CreatedTo: time.Date(2024, 12, 31, 0, 0, 0, 0, time.UTC), + Order: "created_at", + Dir: "asc", + }, + token: validToken, + svcRes: svcAlarmsPage, + }, + { + desc: "list alarms with actor filters", + pm: sdk.PageMetadata{ + Limit: 10, + UpdatedBy: "user-2", + AssignedBy: "user-3", + AcknowledgedBy: "user-4", + ResolvedBy: "user-5", + Subtopic: "subtopic-1", + }, + token: validToken, + svcRes: svcAlarmsPage, + }, + { + desc: "list alarms with empty metadata excludes severity", + pm: sdk.PageMetadata{}, + token: validToken, + svcRes: alarms.AlarmsPage{}, + }, + { + desc: "list alarms with zero severity excluded", + pm: sdk.PageMetadata{Status: "active", Severity: 0}, + token: validToken, + svcRes: alarms.AlarmsPage{}, + }, + { + desc: "list alarms with empty token", + pm: sdk.PageMetadata{Limit: 10}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := asvc.On("ListAlarms", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ListAlarms(context.Background(), tc.pm, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.Equal(t, tc.svcRes.Total, result.Total) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteAlarm(t *testing.T) { + as, asvc, auth := setupAlarms() + defer as.Close() + + conf := sdk.Config{ + AlarmsURL: as.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "delete alarm successfully", + id: alarmID, + token: validToken, + }, + { + desc: "delete alarm with empty token", + id: alarmID, + token: "", + wantErr: true, + }, + { + desc: "delete non-existent alarm", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := asvc.On("DeleteAlarm", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + err := mgsdk.DeleteAlarm(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/pkg/sdk/bootstrap.go b/pkg/sdk/bootstrap.go new file mode 100644 index 000000000..ebf783b62 --- /dev/null +++ b/pkg/sdk/bootstrap.go @@ -0,0 +1,323 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/errors" +) + +const ( + configsEndpoint = "clients/configs" + bootstrapEndpoint = "clients/bootstrap" + whitelistEndpoint = "clients/state" + bootstrapCertsEndpoint = "clients/configs/certs" + bootstrapConnEndpoint = "clients/configs/connections" + secureEndpoint = "secure" +) + +// BootstrapConfig represents Configuration entity. It wraps information about external entity +// as well as info about corresponding SuperMQ entities. +// MGClient represents corresponding SuperMQ Client ID. +// MGKey is key of corresponding SuperMQ Client. +// MGChannels is a list of SuperMQ Channels corresponding SuperMQ Client connects to. +type BootstrapConfig struct { + Channels any `json:"channels,omitempty"` + ExternalID string `json:"external_id,omitempty"` + ExternalKey string `json:"external_key,omitempty"` + ClientID string `json:"client_id,omitempty"` + ClientSecret string `json:"client_secret,omitempty"` + Name string `json:"name,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` + Content string `json:"content,omitempty"` + State int `json:"state,omitempty"` +} + +func (ts *BootstrapConfig) UnmarshalJSON(data []byte) error { + var rawData map[string]json.RawMessage + if err := json.Unmarshal(data, &rawData); err != nil { + return err + } + + if channelData, ok := rawData["channels"]; ok { + var stringData []string + if err := json.Unmarshal(channelData, &stringData); err == nil { + ts.Channels = stringData + } else { + var channels []Channel + if err := json.Unmarshal(channelData, &channels); err == nil { + ts.Channels = channels + } else { + return fmt.Errorf("unsupported channel data type") + } + } + } + + if err := json.Unmarshal(data, &struct { + ExternalID *string `json:"external_id,omitempty"` + ExternalKey *string `json:"external_key,omitempty"` + ClientID *string `json:"client_id,omitempty"` + ClientSecret *string `json:"client_secret,omitempty"` + Name *string `json:"name,omitempty"` + ClientCert *string `json:"client_cert,omitempty"` + ClientKey *string `json:"client_key,omitempty"` + CACert *string `json:"ca_cert,omitempty"` + Content *string `json:"content,omitempty"` + State *int `json:"state,omitempty"` + }{ + ExternalID: &ts.ExternalID, + ExternalKey: &ts.ExternalKey, + ClientID: &ts.ClientID, + ClientSecret: &ts.ClientSecret, + Name: &ts.Name, + ClientCert: &ts.ClientCert, + ClientKey: &ts.ClientKey, + CACert: &ts.CACert, + Content: &ts.Content, + State: &ts.State, + }); err != nil { + return err + } + + return nil +} + +func (sdk mgSDK) AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, errors.SDKError) { + data, err := json.Marshal(cfg) + if err != nil { + return "", errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint) + + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusOK, http.StatusCreated) + if sdkerr != nil { + return "", sdkerr + } + + id := strings.TrimPrefix(headers.Get("Location"), "/clients/configs/") + + return id, nil +} + +func (sdk mgSDK) Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, errors.SDKError) { + endpoint := fmt.Sprintf("%s/%s", domainID, configsEndpoint) + url, err := sdk.withQueryParams(sdk.bootstrapURL, endpoint, pm) + if err != nil { + return BootstrapPage{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return BootstrapPage{}, sdkerr + } + + var bb BootstrapPage + if err = json.Unmarshal(body, &bb); err != nil { + return BootstrapPage{}, errors.NewSDKError(err) + } + + return bb, nil +} + +func (sdk mgSDK) Whitelist(ctx context.Context, clientID string, state int, domainID, token string) errors.SDKError { + if clientID == "" { + return errors.NewSDKError(apiutil.ErrMissingID) + } + + data, err := json.Marshal(BootstrapConfig{State: state}) + if err != nil { + return errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, whitelistEndpoint, clientID) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusCreated, http.StatusOK) + + return sdkerr +} + +func (sdk mgSDK) ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, errors.SDKError) { + if id == "" { + return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) + + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if err != nil { + return BootstrapConfig{}, err + } + + var bc BootstrapConfig + if err := json.Unmarshal(body, &bc); err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + return bc, nil +} + +func (sdk mgSDK) UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) errors.SDKError { + if cfg.ClientID == "" { + return errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, cfg.ClientID) + + data, err := json.Marshal(cfg) + if err != nil { + return errors.NewSDKError(err) + } + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) + + return sdkerr +} + +func (sdk mgSDK) UpdateBootstrapCerts(ctx context.Context, id, clientCert, clientKey, ca, domainID, token string) (BootstrapConfig, errors.SDKError) { + if id == "" { + return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, bootstrapCertsEndpoint, id) + request := BootstrapConfig{ + ClientCert: clientCert, + ClientKey: clientKey, + CACert: ca, + } + + data, err := json.Marshal(request) + if err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return BootstrapConfig{}, sdkerr + } + + var bc BootstrapConfig + if err := json.Unmarshal(body, &bc); err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + return bc, nil +} + +func (sdk mgSDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) errors.SDKError { + if id == "" { + return errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, bootstrapConnEndpoint, id) + request := map[string][]string{ + "channels": channels, + } + data, err := json.Marshal(request) + if err != nil { + return errors.NewSDKError(err) + } + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusOK) + return sdkerr +} + +func (sdk mgSDK) RemoveBootstrap(ctx context.Context, id, domainID, token string) errors.SDKError { + if id == "" { + return errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, domainID, configsEndpoint, id) + + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + return err +} + +func (sdk mgSDK) Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, errors.SDKError) { + if externalID == "" { + return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s", sdk.bootstrapURL, bootstrapEndpoint, externalID) + + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, ClientPrefix+externalKey, nil, nil, http.StatusOK) + if err != nil { + return BootstrapConfig{}, err + } + + var bc BootstrapConfig + if err := json.Unmarshal(body, &bc); err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + return bc, nil +} + +func (sdk mgSDK) BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, errors.SDKError) { + if externalID == "" { + return BootstrapConfig{}, errors.NewSDKError(apiutil.ErrMissingID) + } + url := fmt.Sprintf("%s/%s/%s/%s", sdk.bootstrapURL, bootstrapEndpoint, secureEndpoint, externalID) + + encExtKey, err := bootstrapEncrypt([]byte(externalKey), cryptoKey) + if err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + _, body, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, ClientPrefix+encExtKey, nil, nil, http.StatusOK) + if sdkErr != nil { + return BootstrapConfig{}, sdkErr + } + + decBody, decErr := bootstrapDecrypt(body, cryptoKey) + if decErr != nil { + return BootstrapConfig{}, errors.NewSDKError(decErr) + } + var bc BootstrapConfig + if err := json.Unmarshal(decBody, &bc); err != nil { + return BootstrapConfig{}, errors.NewSDKError(err) + } + + return bc, nil +} + +func bootstrapEncrypt(in []byte, cryptoKey string) (string, error) { + block, err := aes.NewCipher([]byte(cryptoKey)) + if err != nil { + return "", err + } + ciphertext := make([]byte, aes.BlockSize+len(in)) + iv := ciphertext[:aes.BlockSize] + + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], in) + return hex.EncodeToString(ciphertext), nil +} + +func bootstrapDecrypt(in []byte, cryptoKey string) ([]byte, error) { + ciphertext := in + + block, err := aes.NewCipher([]byte(cryptoKey)) + if err != nil { + return nil, err + } + if len(ciphertext) < aes.BlockSize { + return nil, err + } + iv := ciphertext[:aes.BlockSize] + ciphertext = ciphertext[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(ciphertext, ciphertext) + return ciphertext, nil +} diff --git a/pkg/sdk/bootstrap_test.go b/pkg/sdk/bootstrap_test.go new file mode 100644 index 000000000..95ccaf606 --- /dev/null +++ b/pkg/sdk/bootstrap_test.go @@ -0,0 +1,1350 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/bootstrap" + "github.com/absmach/supermq/bootstrap/api" + bmocks "github.com/absmach/supermq/bootstrap/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + sdk "github.com/absmach/supermq/pkg/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + externalId = testsutil.GenerateUUID(&testing.T{}) + externalKey = testsutil.GenerateUUID(&testing.T{}) + clientId = testsutil.GenerateUUID(&testing.T{}) + clientSecret = testsutil.GenerateUUID(&testing.T{}) + channel1Id = testsutil.GenerateUUID(&testing.T{}) + channel2Id = testsutil.GenerateUUID(&testing.T{}) + clientCert = "newcert" + clientKey = "newkey" + caCert = "newca" + content = "newcontent" + state = 1 + bsName = "test" + encKey = []byte("1234567891011121") + bootstrapConfig = bootstrap.Config{ + ClientID: clientId, + Name: "test", + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + Channels: []bootstrap.Channel{ + { + ID: channel1Id, + }, + { + ID: channel2Id, + }, + }, + ExternalID: externalId, + ExternalKey: externalKey, + Content: content, + State: bootstrap.Inactive, + } + sdkBootstrapConfig = sdk.BootstrapConfig{ + Channels: []string{channel1Id, channel2Id}, + ExternalID: externalId, + ExternalKey: externalKey, + ClientID: clientId, + ClientSecret: clientSecret, + Name: bsName, + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + Content: content, + State: state, + } + sdkBootstrapConfigRes = sdk.BootstrapConfig{ + ClientID: clientId, + ClientSecret: clientSecret, + Channels: []sdk.Channel{ + { + ID: channel1Id, + }, + { + ID: channel2Id, + }, + }, + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + } + readConfigResponse = struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + Channels []readerChannelRes `json:"channels"` + Content string `json:"content,omitempty"` + ClientCert string `json:"client_cert,omitempty"` + ClientKey string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` + }{ + ClientID: clientId, + ClientSecret: clientSecret, + Channels: []readerChannelRes{ + { + ID: channel1Id, + }, + { + ID: channel2Id, + }, + }, + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + } +) + +var ( + errMarshalChan = errors.New("json: unsupported type: chan int") + errJsonEOF = errors.New("unexpected end of JSON input") +) + +type readerChannelRes struct { + ID string `json:"id"` + Name string `json:"name,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +func setupBootstrap() (*httptest.Server, *bmocks.Service, *bmocks.ConfigReader, *authnmocks.Authentication) { + bsvc := new(bmocks.Service) + reader := new(bmocks.ConfigReader) + logger := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + mux := api.MakeHandler(bsvc, am, reader, logger, "") + + return httptest.NewServer(mux), bsvc, reader, authn +} + +func TestAddBootstrap(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + neID := sdkBootstrapConfig + neID.ClientID = "non-existent" + + neReqId := bootstrapConfig + neReqId.ClientID = "non-existent" + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + cfg sdk.BootstrapConfig + svcReq bootstrap.Config + svcRes bootstrap.Config + svcErr error + authenticateErr error + response string + err errors.SDKError + }{ + { + desc: "add successfully", + domainID: domainID, + token: validToken, + cfg: sdkBootstrapConfig, + svcReq: bootstrapConfig, + svcRes: bootstrapConfig, + svcErr: nil, + err: nil, + }, + { + desc: "add with invalid token", + domainID: domainID, + token: invalidToken, + cfg: sdkBootstrapConfig, + svcReq: bootstrapConfig, + svcRes: bootstrap.Config{}, + authenticateErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "add with config that cannot be marshalled", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{ + Channels: map[string]any{ + "channel1": make(chan int), + }, + ExternalID: externalId, + ExternalKey: externalKey, + ClientID: clientId, + ClientSecret: clientSecret, + Name: bsName, + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + Content: content, + }, + svcReq: bootstrap.Config{}, + svcRes: bootstrap.Config{}, + svcErr: nil, + err: errors.NewSDKError(errMarshalChan), + }, + { + desc: "add an existing config", + domainID: domainID, + token: validToken, + cfg: sdkBootstrapConfig, + svcReq: bootstrapConfig, + svcRes: bootstrap.Config{}, + svcErr: svcerr.ErrConflict, + err: errors.NewSDKErrorWithStatus(svcerr.ErrConflict, http.StatusBadRequest), + }, + { + desc: "add empty config", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{}, + svcReq: bootstrap.Config{}, + svcRes: bootstrap.Config{}, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest), + }, + { + desc: "add with non-existent client Id", + domainID: domainID, + token: validToken, + cfg: neID, + svcReq: neReqId, + svcRes: bootstrap.Config{}, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("Add", mock.Anything, tc.session, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) + resp, err := mgsdk.AddBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if err == nil { + assert.Equal(t, bootstrapConfig.ClientID, resp) + ok := svcCall.Parent.AssertCalled(t, "Add", mock.Anything, tc.session, tc.token, tc.svcReq) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListBootstraps(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + configRes := sdk.BootstrapConfig{ + Channels: []sdk.Channel{ + { + ID: channel1Id, + }, + { + ID: channel2Id, + }, + }, + ClientID: clientId, + Name: bsName, + ExternalID: externalId, + ExternalKey: externalKey, + Content: content, + } + unmarshalableConfig := bootstrapConfig + unmarshalableConfig.Channels = []bootstrap.Channel{ + { + ID: channel1Id, + Metadata: map[string]any{ + "test": make(chan int), + }, + }, + } + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + pageMeta sdk.PageMetadata + svcResp bootstrap.ConfigsPage + svcErr error + authenticateErr error + response sdk.BootstrapPage + err errors.SDKError + }{ + { + desc: "list successfully", + domainID: domainID, + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcResp: bootstrap.ConfigsPage{ + Total: 1, + Offset: 0, + Configs: []bootstrap.Config{bootstrapConfig}, + }, + response: sdk.BootstrapPage{ + PageRes: sdk.PageRes{ + Total: 1, + }, + Configs: []sdk.BootstrapConfig{configRes}, + }, + err: nil, + }, + { + desc: "list with invalid token", + domainID: domainID, + token: invalidToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcResp: bootstrap.ConfigsPage{}, + authenticateErr: svcerr.ErrAuthentication, + response: sdk.BootstrapPage{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "list with empty token", + domainID: domainID, + token: "", + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcResp: bootstrap.ConfigsPage{}, + svcErr: nil, + response: sdk.BootstrapPage{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "list with invalid query params", + domainID: domainID, + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 1, + Limit: 10, + Metadata: map[string]any{ + "test": make(chan int), + }, + }, + svcResp: bootstrap.ConfigsPage{}, + svcErr: nil, + response: sdk.BootstrapPage{}, + err: errors.NewSDKError(errMarshalChan), + }, + { + desc: "list with response that cannot be unmarshalled", + domainID: domainID, + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcResp: bootstrap.ConfigsPage{ + Total: 1, + Offset: 0, + Configs: []bootstrap.Config{unmarshalableConfig}, + }, + svcErr: nil, + response: sdk.BootstrapPage{}, + err: errors.NewSDKError(errJsonEOF), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("List", mock.Anything, tc.session, mock.Anything, tc.pageMeta.Offset, tc.pageMeta.Limit).Return(tc.svcResp, tc.svcErr) + resp, err := mgsdk.Bootstraps(context.Background(), tc.pageMeta, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.response, resp) + if err == nil { + ok := svcCall.Parent.AssertCalled(t, "List", mock.Anything, tc.session, mock.Anything, tc.pageMeta.Offset, tc.pageMeta.Limit) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestWhiteList(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + active := 1 + inactive := 0 + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + clientID string + state int + svcReq bootstrap.State + svcErr error + authenticateErr error + err errors.SDKError + }{ + { + desc: "whitelist to active state successfully", + domainID: domainID, + token: validToken, + clientID: clientId, + state: active, + svcReq: bootstrap.Active, + svcErr: nil, + err: nil, + }, + { + desc: "whitelist to inactive state successfully", + domainID: domainID, + token: validToken, + clientID: clientId, + state: inactive, + svcReq: bootstrap.Inactive, + svcErr: nil, + err: nil, + }, + { + desc: "whitelist with invalid token", + domainID: domainID, + token: invalidToken, + clientID: clientId, + state: active, + svcReq: bootstrap.Active, + authenticateErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "whitelist with empty token", + domainID: domainID, + token: "", + clientID: clientId, + state: active, + svcReq: bootstrap.Active, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "whitelist with invalid state", + domainID: domainID, + token: validToken, + clientID: clientId, + state: -1, + svcReq: bootstrap.Active, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(bootstrap.ErrBootstrapState, http.StatusBadRequest), + }, + { + desc: "whitelist with empty client Id", + domainID: domainID, + token: validToken, + clientID: "", + state: 1, + svcReq: bootstrap.Active, + svcErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq).Return(tc.svcErr) + err := mgsdk.Whitelist(context.Background(), tc.clientID, tc.state, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "ChangeState", mock.Anything, tc.session, tc.token, tc.clientID, tc.svcReq) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewBootstrap(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + viewBoostrapRes := sdk.BootstrapConfig{ + ClientID: clientId, + Channels: sdkBootstrapConfigRes.Channels, + ExternalID: externalId, + ExternalKey: externalKey, + Name: bsName, + Content: content, + State: 0, + } + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + id string + svcResp bootstrap.Config + svcErr error + authenticateErr error + response sdk.BootstrapConfig + err errors.SDKError + }{ + { + desc: "view successfully", + domainID: domainID, + token: validToken, + id: clientId, + svcResp: bootstrapConfig, + svcErr: nil, + response: viewBoostrapRes, + err: nil, + }, + { + desc: "view with invalid token", + domainID: domainID, + token: invalidToken, + id: clientId, + svcResp: bootstrap.Config{}, + authenticateErr: svcerr.ErrAuthentication, + response: sdk.BootstrapConfig{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "view with empty token", + domainID: domainID, + token: "", + id: clientId, + svcResp: bootstrap.Config{}, + svcErr: nil, + response: sdk.BootstrapConfig{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "view with non-existent client Id", + domainID: domainID, + token: validToken, + id: invalid, + svcResp: bootstrap.Config{}, + svcErr: svcerr.ErrNotFound, + response: sdk.BootstrapConfig{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "view with response that cannot be unmarshalled", + domainID: domainID, + token: validToken, + id: clientId, + svcResp: bootstrap.Config{ + ClientID: clientId, + Channels: []bootstrap.Channel{ + { + ID: channel1Id, + Metadata: map[string]any{ + "test": make(chan int), + }, + }, + }, + }, + svcErr: nil, + response: sdk.BootstrapConfig{}, + err: errors.NewSDKError(errJsonEOF), + }, + { + desc: "view with empty client Id", + domainID: domainID, + token: validToken, + id: "", + svcResp: bootstrap.Config{}, + svcErr: nil, + response: sdk.BootstrapConfig{}, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("View", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + resp, err := mgsdk.ViewBootstrap(context.Background(), tc.id, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.response, resp) + if err == nil { + ok := svcCall.Parent.AssertCalled(t, "View", mock.Anything, tc.session, tc.id) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateBootstrap(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + cfg sdk.BootstrapConfig + svcReq bootstrap.Config + svcErr error + authenticationErr error + err errors.SDKError + }{ + { + desc: "update successfully", + domainID: domainID, + token: validToken, + cfg: sdkBootstrapConfig, + svcReq: bootstrap.Config{ + ClientID: clientId, + Name: bsName, + Content: content, + }, + svcErr: nil, + err: nil, + }, + { + desc: "update with invalid token", + domainID: domainID, + token: invalidToken, + cfg: sdkBootstrapConfig, + svcReq: bootstrap.Config{ + ClientID: clientId, + Name: bsName, + Content: content, + }, + authenticationErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "update with empty token", + domainID: domainID, + token: "", + cfg: sdkBootstrapConfig, + svcReq: bootstrap.Config{}, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "update with config that cannot be marshalled", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{ + Channels: map[string]any{ + "channel1": make(chan int), + }, + ExternalID: externalId, + ExternalKey: externalKey, + ClientID: clientId, + ClientSecret: clientSecret, + Name: bsName, + ClientCert: clientCert, + ClientKey: clientKey, + CACert: caCert, + Content: content, + }, + svcReq: bootstrap.Config{ + ClientID: clientId, + Name: bsName, + Content: content, + }, + svcErr: nil, + err: errors.NewSDKError(errMarshalChan), + }, + { + desc: "update with non-existent client Id", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{ + ClientID: invalid, + Channels: []sdk.Channel{ + { + ID: channel1Id, + }, + }, + ExternalID: externalId, + ExternalKey: externalKey, + Content: content, + Name: bsName, + }, + svcReq: bootstrap.Config{ + ClientID: invalid, + Name: bsName, + Content: content, + }, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "update with empty client Id", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{ + ClientID: "", + Channels: []sdk.Channel{ + { + ID: channel1Id, + }, + }, + ExternalID: externalId, + ExternalKey: externalKey, + Content: content, + Name: bsName, + }, + svcReq: bootstrap.Config{ + ClientID: "", + Name: bsName, + Content: content, + }, + svcErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + { + desc: "update with config with only client Id", + domainID: domainID, + token: validToken, + cfg: sdk.BootstrapConfig{ + ClientID: clientId, + }, + svcReq: bootstrap.Config{ + ClientID: clientId, + }, + svcErr: nil, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticationErr) + svcCall := bsvc.On("Update", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr) + err := mgsdk.UpdateBootstrap(context.Background(), tc.cfg, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "Update", mock.Anything, tc.session, tc.svcReq) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateBootstrapCerts(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + updateconfigRes := sdk.BootstrapConfig{ + ClientID: clientId, + ClientCert: clientCert, + CACert: caCert, + ClientKey: clientKey, + } + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + id string + clientCert string + clientKey string + caCert string + svcResp bootstrap.Config + svcErr error + authenticateErr error + response sdk.BootstrapConfig + err errors.SDKError + }{ + { + desc: "update certs successfully", + domainID: domainID, + token: validToken, + id: clientId, + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + svcResp: bootstrapConfig, + svcErr: nil, + response: updateconfigRes, + err: nil, + }, + { + desc: "update certs with invalid token", + domainID: domainID, + token: validToken, + id: clientId, + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + svcResp: bootstrap.Config{}, + authenticateErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "update certs with empty token", + domainID: domainID, + token: "", + id: clientId, + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + svcResp: bootstrap.Config{}, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "update certs with non-existent client Id", + domainID: domainID, + token: validToken, + id: invalid, + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + svcResp: bootstrap.Config{}, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "update certs with empty certs", + domainID: domainID, + token: validToken, + id: clientId, + clientCert: "", + clientKey: "", + caCert: "", + svcResp: bootstrap.Config{}, + svcErr: nil, + err: nil, + }, + { + desc: "update certs with empty id", + domainID: domainID, + token: validToken, + id: "", + clientCert: clientCert, + clientKey: clientKey, + caCert: caCert, + svcResp: bootstrap.Config{}, + svcErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("UpdateCert", mock.Anything, tc.session, tc.id, tc.clientCert, tc.clientKey, tc.caCert).Return(tc.svcResp, tc.svcErr) + resp, err := mgsdk.UpdateBootstrapCerts(context.Background(), tc.id, tc.clientCert, tc.clientKey, tc.caCert, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if err == nil { + assert.Equal(t, tc.response, resp) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateBootstrapConnection(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + id string + channels []string + svcRes bootstrap.Config + svcErr error + authenticateErr error + err errors.SDKError + }{ + { + desc: "update connection successfully", + domainID: domainID, + token: validToken, + id: clientId, + channels: []string{channel1Id, channel2Id}, + svcErr: nil, + err: nil, + }, + { + desc: "update connection with invalid token", + domainID: domainID, + token: invalidToken, + id: clientId, + channels: []string{channel1Id, channel2Id}, + authenticateErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "update connection with empty token", + domainID: domainID, + token: "", + id: clientId, + channels: []string{channel1Id, channel2Id}, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "update connection with non-existent client Id", + domainID: domainID, + token: validToken, + id: invalid, + channels: []string{channel1Id, channel2Id}, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "update connection with non-existent channel Id", + domainID: domainID, + token: validToken, + id: clientId, + channels: []string{invalid}, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "update connection with empty channels", + domainID: domainID, + token: validToken, + id: clientId, + channels: []string{}, + svcErr: svcerr.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(svcerr.ErrUpdateEntity, http.StatusUnprocessableEntity), + }, + { + desc: "update connection with empty id", + domainID: domainID, + token: validToken, + id: "", + channels: []string{channel1Id, channel2Id}, + svcErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels).Return(tc.svcErr) + err := mgsdk.UpdateBootstrapConnection(context.Background(), tc.id, tc.channels, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "UpdateConnections", mock.Anything, tc.session, tc.token, tc.id, tc.channels) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveBootstrap(t *testing.T) { + bs, bsvc, _, auth := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + domainID string + token string + session smqauthn.Session + id string + svcErr error + authenticateErr error + err errors.SDKError + }{ + { + desc: "remove successfully", + domainID: domainID, + token: validToken, + id: clientId, + svcErr: nil, + err: nil, + }, + { + desc: "remove with invalid token", + domainID: domainID, + token: invalidToken, + id: clientId, + authenticateErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "remove with non-existent client Id", + domainID: domainID, + token: validToken, + id: invalid, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "remove removed bootstrap", + domainID: domainID, + token: validToken, + id: clientId, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "remove with empty token", + domainID: domainID, + token: "", + id: clientId, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "remove with empty id", + domainID: domainID, + token: validToken, + id: "", + svcErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := bsvc.On("Remove", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + err := mgsdk.RemoveBootstrap(context.Background(), tc.id, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "Remove", mock.Anything, tc.session, tc.id) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestBoostrap(t *testing.T) { + bs, bsvc, reader, _ := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + token string + externalID string + externalKey string + svcResp bootstrap.Config + svcErr error + readerResp any + readerErr error + response sdk.BootstrapConfig + err errors.SDKError + }{ + { + desc: "bootstrap successfully", + token: validToken, + externalID: externalId, + externalKey: externalKey, + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: readConfigResponse, + readerErr: nil, + response: sdkBootstrapConfigRes, + err: nil, + }, + { + desc: "bootstrap with invalid token", + token: invalidToken, + externalID: externalId, + externalKey: externalKey, + svcResp: bootstrap.Config{}, + svcErr: svcerr.ErrAuthentication, + readerResp: bootstrap.Config{}, + readerErr: nil, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "bootstrap with error in reader", + token: validToken, + externalID: externalId, + externalKey: externalKey, + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: []byte{0}, + readerErr: errJsonEOF, + err: errors.NewSDKErrorWithStatus(errJsonEOF, http.StatusInternalServerError), + }, + { + desc: "boostrap with response that cannot be unmarshalled", + token: validToken, + externalID: externalId, + externalKey: externalKey, + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: []byte{0}, + readerErr: nil, + err: errors.NewSDKError(errors.New("json: cannot unmarshal string into Go value of type map[string]json.RawMessage")), + }, + { + desc: "bootstrap with empty id", + token: validToken, + externalID: "", + externalKey: externalKey, + svcResp: bootstrap.Config{}, + svcErr: nil, + readerResp: bootstrap.Config{}, + readerErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + { + desc: "boostrap with empty key", + token: validToken, + externalID: externalId, + externalKey: "", + svcResp: bootstrap.Config{}, + svcErr: nil, + readerResp: bootstrap.Config{}, + readerErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerKey, http.StatusUnauthorized), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := bsvc.On("Bootstrap", mock.Anything, tc.externalKey, tc.externalID, false).Return(tc.svcResp, tc.svcErr) + readerCall := reader.On("ReadConfig", tc.svcResp, false).Return(tc.readerResp, tc.readerErr) + resp, err := mgsdk.Bootstrap(context.Background(), tc.externalID, tc.externalKey) + assert.Equal(t, tc.err, err) + if err == nil { + assert.Equal(t, tc.response, resp) + ok := svcCall.Parent.AssertCalled(t, "Bootstrap", mock.Anything, tc.externalKey, tc.externalID, false) + assert.True(t, ok) + } + svcCall.Unset() + readerCall.Unset() + }) + } +} + +func TestBootstrapSecure(t *testing.T) { + bs, bsvc, reader, _ := setupBootstrap() + defer bs.Close() + + conf := sdk.Config{ + BootstrapURL: bs.URL, + } + mgsdk := sdk.NewSDK(conf) + + b, err := json.Marshal(readConfigResponse) + assert.Nil(t, err, fmt.Sprintf("Marshalling bootstrap response expected to succeed: %s.\n", err)) + encResponse, err := encrypt(b, encKey) + assert.Nil(t, err, fmt.Sprintf("Encrypting bootstrap response expected to succeed: %s.\n", err)) + + cases := []struct { + desc string + token string + externalID string + externalKey string + cryptoKey string + svcResp bootstrap.Config + svcErr error + readerResp []byte + readerErr error + response sdk.BootstrapConfig + err errors.SDKError + }{ + { + desc: "bootstrap successfully", + token: validToken, + externalID: externalId, + externalKey: externalKey, + cryptoKey: string(encKey), + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: encResponse, + readerErr: nil, + response: sdkBootstrapConfigRes, + err: nil, + }, + { + desc: "bootstrap with invalid token", + token: invalidToken, + externalID: externalId, + externalKey: externalKey, + cryptoKey: string(encKey), + svcResp: bootstrap.Config{}, + svcErr: svcerr.ErrAuthentication, + readerResp: []byte{0}, + readerErr: nil, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "booostrap with invalid crypto key", + token: validToken, + externalID: externalId, + externalKey: externalKey, + cryptoKey: invalid, + svcResp: bootstrap.Config{}, + svcErr: nil, + readerResp: []byte{0}, + readerErr: nil, + err: errors.NewSDKError(errors.New("crypto/aes: invalid key size 7")), + }, + { + desc: "bootstrap with error in reader", + token: validToken, + externalID: externalId, + externalKey: externalKey, + cryptoKey: string(encKey), + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: []byte{0}, + readerErr: errJsonEOF, + err: errors.NewSDKErrorWithStatus(errJsonEOF, http.StatusInternalServerError), + }, + { + desc: "bootstrap with response that cannot be unmarshalled", + token: validToken, + externalID: externalId, + externalKey: externalKey, + cryptoKey: string(encKey), + svcResp: bootstrapConfig, + svcErr: nil, + readerResp: []byte{0}, + readerErr: nil, + err: errors.NewSDKError(errJsonEOF), + }, + { + desc: "bootstrap with empty id", + token: validToken, + externalID: "", + externalKey: externalKey, + svcResp: bootstrap.Config{}, + svcErr: nil, + readerResp: []byte{0}, + readerErr: nil, + err: errors.NewSDKError(apiutil.ErrMissingID), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := bsvc.On("Bootstrap", mock.Anything, mock.Anything, tc.externalID, true).Return(tc.svcResp, tc.svcErr) + readerCall := reader.On("ReadConfig", tc.svcResp, true).Return(tc.readerResp, tc.readerErr) + resp, err := mgsdk.BootstrapSecure(context.Background(), tc.externalID, tc.externalKey, tc.cryptoKey) + assert.Equal(t, tc.err, err) + if err == nil { + assert.Equal(t, sdkBootstrapConfigRes, resp) + ok := svcCall.Parent.AssertCalled(t, "Bootstrap", mock.Anything, mock.Anything, tc.externalID, true) + assert.True(t, ok) + } + svcCall.Unset() + readerCall.Unset() + }) + } +} + +func encrypt(in, encKey []byte) ([]byte, error) { + block, err := aes.NewCipher(encKey) + if err != nil { + return nil, err + } + ciphertext := make([]byte, aes.BlockSize+len(in)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], in) + return ciphertext, nil +} diff --git a/pkg/sdk/certs.go b/pkg/sdk/certs.go new file mode 100644 index 000000000..ceb15328d --- /dev/null +++ b/pkg/sdk/certs.go @@ -0,0 +1,377 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "archive/zip" + "bytes" + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net" + "net/http" + "time" + + "github.com/absmach/supermq/certs" + "github.com/absmach/supermq/pkg/errors" + "golang.org/x/crypto/ocsp" +) + +const ( + certsEndpoint = "certs" + csrEndpoint = "csrs" + crlEndpoint = "crl" +) + +func (sdk mgSDK) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, opts Options, domainID, token string) (Certificate, errors.SDKError) { + type certReq struct { + IpAddrs []string `json:"ip_addresses"` + TTL string `json:"ttl"` + Options Options `json:"options"` + } + r := certReq{ + IpAddrs: ipAddrs, + TTL: ttl, + Options: opts, + } + d, err := json.Marshal(r) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + url := fmt.Sprintf("%s/%s/%s/issue/%s", sdk.certsURL, domainID, certsEndpoint, entityID) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, d, nil, http.StatusCreated) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) ViewCert(ctx context.Context, serialNumber, domainID, token string) (Certificate, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, domainID, certsEndpoint, serialNumber) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) RevokeCert(ctx context.Context, serialNumber, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s/revoke", sdk.certsURL, domainID, certsEndpoint, serialNumber) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, nil, nil, http.StatusNoContent) + return sdkerr +} + +func (sdk mgSDK) RenewCert(ctx context.Context, serialNumber, domainID, token string) (Certificate, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/renew", sdk.certsURL, domainID, certsEndpoint, serialNumber) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var renewRes struct { + Renewed bool `json:"renewed"` + Certificate Certificate `json:"certificate"` + } + if err := json.Unmarshal(body, &renewRes); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return renewRes.Certificate, nil +} + +func (sdk mgSDK) ListCerts(ctx context.Context, pm PageMetadata, domainID, token string) (CertificatePage, errors.SDKError) { + url, err := sdk.withQueryParams(fmt.Sprintf("%s/%s", sdk.certsURL, domainID), certsEndpoint, pm) + if err != nil { + return CertificatePage{}, errors.NewSDKError(err) + } + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return CertificatePage{}, sdkerr + } + var cp CertificatePage + if err := json.Unmarshal(body, &cp); err != nil { + return CertificatePage{}, errors.NewSDKError(err) + } + return cp, nil +} + +func (sdk mgSDK) DeleteCert(ctx context.Context, entityID, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s/delete", sdk.certsURL, domainID, certsEndpoint, entityID) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + return sdkerr +} + +func (sdk mgSDK) OCSP(ctx context.Context, serialNumber, cert string) (OCSPResponse, errors.SDKError) { + if serialNumber == "" && cert == "" { + return OCSPResponse{}, errors.NewSDKError(errors.New("either serial number or certificate must be provided")) + } + ocspReq := struct { + SerialNumber string `json:"serial_number,omitempty"` + Certificate string `json:"certificate,omitempty"` + }{} + if serialNumber != "" { + ocspReq.SerialNumber = serialNumber + } + if cert != "" { + ocspReq.Certificate = cert + } + requestBody, err := json.Marshal(ocspReq) + if err != nil { + return OCSPResponse{}, errors.NewSDKError(err) + } + url := fmt.Sprintf("%s/certs/ocsp", sdk.certsURL) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, "", requestBody, nil, http.StatusOK) + if sdkerr != nil { + return OCSPResponse{}, sdkerr + } + ocspResp, err := ocsp.ParseResponse(body, nil) + if err != nil { + return OCSPResponse{}, errors.NewSDKError(fmt.Errorf("failed to parse OCSP response: %w", err)) + } + var status CertStatus + switch ocspResp.Status { + case ocsp.Good: + status = CertValid + case ocsp.Revoked: + status = CertRevoked + default: + status = CertUnknown + } + resp := OCSPResponse{ + Status: status, + SerialNumber: ocspResp.SerialNumber.String(), + Certificate: body, + } + if ocspResp.RevokedAt != (time.Time{}) { + resp.RevokedAt = &ocspResp.RevokedAt + } + if ocspResp.ProducedAt != (time.Time{}) { + resp.ProducedAt = &ocspResp.ProducedAt + } + if ocspResp.ThisUpdate != (time.Time{}) { + resp.ThisUpdate = &ocspResp.ThisUpdate + } + if ocspResp.NextUpdate != (time.Time{}) { + resp.NextUpdate = &ocspResp.NextUpdate + } + resp.RevocationReason = int(ocspResp.RevocationReason) + return resp, nil +} + +func (sdk mgSDK) ViewCA(ctx context.Context) (Certificate, errors.SDKError) { + url := fmt.Sprintf("%s/%s/view-ca", sdk.certsURL, certsEndpoint) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) DownloadCA(ctx context.Context) (CertificateBundle, errors.SDKError) { + url := fmt.Sprintf("%s/%s/download-ca", sdk.certsURL, certsEndpoint) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK) + if sdkerr != nil { + return CertificateBundle{}, sdkerr + } + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + var bundle CertificateBundle + for _, file := range zipReader.File { + fileContent, err := readZipFile(file) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + if file.Name == "ca.crt" { + bundle.Certificate = fileContent + } + } + return bundle, nil +} + +func (sdk mgSDK) IssueFromCSR(ctx context.Context, entityID, ttl, csr, domainID, token string) (Certificate, errors.SDKError) { + pm := PageMetadata{TTL: ttl} + type csrReq struct { + CSR []byte `json:"csr,omitempty"` + } + r := csrReq{CSR: []byte(csr)} + d, err := json.Marshal(r) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + url, err := sdk.withQueryParams(fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, domainID, certsEndpoint, csrEndpoint), entityID, pm) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, d, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) IssueFromCSRInternal(ctx context.Context, entityID, ttl, csr, token string) (Certificate, errors.SDKError) { + type csrReq struct { + CSR []byte `json:"csr,omitempty"` + } + r := csrReq{CSR: []byte(csr)} + d, err := json.Marshal(r) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + pm := PageMetadata{TTL: ttl} + url, err := sdk.withQueryParams(fmt.Sprintf("%s/certs/csrs", sdk.certsURL), entityID, pm) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, d, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) GenerateCRL(ctx context.Context) ([]byte, errors.SDKError) { + url := fmt.Sprintf("%s/certs/%s", sdk.certsURL, crlEndpoint) + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK) + if sdkerr != nil { + return nil, sdkerr + } + var crlRes struct { + CRL string `json:"crl"` + } + if err := json.Unmarshal(body, &crlRes); err != nil { + return nil, errors.NewSDKError(err) + } + crlData, err := base64.StdEncoding.DecodeString(crlRes.CRL) + if err != nil { + return nil, errors.NewSDKError(err) + } + return crlData, nil +} + +func (sdk mgSDK) RevokeAll(ctx context.Context, entityID, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s/delete", sdk.certsURL, domainID, certsEndpoint, entityID) + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + return sdkerr +} + +func (sdk mgSDK) EntityID(ctx context.Context, serialNumber, domainID, token string) (string, errors.SDKError) { + cert, err := sdk.ViewCert(ctx, serialNumber, domainID, token) + if err != nil { + return "", err + } + return cert.EntityID, nil +} + +// CreateCSR creates a Certificate Signing Request from the given metadata and private key. +// The private key may be a PEM-encoded []byte or a crypto.Signer (rsa, ecdsa, ed25519). +func (sdk mgSDK) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKError) { + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: metadata.CommonName, + Organization: metadata.Organization, + OrganizationalUnit: metadata.OrganizationalUnit, + Country: metadata.Country, + Province: metadata.Province, + Locality: metadata.Locality, + StreetAddress: metadata.StreetAddress, + PostalCode: metadata.PostalCode, + }, + EmailAddresses: metadata.EmailAddresses, + DNSNames: metadata.DNSNames, + ExtraExtensions: metadata.ExtraExtensions, + } + for _, ip := range metadata.IPAddresses { + if parsed := net.ParseIP(ip); parsed != nil { + template.IPAddresses = append(template.IPAddresses, parsed) + } + } + actualKey := privKey + if keyBytes, ok := privKey.([]byte); ok { + var err error + actualKey, err = extractPrivateKey(keyBytes) + if err != nil { + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, err)) + } + } + var signer crypto.Signer + switch key := actualKey.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey: + signer = key.(crypto.Signer) + case ed25519.PrivateKey: + signer = key + default: + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, certs.ErrPrivKeyType)) + } + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, signer) + if err != nil { + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, err)) + } + csrPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes}) + return certs.CSR{CSR: csrPEM}, nil +} + +func readZipFile(file *zip.File) ([]byte, error) { + fc, err := file.Open() + if err != nil { + return nil, err + } + defer fc.Close() + return io.ReadAll(fc) +} + +func extractPrivateKey(pemKey []byte) (any, error) { + block, _ := pem.Decode(pemKey) + if block == nil { + return nil, errors.New("failed to parse private key PEM") + } + var ( + privateKey any + err error + ) + switch block.Type { + case certs.RSAPrivateKey: + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case certs.ECPrivateKey: + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + case certs.PrivateKey, certs.PKCS8PrivateKey, certs.EDPrivateKey: + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + default: + err = certs.ErrPrivKeyType + } + if err != nil { + return nil, certs.ErrFailedParse + } + return privateKey, nil +} diff --git a/pkg/sdk/certs_test.go b/pkg/sdk/certs_test.go new file mode 100644 index 000000000..dc4b3e5da --- /dev/null +++ b/pkg/sdk/certs_test.go @@ -0,0 +1,1032 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/absmach/supermq/certs" + httpapi "github.com/absmach/supermq/certs/api/http" + "github.com/absmach/supermq/certs/mocks" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + certsInstanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" + certsContentType = "application/senml+json" + serialNum = "8e7a30c-bc9f-22de-ae67-1342bc139507" + certsID = "c333e6f-59bb-4c39-9e13-3a2766af8ba5" + ttl = "10h" + commonName = "test" + token = "token" + agentToken = "agent-token" + certsDomainID = "domain-certsID" +) + +func setupCerts() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + svc := new(mocks.Service) + logger := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + handler := httpapi.MakeHandler(svc, am, logger, certsInstanceID, agentToken) + + return httptest.NewServer(handler), svc, authn +} + +func TestIssueCert(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + ipAddr := []string{"192.128.101.82"} + cases := []struct { + desc string + entityID string + ttl string + ipAddrs []string + commonName string + svcresp certs.Certificate + svcerr error + authenticateErr error + err errors.SDKError + sdkCert sdk.Certificate + domain string + token string + session smqauthn.Session + }{ + { + desc: "IssueCert success", + entityID: certsID, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{ + SerialNumber: serialNum, + }, + sdkCert: sdk.Certificate{ + SerialNumber: serialNum, + }, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert failure", + entityID: certsID, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{}, + svcerr: certs.ErrCreateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrCreateEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert with empty entityID", + entityID: `""`, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{}, + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert with empty ipAddrs", + entityID: certsID, + ttl: ttl, + commonName: commonName, + svcresp: certs.Certificate{SerialNumber: serialNum}, + sdkCert: sdk.Certificate{ + SerialNumber: serialNum, + }, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert with empty ttl", + entityID: certsID, + ttl: "", + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{SerialNumber: serialNum}, + sdkCert: sdk.Certificate{ + SerialNumber: serialNum, + }, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert with empty commonName", + entityID: certsID, + ttl: ttl, + ipAddrs: ipAddr, + commonName: "", + svcresp: certs.Certificate{}, + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + domain: certsDomainID, + token: token, + }, + { + desc: "IssueCert with empty token", + entityID: certsID, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{}, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing or invalid bearer user token"), http.StatusUnauthorized), + domain: certsDomainID, + token: "", + }, + { + desc: "IssueCert with empty domain", + entityID: certsID, + ttl: ttl, + ipAddrs: ipAddr, + commonName: commonName, + svcresp: certs.Certificate{}, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing domainID"), http.StatusBadRequest), + domain: "", + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("IssueCert", mock.Anything, tc.session, tc.entityID, tc.ttl, tc.ipAddrs, certs.SubjectOptions{CommonName: tc.commonName}).Return(tc.svcresp, tc.svcerr) + resp, err := ctsdk.IssueCert(context.Background(), tc.entityID, tc.ttl, tc.ipAddrs, sdk.Options{CommonName: tc.commonName}, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.sdkCert.SerialNumber, resp.SerialNumber) + ok := svcCall.Parent.AssertCalled(t, "IssueCert", mock.Anything, tc.session, tc.entityID, tc.ttl, tc.ipAddrs, certs.SubjectOptions{CommonName: tc.commonName}) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRevokeCert(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + serial string + svcresp string + svcerr error + authenticateErr error + err errors.SDKError + domain string + token string + session smqauthn.Session + }{ + { + desc: "RevokeCert success", + serial: serialNum, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "RevokeCert failure", + serial: serialNum, + svcerr: certs.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + { + desc: "RevokeCert with empty serial", + serial: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + domain: certsDomainID, + token: token, + }, + { + desc: "RevokeCert with empty token", + serial: serialNum, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing or invalid bearer user token"), http.StatusUnauthorized), + domain: certsDomainID, + token: "", + }, + { + desc: "RevokeCert with empty domain", + serial: serialNum, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing domainID"), http.StatusBadRequest), + domain: "", + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("RevokeBySerial", mock.Anything, tc.session, tc.serial).Return(tc.svcerr) + + err := ctsdk.RevokeCert(context.Background(), tc.serial, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.desc != "RevokeCert with empty serial" && tc.desc != "RevokeCert with empty token" && tc.desc != "RevokeCert with empty domain" { + ok := svcCall.Parent.AssertCalled(t, "RevokeBySerial", mock.Anything, tc.session, tc.serial) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteCert(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + entityID string + svcresp string + svcerr error + authenticateErr error + err errors.SDKError + domain string + token string + session smqauthn.Session + }{ + { + desc: "DeleteCert success", + entityID: certsID, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "DeleteCert failure", + entityID: certsID, + svcerr: certs.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + { + desc: "DeleteCert with empty entity certsID", + entityID: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + domain: certsDomainID, + token: token, + }, + { + desc: "DeleteCert with empty token", + entityID: certsID, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing or invalid bearer user token"), http.StatusUnauthorized), + domain: certsDomainID, + token: "", + }, + { + desc: "DeleteCert with empty domain", + entityID: certsID, + svcerr: nil, + err: errors.NewSDKErrorWithStatus(errors.New("missing domainID"), http.StatusBadRequest), + domain: "", + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("RevokeAll", mock.Anything, tc.session, tc.entityID).Return(tc.svcerr) + + err := ctsdk.DeleteCert(context.Background(), tc.entityID, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.desc != "DeleteCert with empty entity certsID" && tc.desc != "DeleteCert with empty token" && tc.desc != "DeleteCert with empty domain" { + ok := svcCall.Parent.AssertCalled(t, "RevokeAll", mock.Anything, tc.session, tc.entityID) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRenewCert(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + serial string + svcresp certs.Certificate + svcerr error + authenticateErr error + err errors.SDKError + expected sdk.Certificate + domain string + token string + session smqauthn.Session + }{ + { + desc: "RenewCert success", + serial: serialNum, + svcresp: certs.Certificate{ + SerialNumber: "new-serial-123", + EntityID: "test-entity", + }, + svcerr: nil, + err: nil, + expected: sdk.Certificate{ + SerialNumber: "new-serial-123", + EntityID: "test-entity", + }, + domain: certsDomainID, + token: token, + }, + { + desc: "RenewCert failure", + serial: serialNum, + svcresp: certs.Certificate{}, + svcerr: certs.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + expected: sdk.Certificate{}, + domain: certsDomainID, + token: token, + }, + { + desc: "RenewCert with empty serial", + serial: "", + svcresp: certs.Certificate{}, + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + expected: sdk.Certificate{}, + domain: certsDomainID, + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("RenewCert", mock.Anything, tc.session, tc.serial).Return(tc.svcresp, tc.svcerr) + + cert, err := ctsdk.RenewCert(context.Background(), tc.serial, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.expected, cert) + } else { + assert.Equal(t, sdk.Certificate{}, cert) + } + if tc.desc != "RenewCert with empty serial" { + ok := svcCall.Parent.AssertCalled(t, "RenewCert", mock.Anything, tc.session, tc.serial) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListCerts(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + svcResp certs.CertificatePage + sdkPm sdk.PageMetadata + svcerr error + authenticateErr error + err errors.SDKError + domain string + token string + session smqauthn.Session + }{ + { + desc: "ListCerts success", + sdkPm: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcResp: certs.CertificatePage{ + PageMetadata: certs.PageMetadata{ + Total: 1, + Offset: 0, + Limit: 10, + }, + Certificates: []certs.Certificate{ + { + SerialNumber: serialNum, + }, + }, + }, + domain: certsDomainID, + token: token, + }, + { + desc: "ListCerts success with entity certsID", + sdkPm: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + EntityID: certsID, + }, + svcResp: certs.CertificatePage{ + PageMetadata: certs.PageMetadata{ + Total: 1, + Offset: 0, + Limit: 10, + }, + Certificates: []certs.Certificate{ + { + SerialNumber: serialNum, + EntityID: certsID, + }, + }, + }, + domain: certsDomainID, + token: token, + }, + { + desc: "ListCerts failure", + sdkPm: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("ListCerts", mock.Anything, tc.session, mock.Anything).Return(tc.svcResp, tc.svcerr) + + resp, err := ctsdk.ListCerts(context.Background(), tc.sdkPm, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.svcResp.Total, resp.Total) + assert.Equal(t, tc.svcResp.Certificates[0].SerialNumber, resp.Certificates[0].SerialNumber) + if tc.desc == "ListCerts success with entity certsID" { + assert.Equal(t, tc.svcResp.Certificates[0].EntityID, resp.Certificates[0].EntityID) + } + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewCert(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + } + + cases := []struct { + desc string + serial string + svcresp certs.Certificate + svcerr error + authenticateErr error + err errors.SDKError + sdkCert sdk.Certificate + domain string + token string + session smqauthn.Session + }{ + { + desc: "ViewCert success", + serial: serialNum, + svcresp: certs.Certificate{ + SerialNumber: serialNum, + }, + sdkCert: cert, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "ViewCert failure", + serial: serialNum, + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("ViewCert", mock.Anything, tc.session, tc.serial).Return(tc.svcresp, tc.svcerr) + + c, err := ctsdk.ViewCert(context.Background(), tc.serial, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "ViewCert", mock.Anything, tc.session, tc.serial) + assert.True(t, ok) + } + assert.Equal(t, tc.sdkCert.SerialNumber, c.SerialNumber, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.SerialNumber, c.SerialNumber)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDownloadCACert(t *testing.T) { + ts, svc, _ := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + } + + cases := []struct { + desc string + svcresp certs.Certificate + svcerr error + err errors.SDKError + sdkCert sdk.Certificate + }{ + { + desc: "Download CA successfully", + svcresp: certs.Certificate{ + SerialNumber: serialNum, + Certificate: []byte("cert"), + Key: []byte("key"), + }, + sdkCert: cert, + svcerr: nil, + err: nil, + }, + { + desc: "Download CA failure", + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RetrieveCAChain", mock.Anything).Return(tc.svcresp, tc.svcerr) + + _, err := ctsdk.DownloadCA(context.Background()) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "RetrieveCAChain", mock.Anything) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestViewCA(t *testing.T) { + ts, svc, _ := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + Certificate: "cert", + Key: "Key", + } + + cases := []struct { + desc string + svcresp certs.Certificate + svcerr error + err errors.SDKError + sdkCert sdk.Certificate + }{ + { + desc: "ViewCA success", + svcresp: certs.Certificate{ + Certificate: []byte("cert"), + }, + sdkCert: cert, + svcerr: nil, + err: nil, + }, + { + desc: "ViewCA failure", + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RetrieveCAChain", mock.Anything).Return(tc.svcresp, tc.svcerr) + + c, err := ctsdk.ViewCA(context.Background()) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "RetrieveCAChain", mock.Anything) + assert.True(t, ok) + } + assert.Equal(t, tc.sdkCert.Certificate, c.Certificate, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.Certificate, c.Certificate)) + svcCall.Unset() + }) + } +} + +func TestGenerateCRL(t *testing.T) { + ts, svc, _ := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + crlData := []byte("mock-crl-data") + + cases := []struct { + desc string + svcresp []byte + svcerr error + err errors.SDKError + }{ + { + desc: "GenerateCRL success", + svcresp: crlData, + svcerr: nil, + err: nil, + }, + { + desc: "GenerateCRL failure", + svcresp: nil, + svcerr: certs.ErrFailedCertCreation, + err: errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("GenerateCRL", mock.Anything).Return(tc.svcresp, tc.svcerr) + + resp, err := ctsdk.GenerateCRL(context.Background()) + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.svcresp, resp) + ok := svcCall.Parent.AssertCalled(t, "GenerateCRL", mock.Anything) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestRevokeAll(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + entityID string + svcerr error + authenticateErr error + err errors.SDKError + domain string + token string + session smqauthn.Session + }{ + { + desc: "RevokeAll success", + entityID: certsID, + svcerr: nil, + err: nil, + domain: certsDomainID, + token: token, + }, + { + desc: "RevokeAll failure", + entityID: certsID, + svcerr: certs.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + domain: certsDomainID, + token: token, + }, + { + desc: "RevokeAll with empty entityID", + entityID: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + domain: certsDomainID, + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := svc.On("RevokeAll", mock.Anything, tc.session, tc.entityID).Return(tc.svcerr) + + err := ctsdk.RevokeAll(context.Background(), tc.entityID, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + if tc.desc != "RevokeAll with empty entityID" { + ok := svcCall.Parent.AssertCalled(t, "RevokeAll", mock.Anything, tc.session, tc.entityID) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestGetEntityID(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + entityID := "test-entity-certsID" + + cases := []struct { + desc string + serial string + svcresp certs.Certificate + svcerr error + authenticateErr error + err errors.SDKError + expected string + domain string + token string + session smqauthn.Session + }{ + { + desc: "GetEntityID success", + serial: serialNum, + svcresp: certs.Certificate{ + SerialNumber: serialNum, + EntityID: entityID, + }, + svcerr: nil, + err: nil, + expected: entityID, + domain: certsDomainID, + token: token, + }, + { + desc: "GetEntityID failure", + serial: serialNum, + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + expected: "", + domain: certsDomainID, + token: token, + }, + { + desc: "GetEntityID with empty serial", + serial: "", + svcresp: certs.Certificate{}, + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + expected: "", + domain: certsDomainID, + token: token, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == token { + tc.session = smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + } + + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + var svcCall *mock.Call + if tc.desc == "GetEntityID with empty serial" { + // Empty serial routes to ListCerts endpoint instead of ViewCert + svcCall = svc.On("ListCerts", mock.Anything, tc.session, mock.Anything).Return(certs.CertificatePage{}, tc.svcerr) + } else { + svcCall = svc.On("ViewCert", mock.Anything, tc.session, tc.serial).Return(tc.svcresp, tc.svcerr) + } + + resp, err := ctsdk.EntityID(context.Background(), tc.serial, tc.domain, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.expected, resp) + if tc.desc != "GetEntityID with empty serial" { + ok := svcCall.Parent.AssertCalled(t, "ViewCert", mock.Anything, tc.session, tc.serial) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestIssueFromCSRInternal(t *testing.T) { + ts, svc, auth := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: certsContentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + } + + cases := []struct { + desc string + entityID string + ttl string + csr string + svcresp certs.Certificate + svcerr error + err errors.SDKError + sdkCert sdk.Certificate + }{ + { + desc: "IssueFromCSRInternal success", + entityID: certsID, + ttl: ttl, + csr: "valid-csr-content", + svcresp: certs.Certificate{ + SerialNumber: serialNum, + Certificate: []byte("cert"), + Key: []byte("key"), + }, + sdkCert: cert, + svcerr: nil, + err: nil, + }, + { + desc: "IssueFromCSRInternal failure", + entityID: certsID, + ttl: ttl, + csr: "invalid-csr-content", + svcresp: certs.Certificate{}, + svcerr: certs.ErrFailedCertCreation, + err: errors.NewSDKErrorWithStatus(certs.ErrFailedCertCreation, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + agentSession := smqauthn.Session{DomainUserID: certsID, UserID: certsID, DomainID: certsDomainID} + authCall := auth.On("Authenticate", mock.Anything, agentToken).Return(agentSession, nil) + svcCall := svc.On("IssueFromCSRInternal", mock.Anything, tc.entityID, tc.ttl, mock.Anything).Return(tc.svcresp, tc.svcerr) + + c, err := ctsdk.IssueFromCSRInternal(context.Background(), tc.entityID, tc.ttl, tc.csr, agentToken) + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.sdkCert.SerialNumber, c.SerialNumber) + ok := svcCall.Parent.AssertCalled(t, "IssueFromCSRInternal", mock.Anything, tc.entityID, tc.ttl, mock.Anything) + assert.True(t, ok) + } + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/pkg/sdk/consumers.go b/pkg/sdk/consumers.go new file mode 100644 index 000000000..60c2dcf0d --- /dev/null +++ b/pkg/sdk/consumers.go @@ -0,0 +1,88 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/absmach/supermq/pkg/errors" +) + +const subscriptionEndpoint = "subscriptions" + +type Subscription struct { + ID string `json:"id,omitempty"` + OwnerID string `json:"owner_id,omitempty"` + Topic string `json:"topic,omitempty"` + Contact string `json:"contact,omitempty"` +} + +func (sdk mgSDK) CreateSubscription(ctx context.Context, topic, contact, token string) (string, errors.SDKError) { + sub := Subscription{ + Topic: topic, + Contact: contact, + } + data, err := json.Marshal(sub) + if err != nil { + return "", errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s", sdk.usersURL, subscriptionEndpoint) + + headers, _, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusCreated) + if sdkerr != nil { + return "", sdkerr + } + + id := strings.TrimPrefix(headers.Get("Location"), fmt.Sprintf("/%s/", subscriptionEndpoint)) + + return id, nil +} + +func (sdk mgSDK) ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, errors.SDKError) { + url, err := sdk.withQueryParams(sdk.usersURL, subscriptionEndpoint, pm) + if err != nil { + return SubscriptionPage{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return SubscriptionPage{}, sdkerr + } + + var sp SubscriptionPage + if err := json.Unmarshal(body, &sp); err != nil { + return SubscriptionPage{}, errors.NewSDKError(err) + } + + return sp, nil +} + +func (sdk mgSDK) ViewSubscription(ctx context.Context, id, token string) (Subscription, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) + + _, body, err := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if err != nil { + return Subscription{}, err + } + + var sub Subscription + if err := json.Unmarshal(body, &sub); err != nil { + return Subscription{}, errors.NewSDKError(err) + } + + return sub, nil +} + +func (sdk mgSDK) DeleteSubscription(ctx context.Context, id, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, subscriptionEndpoint, id) + + _, _, err := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent) + + return err +} diff --git a/pkg/sdk/consumers_test.go b/pkg/sdk/consumers_test.go new file mode 100644 index 000000000..66c799fb9 --- /dev/null +++ b/pkg/sdk/consumers_test.go @@ -0,0 +1,454 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/consumers/notifiers" + httpapi "github.com/absmach/supermq/consumers/notifiers/api" + notmocks "github.com/absmach/supermq/consumers/notifiers/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + ownerID = testsutil.GenerateUUID(&testing.T{}) + subID = testsutil.GenerateUUID(&testing.T{}) + sdkSubReq = sdk.Subscription{ + Topic: "topic", + Contact: "contact", + } + sdkSubRes = sdk.Subscription{ + Topic: "topic", + Contact: "contact", + OwnerID: ownerID, + ID: subID, + } + notSubReq = notifiers.Subscription{ + Contact: "contact", + Topic: "topic", + } + notSubRes = notifiers.Subscription{ + Contact: "contact", + Topic: "topic", + OwnerID: ownerID, + ID: subID, + } + instanceID = "instanceID" +) + +func setupSubscriptions() (*httptest.Server, *notmocks.Service) { + nsvc := new(notmocks.Service) + logger := smqlog.NewMock() + mux := httpapi.MakeHandler(nsvc, logger, instanceID) + + return httptest.NewServer(mux), nsvc +} + +func TestCreateSubscription(t *testing.T) { + ts, nsvc := setupSubscriptions() + defer ts.Close() + + sdkConf := sdk.Config{ + UsersURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + mgsdk := sdk.NewSDK(sdkConf) + + cases := []struct { + desc string + subscription sdk.Subscription + token string + empty bool + id string + svcReq notifiers.Subscription + svcErr error + svcRes string + err errors.SDKError + }{ + { + desc: "create new subscription", + subscription: sdkSubReq, + token: validToken, + empty: false, + svcReq: notSubReq, + svcRes: subID, + svcErr: nil, + err: nil, + }, + { + desc: "create new subscription with empty token", + subscription: sdkSubReq, + token: "", + empty: true, + svcReq: notifiers.Subscription{}, + svcRes: "", + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "create new subscription with invalid token", + subscription: sdkSubReq, + token: invalidToken, + empty: true, + svcReq: notSubReq, + svcRes: "", + svcErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "create new subscription with empty topic", + subscription: sdk.Subscription{ + Topic: "", + Contact: "contact", + }, + token: validToken, + empty: true, + svcReq: notifiers.Subscription{}, + svcErr: nil, + svcRes: "", + err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidTopic, http.StatusBadRequest), + }, + { + desc: "create new subscription with empty contact", + subscription: sdk.Subscription{ + Topic: "topic", + Contact: "", + }, + token: validToken, + empty: true, + svcReq: notifiers.Subscription{}, + svcErr: nil, + svcRes: "", + err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidContact, http.StatusBadRequest), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := nsvc.On("CreateSubscription", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) + loc, err := mgsdk.CreateSubscription(context.Background(), tc.subscription.Topic, tc.subscription.Contact, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.empty, loc == "") + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "CreateSubscription", mock.Anything, tc.token, tc.svcReq) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestViewSubscription(t *testing.T) { + ts, nsvc := setupSubscriptions() + defer ts.Close() + sdkConf := sdk.Config{ + UsersURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + mgsdk := sdk.NewSDK(sdkConf) + + cases := []struct { + desc string + subID string + token string + svcRes notifiers.Subscription + svcErr error + response sdk.Subscription + err errors.SDKError + }{ + { + desc: "view existing subscription", + subID: subID, + token: validToken, + svcRes: notSubRes, + svcErr: nil, + response: sdkSubRes, + err: nil, + }, + { + desc: "view non-existent subscription", + subID: wrongID, + token: validToken, + svcRes: notifiers.Subscription{}, + svcErr: svcerr.ErrNotFound, + response: sdk.Subscription{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + }, + { + desc: "view subscription with invalid token", + subID: subID, + token: invalidToken, + svcRes: notifiers.Subscription{}, + svcErr: svcerr.ErrAuthentication, + response: sdk.Subscription{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "view subscription with empty token", + subID: subID, + token: "", + svcRes: notifiers.Subscription{}, + svcErr: nil, + response: sdk.Subscription{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := nsvc.On("ViewSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcRes, tc.svcErr) + resp, err := mgsdk.ViewSubscription(context.Background(), tc.subID, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.response, resp) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "ViewSubscription", mock.Anything, tc.token, tc.subID) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestListSubscription(t *testing.T) { + ts, nsvc := setupSubscriptions() + defer ts.Close() + sdkConf := sdk.Config{ + UsersURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + mgsdk := sdk.NewSDK(sdkConf) + nSubs := 10 + noSubs := []notifiers.Subscription{} + sdSubs := []sdk.Subscription{} + for i := 0; i < nSubs; i++ { + nosub := notifiers.Subscription{ + OwnerID: ownerID, + Topic: fmt.Sprintf("topic_%d", i), + Contact: fmt.Sprintf("contact_%d", i), + } + noSubs = append(noSubs, nosub) + sdsub := sdk.Subscription{ + OwnerID: ownerID, + Topic: fmt.Sprintf("topic_%d", i), + Contact: fmt.Sprintf("contact_%d", i), + } + sdSubs = append(sdSubs, sdsub) + } + + cases := []struct { + desc string + token string + pageMeta sdk.PageMetadata + svcReq notifiers.PageMetadata + svcRes notifiers.Page + svcErr error + response sdk.SubscriptionPage + err errors.SDKError + }{ + { + desc: "list all subscription", + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcReq: notifiers.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcRes: notifiers.Page{ + Total: 10, + Subscriptions: noSubs, + }, + svcErr: nil, + response: sdk.SubscriptionPage{ + PageRes: sdk.PageRes{ + Total: 10, + }, + Subscriptions: sdSubs, + }, + err: nil, + }, + { + desc: "list subscription with specific topic", + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + Topic: "topic_1", + }, + svcReq: notifiers.PageMetadata{ + Offset: 0, + Limit: 10, + Topic: "topic_1", + }, + svcRes: notifiers.Page{ + Total: uint(len(noSubs[1:2])), + Subscriptions: noSubs[1:2], + }, + svcErr: nil, + response: sdk.SubscriptionPage{ + PageRes: sdk.PageRes{ + Total: uint64(len(sdSubs[1:2])), + }, + Subscriptions: sdSubs[1:2], + }, + err: nil, + }, + { + desc: "list subscription with specific contact", + token: validToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + Contact: "contact_1", + }, + svcReq: notifiers.PageMetadata{ + Offset: 0, + Limit: 10, + Contact: "contact_1", + }, + svcRes: notifiers.Page{ + Total: uint(len(noSubs[1:2])), + Subscriptions: noSubs[1:2], + }, + svcErr: nil, + response: sdk.SubscriptionPage{ + PageRes: sdk.PageRes{ + Total: uint64(len(sdSubs[1:2])), + }, + Subscriptions: sdSubs[1:2], + }, + err: nil, + }, + { + desc: "list subscription with invalid token", + token: invalidToken, + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcReq: notifiers.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcRes: notifiers.Page{}, + svcErr: svcerr.ErrAuthentication, + response: sdk.SubscriptionPage{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "list subscription with empty token", + token: "", + pageMeta: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + svcReq: notifiers.PageMetadata{}, + svcRes: notifiers.Page{}, + svcErr: nil, + response: sdk.SubscriptionPage{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := nsvc.On("ListSubscriptions", mock.Anything, tc.token, tc.svcReq).Return(tc.svcRes, tc.svcErr) + resp, err := mgsdk.ListSubscriptions(context.Background(), tc.pageMeta, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.response, resp) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "ListSubscriptions", mock.Anything, tc.token, tc.svcReq) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestDeleteSubscription(t *testing.T) { + ts, nsvc := setupSubscriptions() + defer ts.Close() + sdkConf := sdk.Config{ + UsersURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + mgsdk := sdk.NewSDK(sdkConf) + + cases := []struct { + desc string + subID string + token string + svcErr error + err errors.SDKError + }{ + { + desc: "delete existing subscription", + subID: subID, + token: validToken, + svcErr: nil, + err: nil, + }, + { + desc: "delete non-existent subscription", + subID: wrongID, + token: validToken, + svcErr: svcerr.ErrRemoveEntity, + err: errors.NewSDKErrorWithStatus(svcerr.ErrRemoveEntity, http.StatusUnprocessableEntity), + }, + { + desc: "delete subscription with invalid token", + subID: subID, + token: invalidToken, + svcErr: svcerr.ErrAuthentication, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + }, + { + desc: "delete subscription with empty token", + subID: subID, + token: "", + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "delete subscription with empty subID", + subID: "", + token: validToken, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := nsvc.On("RemoveSubscription", mock.Anything, tc.token, tc.subID).Return(tc.svcErr) + err := mgsdk.DeleteSubscription(context.Background(), tc.subID, tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "RemoveSubscription", mock.Anything, tc.token, tc.subID) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} diff --git a/pkg/sdk/health.go b/pkg/sdk/health.go index bd77143cf..97ff0b0c5 100644 --- a/pkg/sdk/health.go +++ b/pkg/sdk/health.go @@ -38,8 +38,6 @@ func (sdk mgSDK) Health(service string) (HealthInfo, errors.SDKError) { url = fmt.Sprintf("%s/health", sdk.usersURL) case "certs": url = fmt.Sprintf("%s/health", sdk.certsURL) - case "http-adapter": - url = fmt.Sprintf("%s/health", sdk.httpAdapterURL) case "groups": url = fmt.Sprintf("%s/health", sdk.groupsURL) case "channels": @@ -48,6 +46,8 @@ func (sdk mgSDK) Health(service string) (HealthInfo, errors.SDKError) { url = fmt.Sprintf("%s/health", sdk.domainsURL) case "journal": url = fmt.Sprintf("%s/health", sdk.journalURL) + case "fluxmq": + url = fmt.Sprintf("%s/health", sdk.httpAdapterURL) } resp, err := sdk.client.Get(url) diff --git a/pkg/sdk/health_test.go b/pkg/sdk/health_test.go index 7d2cb34f1..3e13b9f57 100644 --- a/pkg/sdk/health_test.go +++ b/pkg/sdk/health_test.go @@ -20,9 +20,6 @@ func TestHealth(t *testing.T) { usersTs, _, _ := setupUsers() defer usersTs.Close() - httpAdapterTs, _ := setupMessages(t) - defer httpAdapterTs.Close() - groupsTs, _, _ := setupGroups() defer groupsTs.Close() @@ -35,10 +32,13 @@ func TestHealth(t *testing.T) { journalTs, _, _ := setupJournal() defer journalTs.Close() + fluxmqTs := setupFluxMQ("any") + defer fluxmqTs.Close() + sdkConf := sdk.Config{ ClientsURL: clientsTs.URL, UsersURL: usersTs.URL, - HTTPAdapterURL: httpAdapterTs.URL, + HTTPAdapterURL: fluxmqTs.URL, GroupsURL: groupsTs.URL, ChannelsURL: channelsTs.URL, DomainsURL: domainsTs.URL, @@ -72,14 +72,6 @@ func TestHealth(t *testing.T) { description: "users service", status: "pass", }, - { - desc: "get http-adapter service health check", - service: "http-adapter", - empty: false, - err: nil, - description: "http service", - status: "pass", - }, { desc: "get groups service health check", service: "groups", @@ -124,4 +116,11 @@ func TestHealth(t *testing.T) { assert.Equal(t, supermq.BuildTime, h.BuildTime, fmt.Sprintf("%s: expected default epoch date, got %s", tc.desc, h.BuildTime)) }) } + + // FluxMQ returns a simpler health response without version/commit/description. + t.Run("get fluxmq service health check", func(t *testing.T) { + h, err := mgsdk.Health("fluxmq") + assert.Nil(t, err) + assert.Equal(t, "healthy", h.Status) + }) } diff --git a/pkg/sdk/message.go b/pkg/sdk/message.go index 37b5f7791..04e176dc0 100644 --- a/pkg/sdk/message.go +++ b/pkg/sdk/message.go @@ -5,6 +5,7 @@ package sdk import ( "context" + "encoding/json" "fmt" "net/http" "strings" @@ -15,19 +16,37 @@ import ( const channelParts = 2 +type publishRequest struct { + Topic string `json:"topic"` + Payload []byte `json:"payload"` + QoS byte `json:"qos"` + Retain bool `json:"retain"` +} + func (sdk mgSDK) SendMessage(ctx context.Context, domainID, topic, msg, secret string) errors.SDKError { chanNameParts := strings.SplitN(topic, ".", channelParts) chanID := chanNameParts[0] - subtopicPart := "" + brokerTopic := fmt.Sprintf("m/%s/c/%s", domainID, chanID) if len(chanNameParts) == channelParts { - subtopicPart = fmt.Sprintf("/%s", strings.ReplaceAll(chanNameParts[1], ".", "/")) + brokerTopic = fmt.Sprintf("%s/%s", brokerTopic, strings.ReplaceAll(chanNameParts[1], ".", "/")) } - reqURL := fmt.Sprintf("%s/m/%s/c/%s%s", sdk.httpAdapterURL, domainID, chanID, subtopicPart) + data, err := json.Marshal(publishRequest{ + Topic: brokerTopic, + Payload: []byte(msg), + }) + if err != nil { + return errors.NewSDKError(err) + } - _, _, err := sdk.processRequest(ctx, http.MethodPost, reqURL, ClientPrefix+secret, []byte(msg), nil, http.StatusAccepted) + headers := map[string]string{ + "X-FluxMQ-Password": secret, + } - return err + reqURL := fmt.Sprintf("%s/publish", sdk.httpAdapterURL) + _, _, sdkErr := sdk.processRequest(ctx, http.MethodPost, reqURL, "", data, headers, http.StatusOK) + + return sdkErr } func (sdk *mgSDK) SetContentType(ct ContentType) errors.SDKError { diff --git a/pkg/sdk/message_test.go b/pkg/sdk/message_test.go index 270dba3ec..51565480e 100644 --- a/pkg/sdk/message_test.go +++ b/pkg/sdk/message_test.go @@ -5,212 +5,145 @@ package sdk_test import ( "context" + "encoding/json" "fmt" - "net" + "io" "net/http" "net/http/httptest" - "net/url" - "strings" "testing" - "github.com/absmach/mgate" - proxy "github.com/absmach/mgate/pkg/http" - grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" - grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" - grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1" apiutil "github.com/absmach/supermq/api/http/util" - chmocks "github.com/absmach/supermq/channels/mocks" - climocks "github.com/absmach/supermq/clients/mocks" - dmocks "github.com/absmach/supermq/domains/mocks" - adapter "github.com/absmach/supermq/http" - "github.com/absmach/supermq/http/api" - httpmocks "github.com/absmach/supermq/http/mocks" - smqlog "github.com/absmach/supermq/logger" - authnmocks "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/messaging" - pubsub "github.com/absmach/supermq/pkg/messaging/mocks" sdk "github.com/absmach/supermq/pkg/sdk" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" ) -var ( - channelsGRPCClient *chmocks.ChannelsServiceClient - clientsGRPCClient *climocks.ClientsServiceClient - domainsGRPCClient *dmocks.DomainsServiceClient -) +type publishReq struct { + Topic string `json:"topic"` + Payload []byte `json:"payload"` + QoS byte `json:"qos"` + Retain bool `json:"retain"` +} -func setupMessages(t *testing.T) (*httptest.Server, *pubsub.PubSub) { - clientsGRPCClient = new(climocks.ClientsServiceClient) - channelsGRPCClient = new(chmocks.ChannelsServiceClient) - domainsGRPCClient = new(dmocks.DomainsServiceClient) - pub := new(pubsub.PubSub) - authn := new(authnmocks.Authentication) - svc := new(httpmocks.Service) +func setupFluxMQ(secret string, expectedTopic ...string) *httptest.Server { + mux := http.NewServeMux() - parser, err := messaging.NewTopicParser(messaging.DefaultCacheConfig, channelsGRPCClient, domainsGRPCClient) - assert.Nil(t, err, fmt.Sprintf("unexpected error while setting up parser: %v", err)) - handler := adapter.NewHandler(pub, smqlog.NewMock(), authn, clientsGRPCClient, channelsGRPCClient, parser) - resolver := messaging.NewTopicResolver(channelsGRPCClient, domainsGRPCClient) + mux.HandleFunc("POST /publish", func(w http.ResponseWriter, r *http.Request) { + password := r.Header.Get("X-FluxMQ-Password") + if password == "" || password != secret { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } - mux := api.MakeHandler(context.Background(), svc, resolver, smqlog.NewMock(), "") - target := httptest.NewServer(mux) + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + defer r.Body.Close() - ptUrl, _ := url.Parse(target.URL) - ptHost, ptPort, _ := net.SplitHostPort(ptUrl.Host) - config := mgate.Config{ - Host: "", - Port: "", - PathPrefix: "", - TargetHost: ptHost, - TargetPort: ptPort, - TargetProtocol: ptUrl.Scheme, - TargetPath: ptUrl.Path, - } + var req publishReq + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } - mp, err := proxy.NewProxy(config, handler, smqlog.NewMock(), []string{}, []string{"/health", "/metrics"}) - if err != nil { - return nil, nil - } + if req.Topic == "" { + http.Error(w, "empty topic", http.StatusBadRequest) + return + } + if len(expectedTopic) > 0 && req.Topic != expectedTopic[0] { + http.Error(w, fmt.Sprintf("unexpected topic: %s", req.Topic), http.StatusBadRequest) + return + } - return httptest.NewServer(http.HandlerFunc(mp.ServeHTTP)), pub + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"ok"}`) + }) + + mux.HandleFunc("GET /health", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"status":"healthy"}`) + }) + + return httptest.NewServer(mux) } func TestSendMessage(t *testing.T) { - ts, pub := setupMessages(t) - defer ts.Close() - - msg := `[{"n":"current","t":-1,"v":1.6}]` - clientKey := "clientKey" - channelID := "channelID" - domainID := "domainID" - - sdkConf := sdk.Config{ - HTTPAdapterURL: ts.URL, - MsgContentType: "application/senml+json", - TLSVerification: false, - } - - mgsdk := sdk.NewSDK(sdkConf) + clientSecret := "validSecret" cases := []struct { - desc string - topic string - domainID string - msg string - secret string - authRes *grpcClientsV1.AuthnRes - authErr error - svcErr error - err errors.SDKError + desc string + topic string + domainID string + wantTopic string + msg string + secret string + err errors.SDKError }{ { - desc: "publish message successfully", - topic: channelID, - domainID: domainID, - msg: msg, - secret: clientKey, - authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""}, - authErr: nil, - svcErr: nil, - err: nil, + desc: "publish message successfully", + topic: "channelID", + domainID: "domainID", + wantTopic: "m/domainID/c/channelID", + msg: `[{"n":"current","t":-1,"v":1.6}]`, + secret: clientSecret, + err: nil, }, { - desc: "publish message with empty client key", - topic: channelID, - domainID: domainID, - msg: msg, - secret: "", - authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, - svcErr: nil, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + desc: "publish message with subtopic", + topic: "channelID.sub.topic", + domainID: "domainID", + wantTopic: "m/domainID/c/channelID/sub/topic", + msg: `[{"n":"current","t":-1,"v":1.6}]`, + secret: clientSecret, + err: nil, }, { - desc: "publish message with invalid client key", - topic: channelID, - domainID: domainID, - msg: msg, - secret: "invalid", - authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, - svcErr: svcerr.ErrAuthentication, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + desc: "publish message with invalid secret", + topic: "channelID", + domainID: "domainID", + wantTopic: "m/domainID/c/channelID", + msg: `[{"n":"current","t":-1,"v":1.6}]`, + secret: "invalid", + err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.New(""), errors.New("")), http.StatusUnauthorized), }, { - desc: "publish message with invalid channel ID", - topic: wrongID, - domainID: domainID, - msg: msg, - secret: clientKey, - authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, - svcErr: svcerr.ErrAuthentication, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), - }, - { - desc: "publish message with empty message body", - topic: channelID, - domainID: domainID, - msg: "", - secret: clientKey, - authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""}, - authErr: nil, - svcErr: nil, - err: errors.NewSDKErrorWithStatus(apiutil.ErrEmptyMessage, http.StatusBadRequest), - }, - { - desc: "publish message with channel subtopic", - topic: channelID + ".subtopic", - domainID: domainID, - msg: msg, - secret: clientKey, - authRes: &grpcClientsV1.AuthnRes{Authenticated: true, Id: ""}, - authErr: nil, - svcErr: nil, - err: nil, - }, - { - desc: "publish message with invalid domain ID", - topic: channelID, - domainID: wrongID, - msg: msg, - secret: clientKey, - authRes: &grpcClientsV1.AuthnRes{Authenticated: false, Id: ""}, - svcErr: svcerr.ErrAuthentication, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), + desc: "publish message with empty secret", + topic: "channelID", + domainID: "domainID", + wantTopic: "m/domainID/c/channelID", + msg: `[{"n":"current","t":-1,"v":1.6}]`, + secret: "", + err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.New(""), errors.New("")), http.StatusUnauthorized), }, } for _, tc := range cases { - internalTopic := tc.domainID + ".c." + strings.ReplaceAll(tc.topic, "/", ".") t.Run(tc.desc, func(t *testing.T) { - authzCall := clientsGRPCClient.On("Authenticate", mock.Anything, mock.Anything).Return(tc.authRes, tc.authErr) - authnCall := channelsGRPCClient.On("Authorize", mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, nil) - svcCall := pub.On("Publish", mock.Anything, internalTopic, mock.Anything).Return(tc.svcErr) - domainsCall := domainsGRPCClient.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: tc.domainID}}, nil) - channelsCall := channelsGRPCClient.On("RetrieveIDByRoute", mock.Anything, mock.Anything).Return(&grpcCommonV1.RetrieveEntityRes{Entity: &grpcCommonV1.EntityBasic{Id: channelID}}, nil) + ts := setupFluxMQ(clientSecret, tc.wantTopic) + defer ts.Close() + + sdkConf := sdk.Config{ + HTTPAdapterURL: ts.URL, + MsgContentType: "application/senml+json", + TLSVerification: false, + } + mgsdk := sdk.NewSDK(sdkConf) + err := mgsdk.SendMessage(context.Background(), tc.domainID, tc.topic, tc.msg, tc.secret) if tc.err != nil { - assert.Contains(t, err.Error(), tc.err.Error(), fmt.Sprintf("expected error message to contain: %v, got: %v", tc.err, err)) + assert.NotNil(t, err, fmt.Sprintf("%s: expected error, got nil", tc.desc)) + } else { + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error: %v", tc.desc, err)) } - if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "Publish", mock.Anything, internalTopic, mock.Anything) - assert.True(t, ok) - } - svcCall.Unset() - authzCall.Unset() - authnCall.Unset() - domainsCall.Unset() - channelsCall.Unset() }) } } func TestSetContentType(t *testing.T) { - ts, _ := setupMessages(t) - defer ts.Close() - sdkConf := sdk.Config{ - HTTPAdapterURL: ts.URL, MsgContentType: "application/senml+json", TLSVerification: false, } @@ -226,6 +159,11 @@ func TestSetContentType(t *testing.T) { cType: "application/senml+json", err: nil, }, + { + desc: "set json content type", + cType: "application/json", + err: nil, + }, { desc: "set invalid content type", cType: "invalid", @@ -233,7 +171,9 @@ func TestSetContentType(t *testing.T) { }, } for _, tc := range cases { - err := mgsdk.SetContentType(tc.cType) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) + t.Run(tc.desc, func(t *testing.T) { + err := mgsdk.SetContentType(tc.cType) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) + }) } } diff --git a/pkg/sdk/messages.go b/pkg/sdk/messages.go new file mode 100644 index 000000000..a30b30f6f --- /dev/null +++ b/pkg/sdk/messages.go @@ -0,0 +1,76 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + + "github.com/absmach/supermq/pkg/errors" +) + +func (sdk mgSDK) ReadMessages(ctx context.Context, pm MessagePageMetadata, chanName, domainID, token string) (MessagesPage, errors.SDKError) { + chanNameParts := strings.SplitN(chanName, ".", channelParts) + chanID := chanNameParts[0] + subtopicPart := "" + if len(chanNameParts) == channelParts { + subtopicPart = fmt.Sprintf("?subtopic=%s", chanNameParts[1]) + } + + msgURL, err := sdk.withMessageQueryParams(sdk.readersURL, fmt.Sprintf("%s/channels/%s/messages%s", domainID, chanID, subtopicPart), pm) + if err != nil { + return MessagesPage{}, errors.NewSDKError(err) + } + + header := make(map[string]string) + header["Content-Type"] = string(sdk.msgContentType) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, msgURL, token, nil, header, http.StatusOK) + if sdkerr != nil { + return MessagesPage{}, sdkerr + } + + var mp MessagesPage + if err := json.Unmarshal(body, &mp); err != nil { + return MessagesPage{}, errors.NewSDKError(err) + } + + return mp, nil +} + +func (sdk mgSDK) withMessageQueryParams(baseURL, endpoint string, mpm MessagePageMetadata) (string, error) { + b, err := json.Marshal(mpm) + if err != nil { + return "", err + } + q := map[string]any{} + if err := json.Unmarshal(b, &q); err != nil { + return "", err + } + ret := url.Values{} + for k, v := range q { + switch t := v.(type) { + case string: + ret.Add(k, t) + case float64: + ret.Add(k, strconv.FormatFloat(t, 'f', -1, 64)) + case uint64: + ret.Add(k, strconv.FormatUint(t, 10)) + case int64: + ret.Add(k, strconv.FormatInt(t, 10)) + case json.Number: + ret.Add(k, t.String()) + case bool: + ret.Add(k, strconv.FormatBool(t)) + } + } + qs := ret.Encode() + + return fmt.Sprintf("%s/%s?%s", baseURL, endpoint, qs), nil +} diff --git a/pkg/sdk/messages_test.go b/pkg/sdk/messages_test.go new file mode 100644 index 000000000..afaa41450 --- /dev/null +++ b/pkg/sdk/messages_test.go @@ -0,0 +1,240 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" + apiutil "github.com/absmach/supermq/api/http/util" + chmocks "github.com/absmach/supermq/channels/mocks" + climocks "github.com/absmach/supermq/clients/mocks" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + readersapi "github.com/absmach/supermq/readers/api/http" + readersmocks "github.com/absmach/supermq/readers/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func setupReaders() (*httptest.Server, *authnmocks.Authentication, *readersmocks.MessageRepository) { + repo := new(readersmocks.MessageRepository) + authn := new(authnmocks.Authentication) + clientsGRPCClient = new(climocks.ClientsServiceClient) + channelsGRPCClient = new(chmocks.ChannelsServiceClient) + + mux := readersapi.MakeHandler(repo, authn, clientsGRPCClient, channelsGRPCClient, "test", "") + return httptest.NewServer(mux), authn, repo +} + +func TestReadMessages(t *testing.T) { + ts, authn, repo := setupReaders() + defer ts.Close() + + channelID := "channelID" + msgValue := 1.6 + boolVal := true + msg := senml.Message{ + Name: "current", + Time: 1720000000, + Value: &msgValue, + Publisher: validID, + } + invalidMsg := "[{\"n\":\"current\",\"t\":-1,\"v\":1.6}]" + + sdkConf := sdk.Config{ + ReaderURL: ts.URL, + } + + mgsdk := sdk.NewSDK(sdkConf) + + cases := []struct { + desc string + token string + chanName string + domainID string + messagePageMeta sdk.MessagePageMetadata + authzErr error + authnErr error + repoRes readers.MessagesPage + repoErr error + response sdk.MessagesPage + err errors.SDKError + }{ + { + desc: "read messages successfully", + token: validToken, + chanName: channelID, + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + Level: 0, + }, + Publisher: validID, + BoolValue: &boolVal, + }, + repoRes: readers.MessagesPage{ + Total: 1, + Messages: []readers.Message{msg}, + }, + repoErr: nil, + response: sdk.MessagesPage{ + PageRes: sdk.PageRes{ + Total: 1, + }, + Messages: []senml.Message{msg}, + }, + err: nil, + }, + { + desc: "read messages successfully with subtopic", + token: validToken, + chanName: channelID + ".subtopic", + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Publisher: validID, + }, + repoRes: readers.MessagesPage{ + Total: 1, + Messages: []readers.Message{msg}, + }, + repoErr: nil, + response: sdk.MessagesPage{ + PageRes: sdk.PageRes{ + Total: 1, + }, + Messages: []senml.Message{msg}, + }, + err: nil, + }, + { + desc: "read messages with invalid token", + token: invalidToken, + chanName: channelID, + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Subtopic: "subtopic", + Publisher: validID, + }, + authzErr: svcerr.ErrAuthorization, + repoRes: readers.MessagesPage{}, + response: sdk.MessagesPage{}, + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + }, + { + desc: "read messages with empty token", + token: "", + chanName: channelID, + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Subtopic: "subtopic", + Publisher: validID, + }, + authnErr: svcerr.ErrAuthentication, + repoRes: readers.MessagesPage{}, + response: sdk.MessagesPage{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + }, + { + desc: "read messages with empty channel ID", + token: validToken, + chanName: "", + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Subtopic: "subtopic", + Publisher: validID, + }, + repoRes: readers.MessagesPage{}, + repoErr: nil, + response: sdk.MessagesPage{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingID, http.StatusBadRequest), + }, + { + desc: "read messages with invalid message page metadata", + token: validToken, + chanName: channelID, + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + Metadata: map[string]any{ + "key": make(chan int), + }, + }, + Subtopic: "subtopic", + Publisher: validID, + }, + repoRes: readers.MessagesPage{}, + repoErr: nil, + response: sdk.MessagesPage{}, + err: errors.NewSDKError(errors.New("json: unsupported type: chan int")), + }, + { + desc: "read messages with response that cannot be unmarshalled", + token: validToken, + chanName: channelID, + domainID: validID, + messagePageMeta: sdk.MessagePageMetadata{ + PageMetadata: sdk.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Subtopic: "subtopic", + Publisher: validID, + }, + repoRes: readers.MessagesPage{ + Total: 1, + Messages: []readers.Message{invalidMsg}, + }, + repoErr: nil, + response: sdk.MessagesPage{}, + err: errors.NewSDKError(errors.New("json: cannot unmarshal string into Go struct field MessagesPage.messages of type senml.Message")), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall1 := authn.On("Authenticate", mock.Anything, tc.token).Return(smqauthn.Session{UserID: validID}, tc.authnErr) + authzCall := channelsGRPCClient.On("Authorize", mock.Anything, mock.Anything).Return(&grpcChannelsV1.AuthzRes{Authorized: true}, tc.authzErr) + repoCall := repo.On("ReadAll", channelID, mock.Anything).Return(tc.repoRes, tc.repoErr) + response, err := mgsdk.ReadMessages(context.Background(), tc.messagePageMeta, tc.chanName, tc.domainID, tc.token) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.response, response) + if tc.err == nil { + ok := repoCall.Parent.AssertCalled(t, "ReadAll", channelID, mock.Anything) + assert.True(t, ok) + } + authCall1.Unset() + authzCall.Unset() + repoCall.Unset() + }) + } +} diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index fcf7f9f68..ce1d42c0f 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -11,6 +11,7 @@ package mocks import ( "context" + "github.com/absmach/supermq/certs" "github.com/absmach/supermq/pkg/errors" "github.com/absmach/supermq/pkg/sdk" mock "github.com/stretchr/testify/mock" @@ -106,6 +107,86 @@ func (_c *SDK_AcceptInvitation_Call) RunAndReturn(run func(ctx context.Context, return _c } +// AddBootstrap provides a mock function for the type SDK +func (_mock *SDK) AddBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for AddBootstrap") + } + + var r0 string + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) string); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_AddBootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddBootstrap' +type SDK_AddBootstrap_Call struct { + *mock.Call +} + +// AddBootstrap is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.BootstrapConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) AddBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_AddBootstrap_Call { + return &SDK_AddBootstrap_Call{Call: _e.mock.On("AddBootstrap", ctx, cfg, domainID, token)} +} + +func (_c *SDK_AddBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_AddBootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.BootstrapConfig + if args[1] != nil { + arg1 = args[1].(sdk.BootstrapConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_AddBootstrap_Call) Return(s string, sDKError errors.SDKError) *SDK_AddBootstrap_Call { + _c.Call.Return(s, sDKError) + return _c +} + +func (_c *SDK_AddBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) (string, errors.SDKError)) *SDK_AddBootstrap_Call { + _c.Call.Return(run) + return _c +} + // AddChildren provides a mock function for the type SDK func (_mock *SDK) AddChildren(ctx context.Context, id string, domainID string, groupIDs []string, token string) errors.SDKError { ret := _mock.Called(ctx, id, domainID, groupIDs, token) @@ -735,6 +816,166 @@ func (_c *SDK_AddGroupRoleMembers_Call) RunAndReturn(run func(ctx context.Contex return _c } +// AddReportConfig provides a mock function for the type SDK +func (_mock *SDK) AddReportConfig(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for AddReportConfig") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.ReportConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_AddReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddReportConfig' +type SDK_AddReportConfig_Call struct { + *mock.Call +} + +// AddReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.ReportConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) AddReportConfig(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_AddReportConfig_Call { + return &SDK_AddReportConfig_Call{Call: _e.mock.On("AddReportConfig", ctx, cfg, domainID, token)} +} + +func (_c *SDK_AddReportConfig_Call) Run(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string)) *SDK_AddReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.ReportConfig + if args[1] != nil { + arg1 = args[1].(sdk.ReportConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_AddReportConfig_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_AddReportConfig_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_AddReportConfig_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_AddReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// AddRule provides a mock function for the type SDK +func (_mock *SDK) AddRule(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, r, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for AddRule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, r, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, r, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Rule, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, r, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_AddRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRule' +type SDK_AddRule_Call struct { + *mock.Call +} + +// AddRule is a helper method to define mock.On call +// - ctx context.Context +// - r sdk.Rule +// - domainID string +// - token string +func (_e *SDK_Expecter) AddRule(ctx interface{}, r interface{}, domainID interface{}, token interface{}) *SDK_AddRule_Call { + return &SDK_AddRule_Call{Call: _e.mock.On("AddRule", ctx, r, domainID, token)} +} + +func (_c *SDK_AddRule_Call) Run(run func(ctx context.Context, r sdk.Rule, domainID string, token string)) *SDK_AddRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.Rule + if args[1] != nil { + arg1 = args[1].(sdk.Rule) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_AddRule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_AddRule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_AddRule_Call) RunAndReturn(run func(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_AddRule_Call { + _c.Call.Return(run) + return _c +} + // AvailableClientRoleActions provides a mock function for the type SDK func (_mock *SDK) AvailableClientRoleActions(ctx context.Context, domainID string, token string) ([]string, errors.SDKError) { ret := _mock.Called(ctx, domainID, token) @@ -957,6 +1198,240 @@ func (_c *SDK_AvailableGroupRoleActions_Call) RunAndReturn(run func(ctx context. return _c } +// Bootstrap provides a mock function for the type SDK +func (_mock *SDK) Bootstrap(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey) + + if len(ret) == 0 { + panic("no return value specified for Bootstrap") + } + + var r0 sdk.BootstrapConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey) + } else { + r0 = ret.Get(0).(sdk.BootstrapConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_Bootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Bootstrap' +type SDK_Bootstrap_Call struct { + *mock.Call +} + +// Bootstrap is a helper method to define mock.On call +// - ctx context.Context +// - externalID string +// - externalKey string +func (_e *SDK_Expecter) Bootstrap(ctx interface{}, externalID interface{}, externalKey interface{}) *SDK_Bootstrap_Call { + return &SDK_Bootstrap_Call{Call: _e.mock.On("Bootstrap", ctx, externalID, externalKey)} +} + +func (_c *SDK_Bootstrap_Call) Run(run func(ctx context.Context, externalID string, externalKey string)) *SDK_Bootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_Bootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKError errors.SDKError) *SDK_Bootstrap_Call { + _c.Call.Return(bootstrapConfig, sDKError) + return _c +} + +func (_c *SDK_Bootstrap_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_Bootstrap_Call { + _c.Call.Return(run) + return _c +} + +// BootstrapSecure provides a mock function for the type SDK +func (_mock *SDK) BootstrapSecure(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, externalID, externalKey, cryptoKey) + + if len(ret) == 0 { + panic("no return value specified for BootstrapSecure") + } + + var r0 sdk.BootstrapConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, externalID, externalKey, cryptoKey) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, externalID, externalKey, cryptoKey) + } else { + r0 = ret.Get(0).(sdk.BootstrapConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, externalID, externalKey, cryptoKey) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_BootstrapSecure_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BootstrapSecure' +type SDK_BootstrapSecure_Call struct { + *mock.Call +} + +// BootstrapSecure is a helper method to define mock.On call +// - ctx context.Context +// - externalID string +// - externalKey string +// - cryptoKey string +func (_e *SDK_Expecter) BootstrapSecure(ctx interface{}, externalID interface{}, externalKey interface{}, cryptoKey interface{}) *SDK_BootstrapSecure_Call { + return &SDK_BootstrapSecure_Call{Call: _e.mock.On("BootstrapSecure", ctx, externalID, externalKey, cryptoKey)} +} + +func (_c *SDK_BootstrapSecure_Call) Run(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string)) *SDK_BootstrapSecure_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_BootstrapSecure_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKError errors.SDKError) *SDK_BootstrapSecure_Call { + _c.Call.Return(bootstrapConfig, sDKError) + return _c +} + +func (_c *SDK_BootstrapSecure_Call) RunAndReturn(run func(ctx context.Context, externalID string, externalKey string, cryptoKey string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_BootstrapSecure_Call { + _c.Call.Return(run) + return _c +} + +// Bootstraps provides a mock function for the type SDK +func (_mock *SDK) Bootstraps(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for Bootstraps") + } + + var r0 sdk.BootstrapPage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.BootstrapPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.BootstrapPage); ok { + r0 = returnFunc(ctx, pm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.BootstrapPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_Bootstraps_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Bootstraps' +type SDK_Bootstraps_Call struct { + *mock.Call +} + +// Bootstraps is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - domainID string +// - token string +func (_e *SDK_Expecter) Bootstraps(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_Bootstraps_Call { + return &SDK_Bootstraps_Call{Call: _e.mock.On("Bootstraps", ctx, pm, domainID, token)} +} + +func (_c *SDK_Bootstraps_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_Bootstraps_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_Bootstraps_Call) Return(bootstrapPage sdk.BootstrapPage, sDKError errors.SDKError) *SDK_Bootstraps_Call { + _c.Call.Return(bootstrapPage, sDKError) + return _c +} + +func (_c *SDK_Bootstraps_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.BootstrapPage, errors.SDKError)) *SDK_Bootstraps_Call { + _c.Call.Return(run) + return _c +} + // Channel provides a mock function for the type SDK func (_mock *SDK) Channel(ctx context.Context, id string, domainID string, token string) (sdk.Channel, errors.SDKError) { ret := _mock.Called(ctx, id, domainID, token) @@ -1869,6 +2344,80 @@ func (_c *SDK_ConnectClients_Call) RunAndReturn(run func(ctx context.Context, ch return _c } +// CreateCSR provides a mock function for the type SDK +func (_mock *SDK) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKError) { + ret := _mock.Called(ctx, metadata, privKey) + + if len(ret) == 0 { + panic("no return value specified for CreateCSR") + } + + var r0 certs.CSR + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, any) (certs.CSR, errors.SDKError)); ok { + return returnFunc(ctx, metadata, privKey) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, any) certs.CSR); ok { + r0 = returnFunc(ctx, metadata, privKey) + } else { + r0 = ret.Get(0).(certs.CSR) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, any) errors.SDKError); ok { + r1 = returnFunc(ctx, metadata, privKey) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' +type SDK_CreateCSR_Call struct { + *mock.Call +} + +// CreateCSR is a helper method to define mock.On call +// - ctx context.Context +// - metadata certs.CSRMetadata +// - privKey any +func (_e *SDK_Expecter) CreateCSR(ctx interface{}, metadata interface{}, privKey interface{}) *SDK_CreateCSR_Call { + return &SDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", ctx, metadata, privKey)} +} + +func (_c *SDK_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, privKey any)) *SDK_CreateCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 certs.CSRMetadata + if args[1] != nil { + arg1 = args[1].(certs.CSRMetadata) + } + var arg2 any + if args[2] != nil { + arg2 = args[2].(any) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_CreateCSR_Call) Return(cSR certs.CSR, sDKError errors.SDKError) *SDK_CreateCSR_Call { + _c.Call.Return(cSR, sDKError) + return _c +} + +func (_c *SDK_CreateCSR_Call) RunAndReturn(run func(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKError)) *SDK_CreateCSR_Call { + _c.Call.Return(run) + return _c +} + // CreateChannel provides a mock function for the type SDK func (_mock *SDK) CreateChannel(ctx context.Context, channel sdk.Channel, domainID string, token string) (sdk.Channel, errors.SDKError) { ret := _mock.Called(ctx, channel, domainID, token) @@ -2599,6 +3148,86 @@ func (_c *SDK_CreateGroupRole_Call) RunAndReturn(run func(ctx context.Context, i return _c } +// CreateSubscription provides a mock function for the type SDK +func (_mock *SDK) CreateSubscription(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, topic, contact, token) + + if len(ret) == 0 { + panic("no return value specified for CreateSubscription") + } + + var r0 string + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, topic, contact, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) string); ok { + r0 = returnFunc(ctx, topic, contact, token) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, topic, contact, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_CreateSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateSubscription' +type SDK_CreateSubscription_Call struct { + *mock.Call +} + +// CreateSubscription is a helper method to define mock.On call +// - ctx context.Context +// - topic string +// - contact string +// - token string +func (_e *SDK_Expecter) CreateSubscription(ctx interface{}, topic interface{}, contact interface{}, token interface{}) *SDK_CreateSubscription_Call { + return &SDK_CreateSubscription_Call{Call: _e.mock.On("CreateSubscription", ctx, topic, contact, token)} +} + +func (_c *SDK_CreateSubscription_Call) Run(run func(ctx context.Context, topic string, contact string, token string)) *SDK_CreateSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_CreateSubscription_Call) Return(s string, sDKError errors.SDKError) *SDK_CreateSubscription_Call { + _c.Call.Return(s, sDKError) + return _c +} + +func (_c *SDK_CreateSubscription_Call) RunAndReturn(run func(ctx context.Context, topic string, contact string, token string) (string, errors.SDKError)) *SDK_CreateSubscription_Call { + _c.Call.Return(run) + return _c +} + // CreateToken provides a mock function for the type SDK func (_mock *SDK) CreateToken(ctx context.Context, lt sdk.Login) (sdk.Token, errors.SDKError) { ret := _mock.Called(ctx, lt) @@ -2741,6 +3370,148 @@ func (_c *SDK_CreateUser_Call) RunAndReturn(run func(ctx context.Context, user s return _c } +// DeleteAlarm provides a mock function for the type SDK +func (_mock *SDK) DeleteAlarm(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for DeleteAlarm") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_DeleteAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteAlarm' +type SDK_DeleteAlarm_Call struct { + *mock.Call +} + +// DeleteAlarm is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) DeleteAlarm(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_DeleteAlarm_Call { + return &SDK_DeleteAlarm_Call{Call: _e.mock.On("DeleteAlarm", ctx, id, domainID, token)} +} + +func (_c *SDK_DeleteAlarm_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_DeleteAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_DeleteAlarm_Call) Return(sDKError errors.SDKError) *SDK_DeleteAlarm_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_DeleteAlarm_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_DeleteAlarm_Call { + _c.Call.Return(run) + return _c +} + +// DeleteCert provides a mock function for the type SDK +func (_mock *SDK) DeleteCert(ctx context.Context, entityID string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, entityID, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for DeleteCert") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, entityID, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_DeleteCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCert' +type SDK_DeleteCert_Call struct { + *mock.Call +} + +// DeleteCert is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - domainID string +// - token string +func (_e *SDK_Expecter) DeleteCert(ctx interface{}, entityID interface{}, domainID interface{}, token interface{}) *SDK_DeleteCert_Call { + return &SDK_DeleteCert_Call{Call: _e.mock.On("DeleteCert", ctx, entityID, domainID, token)} +} + +func (_c *SDK_DeleteCert_Call) Run(run func(ctx context.Context, entityID string, domainID string, token string)) *SDK_DeleteCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_DeleteCert_Call) Return(sDKError errors.SDKError) *SDK_DeleteCert_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_DeleteCert_Call) RunAndReturn(run func(ctx context.Context, entityID string, domainID string, token string) errors.SDKError) *SDK_DeleteCert_Call { + _c.Call.Return(run) + return _c +} + // DeleteChannel provides a mock function for the type SDK func (_mock *SDK) DeleteChannel(ctx context.Context, id string, domainID string, token string) errors.SDKError { ret := _mock.Called(ctx, id, domainID, token) @@ -3248,6 +4019,142 @@ func (_c *SDK_DeleteInvitation_Call) RunAndReturn(run func(ctx context.Context, return _c } +// DeleteReportTemplate provides a mock function for the type SDK +func (_mock *SDK) DeleteReportTemplate(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for DeleteReportTemplate") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_DeleteReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteReportTemplate' +type SDK_DeleteReportTemplate_Call struct { + *mock.Call +} + +// DeleteReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) DeleteReportTemplate(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_DeleteReportTemplate_Call { + return &SDK_DeleteReportTemplate_Call{Call: _e.mock.On("DeleteReportTemplate", ctx, id, domainID, token)} +} + +func (_c *SDK_DeleteReportTemplate_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_DeleteReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_DeleteReportTemplate_Call) Return(sDKError errors.SDKError) *SDK_DeleteReportTemplate_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_DeleteReportTemplate_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_DeleteReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// DeleteSubscription provides a mock function for the type SDK +func (_mock *SDK) DeleteSubscription(ctx context.Context, id string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, token) + + if len(ret) == 0 { + panic("no return value specified for DeleteSubscription") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_DeleteSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteSubscription' +type SDK_DeleteSubscription_Call struct { + *mock.Call +} + +// DeleteSubscription is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - token string +func (_e *SDK_Expecter) DeleteSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_DeleteSubscription_Call { + return &SDK_DeleteSubscription_Call{Call: _e.mock.On("DeleteSubscription", ctx, id, token)} +} + +func (_c *SDK_DeleteSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_DeleteSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_DeleteSubscription_Call) Return(sDKError errors.SDKError) *SDK_DeleteSubscription_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_DeleteSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) errors.SDKError) *SDK_DeleteSubscription_Call { + _c.Call.Return(run) + return _c +} + // DeleteUser provides a mock function for the type SDK func (_mock *SDK) DeleteUser(ctx context.Context, id string, token string) errors.SDKError { ret := _mock.Called(ctx, id, token) @@ -3618,6 +4525,166 @@ func (_c *SDK_DisableGroup_Call) RunAndReturn(run func(ctx context.Context, id s return _c } +// DisableReportConfig provides a mock function for the type SDK +func (_mock *SDK) DisableReportConfig(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for DisableReportConfig") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_DisableReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisableReportConfig' +type SDK_DisableReportConfig_Call struct { + *mock.Call +} + +// DisableReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) DisableReportConfig(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_DisableReportConfig_Call { + return &SDK_DisableReportConfig_Call{Call: _e.mock.On("DisableReportConfig", ctx, id, domainID, token)} +} + +func (_c *SDK_DisableReportConfig_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_DisableReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_DisableReportConfig_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_DisableReportConfig_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_DisableReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_DisableReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// DisableRule provides a mock function for the type SDK +func (_mock *SDK) DisableRule(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for DisableRule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_DisableRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisableRule' +type SDK_DisableRule_Call struct { + *mock.Call +} + +// DisableRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) DisableRule(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_DisableRule_Call { + return &SDK_DisableRule_Call{Call: _e.mock.On("DisableRule", ctx, id, domainID, token)} +} + +func (_c *SDK_DisableRule_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_DisableRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_DisableRule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_DisableRule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_DisableRule_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_DisableRule_Call { + _c.Call.Return(run) + return _c +} + // DisableUser provides a mock function for the type SDK func (_mock *SDK) DisableUser(ctx context.Context, id string, token string) (sdk.User, errors.SDKError) { ret := _mock.Called(ctx, id, token) @@ -4400,6 +5467,68 @@ func (_c *SDK_Domains_Call) RunAndReturn(run func(ctx context.Context, pm sdk.Pa return _c } +// DownloadCA provides a mock function for the type SDK +func (_mock *SDK) DownloadCA(ctx context.Context) (sdk.CertificateBundle, errors.SDKError) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for DownloadCA") + } + + var r0 sdk.CertificateBundle + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context) (sdk.CertificateBundle, errors.SDKError)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) sdk.CertificateBundle); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Get(0).(sdk.CertificateBundle) + } + if returnFunc, ok := ret.Get(1).(func(context.Context) errors.SDKError); ok { + r1 = returnFunc(ctx) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_DownloadCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DownloadCA' +type SDK_DownloadCA_Call struct { + *mock.Call +} + +// DownloadCA is a helper method to define mock.On call +// - ctx context.Context +func (_e *SDK_Expecter) DownloadCA(ctx interface{}) *SDK_DownloadCA_Call { + return &SDK_DownloadCA_Call{Call: _e.mock.On("DownloadCA", ctx)} +} + +func (_c *SDK_DownloadCA_Call) Run(run func(ctx context.Context)) *SDK_DownloadCA_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *SDK_DownloadCA_Call) Return(certificateBundle sdk.CertificateBundle, sDKError errors.SDKError) *SDK_DownloadCA_Call { + _c.Call.Return(certificateBundle, sDKError) + return _c +} + +func (_c *SDK_DownloadCA_Call) RunAndReturn(run func(ctx context.Context) (sdk.CertificateBundle, errors.SDKError)) *SDK_DownloadCA_Call { + _c.Call.Return(run) + return _c +} + // EnableChannel provides a mock function for the type SDK func (_mock *SDK) EnableChannel(ctx context.Context, id string, domainID string, token string) (sdk.Channel, errors.SDKError) { ret := _mock.Called(ctx, id, domainID, token) @@ -4705,6 +5834,166 @@ func (_c *SDK_EnableGroup_Call) RunAndReturn(run func(ctx context.Context, id st return _c } +// EnableReportConfig provides a mock function for the type SDK +func (_mock *SDK) EnableReportConfig(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for EnableReportConfig") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_EnableReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EnableReportConfig' +type SDK_EnableReportConfig_Call struct { + *mock.Call +} + +// EnableReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) EnableReportConfig(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_EnableReportConfig_Call { + return &SDK_EnableReportConfig_Call{Call: _e.mock.On("EnableReportConfig", ctx, id, domainID, token)} +} + +func (_c *SDK_EnableReportConfig_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_EnableReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_EnableReportConfig_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_EnableReportConfig_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_EnableReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_EnableReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// EnableRule provides a mock function for the type SDK +func (_mock *SDK) EnableRule(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for EnableRule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_EnableRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EnableRule' +type SDK_EnableRule_Call struct { + *mock.Call +} + +// EnableRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) EnableRule(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_EnableRule_Call { + return &SDK_EnableRule_Call{Call: _e.mock.On("EnableRule", ctx, id, domainID, token)} +} + +func (_c *SDK_EnableRule_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_EnableRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_EnableRule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_EnableRule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_EnableRule_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_EnableRule_Call { + _c.Call.Return(run) + return _c +} + // EnableUser provides a mock function for the type SDK func (_mock *SDK) EnableUser(ctx context.Context, id string, token string) (sdk.User, errors.SDKError) { ret := _mock.Called(ctx, id, token) @@ -4779,6 +6068,86 @@ func (_c *SDK_EnableUser_Call) RunAndReturn(run func(ctx context.Context, id str return _c } +// EntityID provides a mock function for the type SDK +func (_mock *SDK) EntityID(ctx context.Context, serialNumber string, domainID string, token string) (string, errors.SDKError) { + ret := _mock.Called(ctx, serialNumber, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for EntityID") + } + + var r0 string + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (string, errors.SDKError)); ok { + return returnFunc(ctx, serialNumber, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) string); ok { + r0 = returnFunc(ctx, serialNumber, domainID, token) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, serialNumber, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_EntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EntityID' +type SDK_EntityID_Call struct { + *mock.Call +} + +// EntityID is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - domainID string +// - token string +func (_e *SDK_Expecter) EntityID(ctx interface{}, serialNumber interface{}, domainID interface{}, token interface{}) *SDK_EntityID_Call { + return &SDK_EntityID_Call{Call: _e.mock.On("EntityID", ctx, serialNumber, domainID, token)} +} + +func (_c *SDK_EntityID_Call) Run(run func(ctx context.Context, serialNumber string, domainID string, token string)) *SDK_EntityID_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_EntityID_Call) Return(s string, sDKError errors.SDKError) *SDK_EntityID_Call { + _c.Call.Return(s, sDKError) + return _c +} + +func (_c *SDK_EntityID_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, domainID string, token string) (string, errors.SDKError)) *SDK_EntityID_Call { + _c.Call.Return(run) + return _c +} + // FreezeDomain provides a mock function for the type SDK func (_mock *SDK) FreezeDomain(ctx context.Context, domainID string, token string) errors.SDKError { ret := _mock.Called(ctx, domainID, token) @@ -4844,6 +6213,164 @@ func (_c *SDK_FreezeDomain_Call) RunAndReturn(run func(ctx context.Context, doma return _c } +// GenerateCRL provides a mock function for the type SDK +func (_mock *SDK) GenerateCRL(ctx context.Context) ([]byte, errors.SDKError) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for GenerateCRL") + } + + var r0 []byte + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context) ([]byte, errors.SDKError)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) []byte); ok { + r0 = returnFunc(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context) errors.SDKError); ok { + r1 = returnFunc(ctx) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_GenerateCRL_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenerateCRL' +type SDK_GenerateCRL_Call struct { + *mock.Call +} + +// GenerateCRL is a helper method to define mock.On call +// - ctx context.Context +func (_e *SDK_Expecter) GenerateCRL(ctx interface{}) *SDK_GenerateCRL_Call { + return &SDK_GenerateCRL_Call{Call: _e.mock.On("GenerateCRL", ctx)} +} + +func (_c *SDK_GenerateCRL_Call) Run(run func(ctx context.Context)) *SDK_GenerateCRL_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *SDK_GenerateCRL_Call) Return(bytes []byte, sDKError errors.SDKError) *SDK_GenerateCRL_Call { + _c.Call.Return(bytes, sDKError) + return _c +} + +func (_c *SDK_GenerateCRL_Call) RunAndReturn(run func(ctx context.Context) ([]byte, errors.SDKError)) *SDK_GenerateCRL_Call { + _c.Call.Return(run) + return _c +} + +// GenerateReport provides a mock function for the type SDK +func (_mock *SDK) GenerateReport(ctx context.Context, config sdk.ReportConfig, action sdk.ReportAction, domainID string, token string) (sdk.ReportPage, *sdk.ReportFile, errors.SDKError) { + ret := _mock.Called(ctx, config, action, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for GenerateReport") + } + + var r0 sdk.ReportPage + var r1 *sdk.ReportFile + var r2 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, sdk.ReportAction, string, string) (sdk.ReportPage, *sdk.ReportFile, errors.SDKError)); ok { + return returnFunc(ctx, config, action, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, sdk.ReportAction, string, string) sdk.ReportPage); ok { + r0 = returnFunc(ctx, config, action, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.ReportConfig, sdk.ReportAction, string, string) *sdk.ReportFile); ok { + r1 = returnFunc(ctx, config, action, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*sdk.ReportFile) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, sdk.ReportConfig, sdk.ReportAction, string, string) errors.SDKError); ok { + r2 = returnFunc(ctx, config, action, domainID, token) + } else { + if ret.Get(2) != nil { + r2 = ret.Get(2).(errors.SDKError) + } + } + return r0, r1, r2 +} + +// SDK_GenerateReport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenerateReport' +type SDK_GenerateReport_Call struct { + *mock.Call +} + +// GenerateReport is a helper method to define mock.On call +// - ctx context.Context +// - config sdk.ReportConfig +// - action sdk.ReportAction +// - domainID string +// - token string +func (_e *SDK_Expecter) GenerateReport(ctx interface{}, config interface{}, action interface{}, domainID interface{}, token interface{}) *SDK_GenerateReport_Call { + return &SDK_GenerateReport_Call{Call: _e.mock.On("GenerateReport", ctx, config, action, domainID, token)} +} + +func (_c *SDK_GenerateReport_Call) Run(run func(ctx context.Context, config sdk.ReportConfig, action sdk.ReportAction, domainID string, token string)) *SDK_GenerateReport_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.ReportConfig + if args[1] != nil { + arg1 = args[1].(sdk.ReportConfig) + } + var arg2 sdk.ReportAction + if args[2] != nil { + arg2 = args[2].(sdk.ReportAction) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_GenerateReport_Call) Return(reportPage sdk.ReportPage, reportFile *sdk.ReportFile, sDKError errors.SDKError) *SDK_GenerateReport_Call { + _c.Call.Return(reportPage, reportFile, sDKError) + return _c +} + +func (_c *SDK_GenerateReport_Call) RunAndReturn(run func(ctx context.Context, config sdk.ReportConfig, action sdk.ReportAction, domainID string, token string) (sdk.ReportPage, *sdk.ReportFile, errors.SDKError)) *SDK_GenerateReport_Call { + _c.Call.Return(run) + return _c +} + // Group provides a mock function for the type SDK func (_mock *SDK) Group(ctx context.Context, id string, domainID string, token string) (sdk.Group, errors.SDKError) { ret := _mock.Called(ctx, id, domainID, token) @@ -5576,6 +7103,282 @@ func (_c *SDK_Invitations_Call) RunAndReturn(run func(ctx context.Context, pm sd return _c } +// IssueCert provides a mock function for the type SDK +func (_mock *SDK) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, opts sdk.Options, domainID string, token string) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx, entityID, ttl, ipAddrs, opts, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for IssueCert") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []string, sdk.Options, string, string) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx, entityID, ttl, ipAddrs, opts, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, []string, sdk.Options, string, string) sdk.Certificate); ok { + r0 = returnFunc(ctx, entityID, ttl, ipAddrs, opts, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, []string, sdk.Options, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, entityID, ttl, ipAddrs, opts, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_IssueCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueCert' +type SDK_IssueCert_Call struct { + *mock.Call +} + +// IssueCert is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - ttl string +// - ipAddrs []string +// - opts sdk.Options +// - domainID string +// - token string +func (_e *SDK_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, opts interface{}, domainID interface{}, token interface{}) *SDK_IssueCert_Call { + return &SDK_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, opts, domainID, token)} +} + +func (_c *SDK_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, opts sdk.Options, domainID string, token string)) *SDK_IssueCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 []string + if args[3] != nil { + arg3 = args[3].([]string) + } + var arg4 sdk.Options + if args[4] != nil { + arg4 = args[4].(sdk.Options) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + var arg6 string + if args[6] != nil { + arg6 = args[6].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + arg6, + ) + }) + return _c +} + +func (_c *SDK_IssueCert_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_IssueCert_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_IssueCert_Call) RunAndReturn(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, opts sdk.Options, domainID string, token string) (sdk.Certificate, errors.SDKError)) *SDK_IssueCert_Call { + _c.Call.Return(run) + return _c +} + +// IssueFromCSR provides a mock function for the type SDK +func (_mock *SDK) IssueFromCSR(ctx context.Context, entityID string, ttl string, csr string, domainID string, token string) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx, entityID, ttl, csr, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSR") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx, entityID, ttl, csr, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) sdk.Certificate); ok { + r0 = returnFunc(ctx, entityID, ttl, csr, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, entityID, ttl, csr, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_IssueFromCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSR' +type SDK_IssueFromCSR_Call struct { + *mock.Call +} + +// IssueFromCSR is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - ttl string +// - csr string +// - domainID string +// - token string +func (_e *SDK_Expecter) IssueFromCSR(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}, domainID interface{}, token interface{}) *SDK_IssueFromCSR_Call { + return &SDK_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", ctx, entityID, ttl, csr, domainID, token)} +} + +func (_c *SDK_IssueFromCSR_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr string, domainID string, token string)) *SDK_IssueFromCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *SDK_IssueFromCSR_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_IssueFromCSR_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_IssueFromCSR_Call) RunAndReturn(run func(ctx context.Context, entityID string, ttl string, csr string, domainID string, token string) (sdk.Certificate, errors.SDKError)) *SDK_IssueFromCSR_Call { + _c.Call.Return(run) + return _c +} + +// IssueFromCSRInternal provides a mock function for the type SDK +func (_mock *SDK) IssueFromCSRInternal(ctx context.Context, entityID string, ttl string, csr string, token string) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx, entityID, ttl, csr, token) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSRInternal") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx, entityID, ttl, csr, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) sdk.Certificate); ok { + r0 = returnFunc(ctx, entityID, ttl, csr, token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, entityID, ttl, csr, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_IssueFromCSRInternal_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSRInternal' +type SDK_IssueFromCSRInternal_Call struct { + *mock.Call +} + +// IssueFromCSRInternal is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - ttl string +// - csr string +// - token string +func (_e *SDK_Expecter) IssueFromCSRInternal(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}, token interface{}) *SDK_IssueFromCSRInternal_Call { + return &SDK_IssueFromCSRInternal_Call{Call: _e.mock.On("IssueFromCSRInternal", ctx, entityID, ttl, csr, token)} +} + +func (_c *SDK_IssueFromCSRInternal_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr string, token string)) *SDK_IssueFromCSRInternal_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_IssueFromCSRInternal_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_IssueFromCSRInternal_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_IssueFromCSRInternal_Call) RunAndReturn(run func(ctx context.Context, entityID string, ttl string, csr string, token string) (sdk.Certificate, errors.SDKError)) *SDK_IssueFromCSRInternal_Call { + _c.Call.Return(run) + return _c +} + // Journal provides a mock function for the type SDK func (_mock *SDK) Journal(ctx context.Context, entityType string, entityID string, domainID string, pm sdk.PageMetadata, token string) (sdk.JournalsPage, error) { ret := _mock.Called(ctx, entityType, entityID, domainID, pm, token) @@ -5666,6 +7469,166 @@ func (_c *SDK_Journal_Call) RunAndReturn(run func(ctx context.Context, entityTyp return _c } +// ListAlarms provides a mock function for the type SDK +func (_mock *SDK) ListAlarms(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.AlarmsPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ListAlarms") + } + + var r0 sdk.AlarmsPage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.AlarmsPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.AlarmsPage); ok { + r0 = returnFunc(ctx, pm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.AlarmsPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ListAlarms_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAlarms' +type SDK_ListAlarms_Call struct { + *mock.Call +} + +// ListAlarms is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - domainID string +// - token string +func (_e *SDK_Expecter) ListAlarms(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_ListAlarms_Call { + return &SDK_ListAlarms_Call{Call: _e.mock.On("ListAlarms", ctx, pm, domainID, token)} +} + +func (_c *SDK_ListAlarms_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_ListAlarms_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ListAlarms_Call) Return(alarmsPage sdk.AlarmsPage, sDKError errors.SDKError) *SDK_ListAlarms_Call { + _c.Call.Return(alarmsPage, sDKError) + return _c +} + +func (_c *SDK_ListAlarms_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.AlarmsPage, errors.SDKError)) *SDK_ListAlarms_Call { + _c.Call.Return(run) + return _c +} + +// ListCerts provides a mock function for the type SDK +func (_mock *SDK) ListCerts(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.CertificatePage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ListCerts") + } + + var r0 sdk.CertificatePage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.CertificatePage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.CertificatePage); ok { + r0 = returnFunc(ctx, pm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.CertificatePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ListCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCerts' +type SDK_ListCerts_Call struct { + *mock.Call +} + +// ListCerts is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - domainID string +// - token string +func (_e *SDK_Expecter) ListCerts(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_ListCerts_Call { + return &SDK_ListCerts_Call{Call: _e.mock.On("ListCerts", ctx, pm, domainID, token)} +} + +func (_c *SDK_ListCerts_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_ListCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ListCerts_Call) Return(certificatePage sdk.CertificatePage, sDKError errors.SDKError) *SDK_ListCerts_Call { + _c.Call.Return(certificatePage, sDKError) + return _c +} + +func (_c *SDK_ListCerts_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.CertificatePage, errors.SDKError)) *SDK_ListCerts_Call { + _c.Call.Return(run) + return _c +} + // ListChannelMembers provides a mock function for the type SDK func (_mock *SDK) ListChannelMembers(ctx context.Context, channelID string, domainID string, pm sdk.PageMetadata, token string) (sdk.EntityMembersPage, errors.SDKError) { ret := _mock.Called(ctx, channelID, domainID, pm, token) @@ -6004,6 +7967,400 @@ func (_c *SDK_ListGroupMembers_Call) RunAndReturn(run func(ctx context.Context, return _c } +// ListReportsConfig provides a mock function for the type SDK +func (_mock *SDK) ListReportsConfig(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.ReportConfigPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ListReportsConfig") + } + + var r0 sdk.ReportConfigPage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.ReportConfigPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.ReportConfigPage); ok { + r0 = returnFunc(ctx, pm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfigPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ListReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListReportsConfig' +type SDK_ListReportsConfig_Call struct { + *mock.Call +} + +// ListReportsConfig is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - domainID string +// - token string +func (_e *SDK_Expecter) ListReportsConfig(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_ListReportsConfig_Call { + return &SDK_ListReportsConfig_Call{Call: _e.mock.On("ListReportsConfig", ctx, pm, domainID, token)} +} + +func (_c *SDK_ListReportsConfig_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_ListReportsConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ListReportsConfig_Call) Return(reportConfigPage sdk.ReportConfigPage, sDKError errors.SDKError) *SDK_ListReportsConfig_Call { + _c.Call.Return(reportConfigPage, sDKError) + return _c +} + +func (_c *SDK_ListReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.ReportConfigPage, errors.SDKError)) *SDK_ListReportsConfig_Call { + _c.Call.Return(run) + return _c +} + +// ListRules provides a mock function for the type SDK +func (_mock *SDK) ListRules(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.Page, errors.SDKError) { + ret := _mock.Called(ctx, pm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ListRules") + } + + var r0 sdk.Page + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) (sdk.Page, errors.SDKError)); ok { + return returnFunc(ctx, pm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string, string) sdk.Page); ok { + r0 = returnFunc(ctx, pm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ListRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRules' +type SDK_ListRules_Call struct { + *mock.Call +} + +// ListRules is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - domainID string +// - token string +func (_e *SDK_Expecter) ListRules(ctx interface{}, pm interface{}, domainID interface{}, token interface{}) *SDK_ListRules_Call { + return &SDK_ListRules_Call{Call: _e.mock.On("ListRules", ctx, pm, domainID, token)} +} + +func (_c *SDK_ListRules_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string)) *SDK_ListRules_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ListRules_Call) Return(page sdk.Page, sDKError errors.SDKError) *SDK_ListRules_Call { + _c.Call.Return(page, sDKError) + return _c +} + +func (_c *SDK_ListRules_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, domainID string, token string) (sdk.Page, errors.SDKError)) *SDK_ListRules_Call { + _c.Call.Return(run) + return _c +} + +// ListSubscriptions provides a mock function for the type SDK +func (_mock *SDK) ListSubscriptions(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, token) + + if len(ret) == 0 { + panic("no return value specified for ListSubscriptions") + } + + var r0 sdk.SubscriptionPage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) (sdk.SubscriptionPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.PageMetadata, string) sdk.SubscriptionPage); ok { + r0 = returnFunc(ctx, pm, token) + } else { + r0 = ret.Get(0).(sdk.SubscriptionPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.PageMetadata, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ListSubscriptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListSubscriptions' +type SDK_ListSubscriptions_Call struct { + *mock.Call +} + +// ListSubscriptions is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.PageMetadata +// - token string +func (_e *SDK_Expecter) ListSubscriptions(ctx interface{}, pm interface{}, token interface{}) *SDK_ListSubscriptions_Call { + return &SDK_ListSubscriptions_Call{Call: _e.mock.On("ListSubscriptions", ctx, pm, token)} +} + +func (_c *SDK_ListSubscriptions_Call) Run(run func(ctx context.Context, pm sdk.PageMetadata, token string)) *SDK_ListSubscriptions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.PageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.PageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_ListSubscriptions_Call) Return(subscriptionPage sdk.SubscriptionPage, sDKError errors.SDKError) *SDK_ListSubscriptions_Call { + _c.Call.Return(subscriptionPage, sDKError) + return _c +} + +func (_c *SDK_ListSubscriptions_Call) RunAndReturn(run func(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.SubscriptionPage, errors.SDKError)) *SDK_ListSubscriptions_Call { + _c.Call.Return(run) + return _c +} + +// OCSP provides a mock function for the type SDK +func (_mock *SDK) OCSP(ctx context.Context, serialNumber string, cert string) (sdk.OCSPResponse, errors.SDKError) { + ret := _mock.Called(ctx, serialNumber, cert) + + if len(ret) == 0 { + panic("no return value specified for OCSP") + } + + var r0 sdk.OCSPResponse + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.OCSPResponse, errors.SDKError)); ok { + return returnFunc(ctx, serialNumber, cert) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.OCSPResponse); ok { + r0 = returnFunc(ctx, serialNumber, cert) + } else { + r0 = ret.Get(0).(sdk.OCSPResponse) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, serialNumber, cert) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_OCSP_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'OCSP' +type SDK_OCSP_Call struct { + *mock.Call +} + +// OCSP is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - cert string +func (_e *SDK_Expecter) OCSP(ctx interface{}, serialNumber interface{}, cert interface{}) *SDK_OCSP_Call { + return &SDK_OCSP_Call{Call: _e.mock.On("OCSP", ctx, serialNumber, cert)} +} + +func (_c *SDK_OCSP_Call) Run(run func(ctx context.Context, serialNumber string, cert string)) *SDK_OCSP_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_OCSP_Call) Return(oCSPResponse sdk.OCSPResponse, sDKError errors.SDKError) *SDK_OCSP_Call { + _c.Call.Return(oCSPResponse, sDKError) + return _c +} + +func (_c *SDK_OCSP_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, cert string) (sdk.OCSPResponse, errors.SDKError)) *SDK_OCSP_Call { + _c.Call.Return(run) + return _c +} + +// ReadMessages provides a mock function for the type SDK +func (_mock *SDK) ReadMessages(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError) { + ret := _mock.Called(ctx, pm, chanID, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ReadMessages") + } + + var r0 sdk.MessagesPage + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) (sdk.MessagesPage, errors.SDKError)); ok { + return returnFunc(ctx, pm, chanID, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.MessagePageMetadata, string, string, string) sdk.MessagesPage); ok { + r0 = returnFunc(ctx, pm, chanID, domainID, token) + } else { + r0 = ret.Get(0).(sdk.MessagesPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.MessagePageMetadata, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, pm, chanID, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ReadMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadMessages' +type SDK_ReadMessages_Call struct { + *mock.Call +} + +// ReadMessages is a helper method to define mock.On call +// - ctx context.Context +// - pm sdk.MessagePageMetadata +// - chanID string +// - domainID string +// - token string +func (_e *SDK_Expecter) ReadMessages(ctx interface{}, pm interface{}, chanID interface{}, domainID interface{}, token interface{}) *SDK_ReadMessages_Call { + return &SDK_ReadMessages_Call{Call: _e.mock.On("ReadMessages", ctx, pm, chanID, domainID, token)} +} + +func (_c *SDK_ReadMessages_Call) Run(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string)) *SDK_ReadMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.MessagePageMetadata + if args[1] != nil { + arg1 = args[1].(sdk.MessagePageMetadata) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_ReadMessages_Call) Return(messagesPage sdk.MessagesPage, sDKError errors.SDKError) *SDK_ReadMessages_Call { + _c.Call.Return(messagesPage, sDKError) + return _c +} + +func (_c *SDK_ReadMessages_Call) RunAndReturn(run func(ctx context.Context, pm sdk.MessagePageMetadata, chanID string, domainID string, token string) (sdk.MessagesPage, errors.SDKError)) *SDK_ReadMessages_Call { + _c.Call.Return(run) + return _c +} + // RefreshToken provides a mock function for the type SDK func (_mock *SDK) RefreshToken(ctx context.Context, token string) (sdk.Token, errors.SDKError) { ret := _mock.Called(ctx, token) @@ -6656,6 +9013,77 @@ func (_c *SDK_RemoveAllGroupRoleMembers_Call) RunAndReturn(run func(ctx context. return _c } +// RemoveBootstrap provides a mock function for the type SDK +func (_mock *SDK) RemoveBootstrap(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RemoveBootstrap") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_RemoveBootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveBootstrap' +type SDK_RemoveBootstrap_Call struct { + *mock.Call +} + +// RemoveBootstrap is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) RemoveBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_RemoveBootstrap_Call { + return &SDK_RemoveBootstrap_Call{Call: _e.mock.On("RemoveBootstrap", ctx, id, domainID, token)} +} + +func (_c *SDK_RemoveBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_RemoveBootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RemoveBootstrap_Call) Return(sDKError errors.SDKError) *SDK_RemoveBootstrap_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_RemoveBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_RemoveBootstrap_Call { + _c.Call.Return(run) + return _c +} + // RemoveChannelParent provides a mock function for the type SDK func (_mock *SDK) RemoveChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError { ret := _mock.Called(ctx, id, domainID, groupID, token) @@ -7450,6 +9878,228 @@ func (_c *SDK_RemoveGroupRoleMembers_Call) RunAndReturn(run func(ctx context.Con return _c } +// RemoveReportConfig provides a mock function for the type SDK +func (_mock *SDK) RemoveReportConfig(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RemoveReportConfig") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_RemoveReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveReportConfig' +type SDK_RemoveReportConfig_Call struct { + *mock.Call +} + +// RemoveReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) RemoveReportConfig(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_RemoveReportConfig_Call { + return &SDK_RemoveReportConfig_Call{Call: _e.mock.On("RemoveReportConfig", ctx, id, domainID, token)} +} + +func (_c *SDK_RemoveReportConfig_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_RemoveReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RemoveReportConfig_Call) Return(sDKError errors.SDKError) *SDK_RemoveReportConfig_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_RemoveReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_RemoveReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRule provides a mock function for the type SDK +func (_mock *SDK) RemoveRule(ctx context.Context, id string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RemoveRule") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_RemoveRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRule' +type SDK_RemoveRule_Call struct { + *mock.Call +} + +// RemoveRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) RemoveRule(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_RemoveRule_Call { + return &SDK_RemoveRule_Call{Call: _e.mock.On("RemoveRule", ctx, id, domainID, token)} +} + +func (_c *SDK_RemoveRule_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_RemoveRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RemoveRule_Call) Return(sDKError errors.SDKError) *SDK_RemoveRule_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_RemoveRule_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) errors.SDKError) *SDK_RemoveRule_Call { + _c.Call.Return(run) + return _c +} + +// RenewCert provides a mock function for the type SDK +func (_mock *SDK) RenewCert(ctx context.Context, serialNumber string, domainID string, token string) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx, serialNumber, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RenewCert") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx, serialNumber, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Certificate); ok { + r0 = returnFunc(ctx, serialNumber, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, serialNumber, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_RenewCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RenewCert' +type SDK_RenewCert_Call struct { + *mock.Call +} + +// RenewCert is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - domainID string +// - token string +func (_e *SDK_Expecter) RenewCert(ctx interface{}, serialNumber interface{}, domainID interface{}, token interface{}) *SDK_RenewCert_Call { + return &SDK_RenewCert_Call{Call: _e.mock.On("RenewCert", ctx, serialNumber, domainID, token)} +} + +func (_c *SDK_RenewCert_Call) Run(run func(ctx context.Context, serialNumber string, domainID string, token string)) *SDK_RenewCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RenewCert_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_RenewCert_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_RenewCert_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, domainID string, token string) (sdk.Certificate, errors.SDKError)) *SDK_RenewCert_Call { + _c.Call.Return(run) + return _c +} + // ResetPassword provides a mock function for the type SDK func (_mock *SDK) ResetPassword(ctx context.Context, password string, confPass string, token string) errors.SDKError { ret := _mock.Called(ctx, password, confPass, token) @@ -7580,6 +10230,148 @@ func (_c *SDK_ResetPasswordRequest_Call) RunAndReturn(run func(ctx context.Conte return _c } +// RevokeAll provides a mock function for the type SDK +func (_mock *SDK) RevokeAll(ctx context.Context, entityID string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, entityID, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeAll") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, entityID, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_RevokeAll_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeAll' +type SDK_RevokeAll_Call struct { + *mock.Call +} + +// RevokeAll is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - domainID string +// - token string +func (_e *SDK_Expecter) RevokeAll(ctx interface{}, entityID interface{}, domainID interface{}, token interface{}) *SDK_RevokeAll_Call { + return &SDK_RevokeAll_Call{Call: _e.mock.On("RevokeAll", ctx, entityID, domainID, token)} +} + +func (_c *SDK_RevokeAll_Call) Run(run func(ctx context.Context, entityID string, domainID string, token string)) *SDK_RevokeAll_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RevokeAll_Call) Return(sDKError errors.SDKError) *SDK_RevokeAll_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_RevokeAll_Call) RunAndReturn(run func(ctx context.Context, entityID string, domainID string, token string) errors.SDKError) *SDK_RevokeAll_Call { + _c.Call.Return(run) + return _c +} + +// RevokeCert provides a mock function for the type SDK +func (_mock *SDK) RevokeCert(ctx context.Context, serialNumber string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, serialNumber, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeCert") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, serialNumber, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_RevokeCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeCert' +type SDK_RevokeCert_Call struct { + *mock.Call +} + +// RevokeCert is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - domainID string +// - token string +func (_e *SDK_Expecter) RevokeCert(ctx interface{}, serialNumber interface{}, domainID interface{}, token interface{}) *SDK_RevokeCert_Call { + return &SDK_RevokeCert_Call{Call: _e.mock.On("RevokeCert", ctx, serialNumber, domainID, token)} +} + +func (_c *SDK_RevokeCert_Call) Run(run func(ctx context.Context, serialNumber string, domainID string, token string)) *SDK_RevokeCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_RevokeCert_Call) Return(sDKError errors.SDKError) *SDK_RevokeCert_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_RevokeCert_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, domainID string, token string) errors.SDKError) *SDK_RevokeCert_Call { + _c.Call.Return(run) + return _c +} + // SearchUsers provides a mock function for the type SDK func (_mock *SDK) SearchUsers(ctx context.Context, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) { ret := _mock.Called(ctx, pm, token) @@ -8137,6 +10929,332 @@ func (_c *SDK_SetGroupParent_Call) RunAndReturn(run func(ctx context.Context, id return _c } +// UpdateAlarm provides a mock function for the type SDK +func (_mock *SDK) UpdateAlarm(ctx context.Context, alarm sdk.Alarm, domainID string, token string) (sdk.Alarm, errors.SDKError) { + ret := _mock.Called(ctx, alarm, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateAlarm") + } + + var r0 sdk.Alarm + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Alarm, string, string) (sdk.Alarm, errors.SDKError)); ok { + return returnFunc(ctx, alarm, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Alarm, string, string) sdk.Alarm); ok { + r0 = returnFunc(ctx, alarm, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Alarm, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, alarm, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateAlarm' +type SDK_UpdateAlarm_Call struct { + *mock.Call +} + +// UpdateAlarm is a helper method to define mock.On call +// - ctx context.Context +// - alarm sdk.Alarm +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateAlarm(ctx interface{}, alarm interface{}, domainID interface{}, token interface{}) *SDK_UpdateAlarm_Call { + return &SDK_UpdateAlarm_Call{Call: _e.mock.On("UpdateAlarm", ctx, alarm, domainID, token)} +} + +func (_c *SDK_UpdateAlarm_Call) Run(run func(ctx context.Context, alarm sdk.Alarm, domainID string, token string)) *SDK_UpdateAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.Alarm + if args[1] != nil { + arg1 = args[1].(sdk.Alarm) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateAlarm_Call) Return(alarm1 sdk.Alarm, sDKError errors.SDKError) *SDK_UpdateAlarm_Call { + _c.Call.Return(alarm1, sDKError) + return _c +} + +func (_c *SDK_UpdateAlarm_Call) RunAndReturn(run func(ctx context.Context, alarm sdk.Alarm, domainID string, token string) (sdk.Alarm, errors.SDKError)) *SDK_UpdateAlarm_Call { + _c.Call.Return(run) + return _c +} + +// UpdateBootstrap provides a mock function for the type SDK +func (_mock *SDK) UpdateBootstrap(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateBootstrap") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.BootstrapConfig, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_UpdateBootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrap' +type SDK_UpdateBootstrap_Call struct { + *mock.Call +} + +// UpdateBootstrap is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.BootstrapConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateBootstrap(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrap_Call { + return &SDK_UpdateBootstrap_Call{Call: _e.mock.On("UpdateBootstrap", ctx, cfg, domainID, token)} +} + +func (_c *SDK_UpdateBootstrap_Call) Run(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string)) *SDK_UpdateBootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.BootstrapConfig + if args[1] != nil { + arg1 = args[1].(sdk.BootstrapConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateBootstrap_Call) Return(sDKError errors.SDKError) *SDK_UpdateBootstrap_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_UpdateBootstrap_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.BootstrapConfig, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrap_Call { + _c.Call.Return(run) + return _c +} + +// UpdateBootstrapCerts provides a mock function for the type SDK +func (_mock *SDK) UpdateBootstrapCerts(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, clientCert, clientKey, ca, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateBootstrapCerts") + } + + var r0 sdk.BootstrapConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) + } else { + r0 = ret.Get(0).(sdk.BootstrapConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, clientCert, clientKey, ca, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateBootstrapCerts_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrapCerts' +type SDK_UpdateBootstrapCerts_Call struct { + *mock.Call +} + +// UpdateBootstrapCerts is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - clientCert string +// - clientKey string +// - ca string +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateBootstrapCerts(ctx interface{}, id interface{}, clientCert interface{}, clientKey interface{}, ca interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapCerts_Call { + return &SDK_UpdateBootstrapCerts_Call{Call: _e.mock.On("UpdateBootstrapCerts", ctx, id, clientCert, clientKey, ca, domainID, token)} +} + +func (_c *SDK_UpdateBootstrapCerts_Call) Run(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string)) *SDK_UpdateBootstrapCerts_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + var arg6 string + if args[6] != nil { + arg6 = args[6].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + arg6, + ) + }) + return _c +} + +func (_c *SDK_UpdateBootstrapCerts_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKError errors.SDKError) *SDK_UpdateBootstrapCerts_Call { + _c.Call.Return(bootstrapConfig, sDKError) + return _c +} + +func (_c *SDK_UpdateBootstrapCerts_Call) RunAndReturn(run func(ctx context.Context, id string, clientCert string, clientKey string, ca string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_UpdateBootstrapCerts_Call { + _c.Call.Return(run) + return _c +} + +// UpdateBootstrapConnection provides a mock function for the type SDK +func (_mock *SDK) UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, id, channels, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateBootstrapConnection") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, id, channels, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_UpdateBootstrapConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateBootstrapConnection' +type SDK_UpdateBootstrapConnection_Call struct { + *mock.Call +} + +// UpdateBootstrapConnection is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - channels []string +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateBootstrapConnection(ctx interface{}, id interface{}, channels interface{}, domainID interface{}, token interface{}) *SDK_UpdateBootstrapConnection_Call { + return &SDK_UpdateBootstrapConnection_Call{Call: _e.mock.On("UpdateBootstrapConnection", ctx, id, channels, domainID, token)} +} + +func (_c *SDK_UpdateBootstrapConnection_Call) Run(run func(ctx context.Context, id string, channels []string, domainID string, token string)) *SDK_UpdateBootstrapConnection_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_UpdateBootstrapConnection_Call) Return(sDKError errors.SDKError) *SDK_UpdateBootstrapConnection_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_UpdateBootstrapConnection_Call) RunAndReturn(run func(ctx context.Context, id string, channels []string, domainID string, token string) errors.SDKError) *SDK_UpdateBootstrapConnection_Call { + _c.Call.Return(run) + return _c +} + // UpdateChannel provides a mock function for the type SDK func (_mock *SDK) UpdateChannel(ctx context.Context, channel sdk.Channel, domainID string, token string) (sdk.Channel, errors.SDKError) { ret := _mock.Called(ctx, channel, domainID, token) @@ -9201,6 +12319,477 @@ func (_c *SDK_UpdateProfilePicture_Call) RunAndReturn(run func(ctx context.Conte return _c } +// UpdateReportConfig provides a mock function for the type SDK +func (_mock *SDK) UpdateReportConfig(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportConfig") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.ReportConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportConfig' +type SDK_UpdateReportConfig_Call struct { + *mock.Call +} + +// UpdateReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.ReportConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateReportConfig(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateReportConfig_Call { + return &SDK_UpdateReportConfig_Call{Call: _e.mock.On("UpdateReportConfig", ctx, cfg, domainID, token)} +} + +func (_c *SDK_UpdateReportConfig_Call) Run(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string)) *SDK_UpdateReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.ReportConfig + if args[1] != nil { + arg1 = args[1].(sdk.ReportConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateReportConfig_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_UpdateReportConfig_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_UpdateReportConfig_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_UpdateReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportSchedule provides a mock function for the type SDK +func (_mock *SDK) UpdateReportSchedule(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportSchedule") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, cfg, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.ReportConfig, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateReportSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportSchedule' +type SDK_UpdateReportSchedule_Call struct { + *mock.Call +} + +// UpdateReportSchedule is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.ReportConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateReportSchedule(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateReportSchedule_Call { + return &SDK_UpdateReportSchedule_Call{Call: _e.mock.On("UpdateReportSchedule", ctx, cfg, domainID, token)} +} + +func (_c *SDK_UpdateReportSchedule_Call) Run(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string)) *SDK_UpdateReportSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.ReportConfig + if args[1] != nil { + arg1 = args[1].(sdk.ReportConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateReportSchedule_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_UpdateReportSchedule_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_UpdateReportSchedule_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_UpdateReportSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportTemplate provides a mock function for the type SDK +func (_mock *SDK) UpdateReportTemplate(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, cfg, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportTemplate") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.ReportConfig, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, cfg, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_UpdateReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportTemplate' +type SDK_UpdateReportTemplate_Call struct { + *mock.Call +} + +// UpdateReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - cfg sdk.ReportConfig +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateReportTemplate(ctx interface{}, cfg interface{}, domainID interface{}, token interface{}) *SDK_UpdateReportTemplate_Call { + return &SDK_UpdateReportTemplate_Call{Call: _e.mock.On("UpdateReportTemplate", ctx, cfg, domainID, token)} +} + +func (_c *SDK_UpdateReportTemplate_Call) Run(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string)) *SDK_UpdateReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.ReportConfig + if args[1] != nil { + arg1 = args[1].(sdk.ReportConfig) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateReportTemplate_Call) Return(sDKError errors.SDKError) *SDK_UpdateReportTemplate_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_UpdateReportTemplate_Call) RunAndReturn(run func(ctx context.Context, cfg sdk.ReportConfig, domainID string, token string) errors.SDKError) *SDK_UpdateReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRule provides a mock function for the type SDK +func (_mock *SDK) UpdateRule(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, r, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateRule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, r, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, r, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Rule, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, r, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRule' +type SDK_UpdateRule_Call struct { + *mock.Call +} + +// UpdateRule is a helper method to define mock.On call +// - ctx context.Context +// - r sdk.Rule +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateRule(ctx interface{}, r interface{}, domainID interface{}, token interface{}) *SDK_UpdateRule_Call { + return &SDK_UpdateRule_Call{Call: _e.mock.On("UpdateRule", ctx, r, domainID, token)} +} + +func (_c *SDK_UpdateRule_Call) Run(run func(ctx context.Context, r sdk.Rule, domainID string, token string)) *SDK_UpdateRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.Rule + if args[1] != nil { + arg1 = args[1].(sdk.Rule) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateRule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_UpdateRule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_UpdateRule_Call) RunAndReturn(run func(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_UpdateRule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleSchedule provides a mock function for the type SDK +func (_mock *SDK) UpdateRuleSchedule(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, r, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleSchedule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, r, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, r, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Rule, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, r, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateRuleSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleSchedule' +type SDK_UpdateRuleSchedule_Call struct { + *mock.Call +} + +// UpdateRuleSchedule is a helper method to define mock.On call +// - ctx context.Context +// - r sdk.Rule +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateRuleSchedule(ctx interface{}, r interface{}, domainID interface{}, token interface{}) *SDK_UpdateRuleSchedule_Call { + return &SDK_UpdateRuleSchedule_Call{Call: _e.mock.On("UpdateRuleSchedule", ctx, r, domainID, token)} +} + +func (_c *SDK_UpdateRuleSchedule_Call) Run(run func(ctx context.Context, r sdk.Rule, domainID string, token string)) *SDK_UpdateRuleSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.Rule + if args[1] != nil { + arg1 = args[1].(sdk.Rule) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateRuleSchedule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_UpdateRuleSchedule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_UpdateRuleSchedule_Call) RunAndReturn(run func(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_UpdateRuleSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleTags provides a mock function for the type SDK +func (_mock *SDK) UpdateRuleTags(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, r, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleTags") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, r, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, sdk.Rule, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, r, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, sdk.Rule, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, r, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_UpdateRuleTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleTags' +type SDK_UpdateRuleTags_Call struct { + *mock.Call +} + +// UpdateRuleTags is a helper method to define mock.On call +// - ctx context.Context +// - r sdk.Rule +// - domainID string +// - token string +func (_e *SDK_Expecter) UpdateRuleTags(ctx interface{}, r interface{}, domainID interface{}, token interface{}) *SDK_UpdateRuleTags_Call { + return &SDK_UpdateRuleTags_Call{Call: _e.mock.On("UpdateRuleTags", ctx, r, domainID, token)} +} + +func (_c *SDK_UpdateRuleTags_Call) Run(run func(ctx context.Context, r sdk.Rule, domainID string, token string)) *SDK_UpdateRuleTags_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 sdk.Rule + if args[1] != nil { + arg1 = args[1].(sdk.Rule) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_UpdateRuleTags_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_UpdateRuleTags_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_UpdateRuleTags_Call) RunAndReturn(run func(ctx context.Context, r sdk.Rule, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_UpdateRuleTags_Call { + _c.Call.Return(run) + return _c +} + // UpdateUser provides a mock function for the type SDK func (_mock *SDK) UpdateUser(ctx context.Context, user sdk.User, token string) (sdk.User, errors.SDKError) { ret := _mock.Called(ctx, user, token) @@ -9845,3 +13434,698 @@ func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verif _c.Call.Return(run) return _c } + +// ViewAlarm provides a mock function for the type SDK +func (_mock *SDK) ViewAlarm(ctx context.Context, id string, domainID string, token string) (sdk.Alarm, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewAlarm") + } + + var r0 sdk.Alarm + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Alarm, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Alarm); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Alarm) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewAlarm_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewAlarm' +type SDK_ViewAlarm_Call struct { + *mock.Call +} + +// ViewAlarm is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewAlarm(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewAlarm_Call { + return &SDK_ViewAlarm_Call{Call: _e.mock.On("ViewAlarm", ctx, id, domainID, token)} +} + +func (_c *SDK_ViewAlarm_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewAlarm_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewAlarm_Call) Return(alarm sdk.Alarm, sDKError errors.SDKError) *SDK_ViewAlarm_Call { + _c.Call.Return(alarm, sDKError) + return _c +} + +func (_c *SDK_ViewAlarm_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.Alarm, errors.SDKError)) *SDK_ViewAlarm_Call { + _c.Call.Return(run) + return _c +} + +// ViewBootstrap provides a mock function for the type SDK +func (_mock *SDK) ViewBootstrap(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewBootstrap") + } + + var r0 sdk.BootstrapConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.BootstrapConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.BootstrapConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.BootstrapConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewBootstrap_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewBootstrap' +type SDK_ViewBootstrap_Call struct { + *mock.Call +} + +// ViewBootstrap is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewBootstrap(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewBootstrap_Call { + return &SDK_ViewBootstrap_Call{Call: _e.mock.On("ViewBootstrap", ctx, id, domainID, token)} +} + +func (_c *SDK_ViewBootstrap_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewBootstrap_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewBootstrap_Call) Return(bootstrapConfig sdk.BootstrapConfig, sDKError errors.SDKError) *SDK_ViewBootstrap_Call { + _c.Call.Return(bootstrapConfig, sDKError) + return _c +} + +func (_c *SDK_ViewBootstrap_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.BootstrapConfig, errors.SDKError)) *SDK_ViewBootstrap_Call { + _c.Call.Return(run) + return _c +} + +// ViewCA provides a mock function for the type SDK +func (_mock *SDK) ViewCA(ctx context.Context) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ViewCA") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx) + } + if returnFunc, ok := ret.Get(0).(func(context.Context) sdk.Certificate); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context) errors.SDKError); ok { + r1 = returnFunc(ctx) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewCA' +type SDK_ViewCA_Call struct { + *mock.Call +} + +// ViewCA is a helper method to define mock.On call +// - ctx context.Context +func (_e *SDK_Expecter) ViewCA(ctx interface{}) *SDK_ViewCA_Call { + return &SDK_ViewCA_Call{Call: _e.mock.On("ViewCA", ctx)} +} + +func (_c *SDK_ViewCA_Call) Run(run func(ctx context.Context)) *SDK_ViewCA_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *SDK_ViewCA_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_ViewCA_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_ViewCA_Call) RunAndReturn(run func(ctx context.Context) (sdk.Certificate, errors.SDKError)) *SDK_ViewCA_Call { + _c.Call.Return(run) + return _c +} + +// ViewCert provides a mock function for the type SDK +func (_mock *SDK) ViewCert(ctx context.Context, serialNumber string, domainID string, token string) (sdk.Certificate, errors.SDKError) { + ret := _mock.Called(ctx, serialNumber, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewCert") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return returnFunc(ctx, serialNumber, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Certificate); ok { + r0 = returnFunc(ctx, serialNumber, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, serialNumber, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewCert' +type SDK_ViewCert_Call struct { + *mock.Call +} + +// ViewCert is a helper method to define mock.On call +// - ctx context.Context +// - serialNumber string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewCert(ctx interface{}, serialNumber interface{}, domainID interface{}, token interface{}) *SDK_ViewCert_Call { + return &SDK_ViewCert_Call{Call: _e.mock.On("ViewCert", ctx, serialNumber, domainID, token)} +} + +func (_c *SDK_ViewCert_Call) Run(run func(ctx context.Context, serialNumber string, domainID string, token string)) *SDK_ViewCert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewCert_Call) Return(certificate sdk.Certificate, sDKError errors.SDKError) *SDK_ViewCert_Call { + _c.Call.Return(certificate, sDKError) + return _c +} + +func (_c *SDK_ViewCert_Call) RunAndReturn(run func(ctx context.Context, serialNumber string, domainID string, token string) (sdk.Certificate, errors.SDKError)) *SDK_ViewCert_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportConfig provides a mock function for the type SDK +func (_mock *SDK) ViewReportConfig(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewReportConfig") + } + + var r0 sdk.ReportConfig + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.ReportConfig, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.ReportConfig); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportConfig' +type SDK_ViewReportConfig_Call struct { + *mock.Call +} + +// ViewReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewReportConfig(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewReportConfig_Call { + return &SDK_ViewReportConfig_Call{Call: _e.mock.On("ViewReportConfig", ctx, id, domainID, token)} +} + +func (_c *SDK_ViewReportConfig_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewReportConfig_Call) Return(reportConfig sdk.ReportConfig, sDKError errors.SDKError) *SDK_ViewReportConfig_Call { + _c.Call.Return(reportConfig, sDKError) + return _c +} + +func (_c *SDK_ViewReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.ReportConfig, errors.SDKError)) *SDK_ViewReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportTemplate provides a mock function for the type SDK +func (_mock *SDK) ViewReportTemplate(ctx context.Context, id string, domainID string, token string) (sdk.ReportTemplate, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewReportTemplate") + } + + var r0 sdk.ReportTemplate + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.ReportTemplate, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.ReportTemplate); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(sdk.ReportTemplate) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportTemplate' +type SDK_ViewReportTemplate_Call struct { + *mock.Call +} + +// ViewReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewReportTemplate(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewReportTemplate_Call { + return &SDK_ViewReportTemplate_Call{Call: _e.mock.On("ViewReportTemplate", ctx, id, domainID, token)} +} + +func (_c *SDK_ViewReportTemplate_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewReportTemplate_Call) Return(reportTemplate sdk.ReportTemplate, sDKError errors.SDKError) *SDK_ViewReportTemplate_Call { + _c.Call.Return(reportTemplate, sDKError) + return _c +} + +func (_c *SDK_ViewReportTemplate_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.ReportTemplate, errors.SDKError)) *SDK_ViewReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// ViewRule provides a mock function for the type SDK +func (_mock *SDK) ViewRule(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError) { + ret := _mock.Called(ctx, id, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for ViewRule") + } + + var r0 sdk.Rule + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (sdk.Rule, errors.SDKError)); ok { + return returnFunc(ctx, id, domainID, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) sdk.Rule); ok { + r0 = returnFunc(ctx, id, domainID, token) + } else { + r0 = ret.Get(0).(sdk.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, domainID, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewRule' +type SDK_ViewRule_Call struct { + *mock.Call +} + +// ViewRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - domainID string +// - token string +func (_e *SDK_Expecter) ViewRule(ctx interface{}, id interface{}, domainID interface{}, token interface{}) *SDK_ViewRule_Call { + return &SDK_ViewRule_Call{Call: _e.mock.On("ViewRule", ctx, id, domainID, token)} +} + +func (_c *SDK_ViewRule_Call) Run(run func(ctx context.Context, id string, domainID string, token string)) *SDK_ViewRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *SDK_ViewRule_Call) Return(rule sdk.Rule, sDKError errors.SDKError) *SDK_ViewRule_Call { + _c.Call.Return(rule, sDKError) + return _c +} + +func (_c *SDK_ViewRule_Call) RunAndReturn(run func(ctx context.Context, id string, domainID string, token string) (sdk.Rule, errors.SDKError)) *SDK_ViewRule_Call { + _c.Call.Return(run) + return _c +} + +// ViewSubscription provides a mock function for the type SDK +func (_mock *SDK) ViewSubscription(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError) { + ret := _mock.Called(ctx, id, token) + + if len(ret) == 0 { + panic("no return value specified for ViewSubscription") + } + + var r0 sdk.Subscription + var r1 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (sdk.Subscription, errors.SDKError)); ok { + return returnFunc(ctx, id, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) sdk.Subscription); ok { + r0 = returnFunc(ctx, id, token) + } else { + r0 = ret.Get(0).(sdk.Subscription) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) errors.SDKError); ok { + r1 = returnFunc(ctx, id, token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + return r0, r1 +} + +// SDK_ViewSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewSubscription' +type SDK_ViewSubscription_Call struct { + *mock.Call +} + +// ViewSubscription is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - token string +func (_e *SDK_Expecter) ViewSubscription(ctx interface{}, id interface{}, token interface{}) *SDK_ViewSubscription_Call { + return &SDK_ViewSubscription_Call{Call: _e.mock.On("ViewSubscription", ctx, id, token)} +} + +func (_c *SDK_ViewSubscription_Call) Run(run func(ctx context.Context, id string, token string)) *SDK_ViewSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *SDK_ViewSubscription_Call) Return(subscription sdk.Subscription, sDKError errors.SDKError) *SDK_ViewSubscription_Call { + _c.Call.Return(subscription, sDKError) + return _c +} + +func (_c *SDK_ViewSubscription_Call) RunAndReturn(run func(ctx context.Context, id string, token string) (sdk.Subscription, errors.SDKError)) *SDK_ViewSubscription_Call { + _c.Call.Return(run) + return _c +} + +// Whitelist provides a mock function for the type SDK +func (_mock *SDK) Whitelist(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError { + ret := _mock.Called(ctx, clientID, state, domainID, token) + + if len(ret) == 0 { + panic("no return value specified for Whitelist") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string, int, string, string) errors.SDKError); ok { + r0 = returnFunc(ctx, clientID, state, domainID, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_Whitelist_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Whitelist' +type SDK_Whitelist_Call struct { + *mock.Call +} + +// Whitelist is a helper method to define mock.On call +// - ctx context.Context +// - clientID string +// - state int +// - domainID string +// - token string +func (_e *SDK_Expecter) Whitelist(ctx interface{}, clientID interface{}, state interface{}, domainID interface{}, token interface{}) *SDK_Whitelist_Call { + return &SDK_Whitelist_Call{Call: _e.mock.On("Whitelist", ctx, clientID, state, domainID, token)} +} + +func (_c *SDK_Whitelist_Call) Run(run func(ctx context.Context, clientID string, state int, domainID string, token string)) *SDK_Whitelist_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 int + if args[2] != nil { + arg2 = args[2].(int) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *SDK_Whitelist_Call) Return(sDKError errors.SDKError) *SDK_Whitelist_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_Whitelist_Call) RunAndReturn(run func(ctx context.Context, clientID string, state int, domainID string, token string) errors.SDKError) *SDK_Whitelist_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/sdk/reports.go b/pkg/sdk/reports.go new file mode 100644 index 000000000..5a00df9e3 --- /dev/null +++ b/pkg/sdk/reports.go @@ -0,0 +1,302 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/absmach/supermq/pkg/errors" +) + +const ( + reportsEndpoint = "reports" + configsEndpointReports = "configs" +) + +// ReportConfig represents a report configuration. +type ReportConfig struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + DomainID string `json:"domain_id,omitempty"` + Schedule any `json:"schedule,omitempty"` + Config any `json:"config,omitempty"` + Email any `json:"email,omitempty"` + Metrics any `json:"metrics,omitempty"` + ReportTemplate ReportTemplate `json:"report_template,omitempty"` + Status string `json:"status,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + CreatedBy string `json:"created_by,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + UpdatedBy string `json:"updated_by,omitempty"` +} + +type ReportTemplate any + +type ReportFile struct { + Name string + Format string + Data []byte +} + +type ReportPage struct { + Total uint64 `json:"total"` + From time.Time `json:"from,omitempty"` + To time.Time `json:"to,omitempty"` + Aggregation any `json:"aggregation,omitempty"` + Reports any `json:"reports,omitempty"` + File any `json:"file,omitempty"` +} + +type ReportConfigPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + ReportConfigs []ReportConfig `json:"report_configs"` +} + +type ReportAction string + +const ( + ViewReportAction ReportAction = "view" + DownloadReportAction ReportAction = "download" + EmailReportAction ReportAction = "email" +) + +func (sdk mgSDK) AddReportConfig(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, errors.SDKError) { + data, err := json.Marshal(cfg) + if err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusCreated, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) ViewReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/%s", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) UpdateReportConfig(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, errors.SDKError) { + data, err := json.Marshal(cfg) + if err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s/%s", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, cfg.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) UpdateReportSchedule(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, errors.SDKError) { + data, err := json.Marshal(map[string]any{"schedule": cfg.Schedule}) + if err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s/%s/schedule", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, cfg.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) RemoveReportConfig(ctx context.Context, id, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s/%s", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent, http.StatusOK) + return sdkerr +} + +func (sdk mgSDK) ListReportsConfig(ctx context.Context, pm PageMetadata, domainID, token string) (ReportConfigPage, errors.SDKError) { + endpoint := fmt.Sprintf("%s/%s/%s", domainID, reportsEndpoint, configsEndpointReports) + url, err := sdk.withQueryParams(sdk.reportsURL, endpoint, pm) + if err != nil { + return ReportConfigPage{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfigPage{}, sdkerr + } + + var rcp ReportConfigPage + if err := json.Unmarshal(body, &rcp); err != nil { + return ReportConfigPage{}, errors.NewSDKError(err) + } + + return rcp, nil +} + +func (sdk mgSDK) EnableReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/%s/enable", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) DisableReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/%s/disable", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return ReportConfig{}, sdkerr + } + + var rc ReportConfig + if err := json.Unmarshal(body, &rc); err != nil { + return ReportConfig{}, errors.NewSDKError(err) + } + + return rc, nil +} + +func (sdk mgSDK) UpdateReportTemplate(ctx context.Context, cfg ReportConfig, domainID, token string) errors.SDKError { + data, err := json.Marshal(cfg) + if err != nil { + return errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s/%s/template", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, cfg.ID) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodPut, url, token, data, nil, http.StatusNoContent) + return sdkerr +} + +func (sdk mgSDK) ViewReportTemplate(ctx context.Context, id, domainID, token string) (ReportTemplate, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/%s/template", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return "", sdkerr + } + + var rt ReportTemplate + if err := json.Unmarshal(body, &rt); err != nil { + return "", errors.NewSDKError(err) + } + + return rt, nil +} + +func (sdk mgSDK) DeleteReportTemplate(ctx context.Context, id, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s/%s/template", sdk.reportsURL, domainID, reportsEndpoint, configsEndpointReports, id) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent, http.StatusOK) + return sdkerr +} + +func (sdk mgSDK) GenerateReport( + ctx context.Context, + config ReportConfig, + action ReportAction, + domainID, + token string, +) (ReportPage, *ReportFile, errors.SDKError) { + data, err := json.Marshal(config) + if err != nil { + return ReportPage{}, nil, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s?action=%s", + sdk.reportsURL, + domainID, + reportsEndpoint, + action, + ) + + headers, body, sdkerr := sdk.processRequest( + ctx, + http.MethodPost, + url, + token, + data, + nil, + http.StatusOK, + ) + if sdkerr != nil { + return ReportPage{}, nil, sdkerr + } + + // ✅ Handle Download Action + if action == DownloadReportAction { + file := &ReportFile{ + Name: extractFilename(headers.Get("Content-Disposition")), + Format: "pdf", + Data: body, + } + return ReportPage{}, file, nil + } + + // ✅ Handle JSON response (view/email) + var rp ReportPage + if err := json.Unmarshal(body, &rp); err != nil { + return ReportPage{}, nil, errors.NewSDKError(err) + } + + return rp, nil, nil +} + +func extractFilename(contentDisposition string) string { + const prefix = "filename=" + if idx := strings.Index(contentDisposition, prefix); idx != -1 { + return strings.Trim(contentDisposition[idx+len(prefix):], `"`) + } + return "report" +} diff --git a/pkg/sdk/reports_test.go b/pkg/sdk/reports_test.go new file mode 100644 index 000000000..feadfa40a --- /dev/null +++ b/pkg/sdk/reports_test.go @@ -0,0 +1,867 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + "time" + + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + pkgSch "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/api" + rmocks "github.com/absmach/supermq/reports/mocks" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + reportConfigID = "report-config-1" + reportName = "daily-report" + reportUpdatedName = "updated daily-report" + reportDescription = "Daily temperature report" + reportUpdatedDesc = "updated Daily temperature report" + validTemplate = ` + + + {{$.Title}} + + + +
+

{{$.Title}}

+

Generated on: {{$.GeneratedDate}}

+
+
+

Messages

+ {{range .Messages}} +
+

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+
+ {{end}} +
+ +` +) + +var ( + now = time.Now().UTC().Truncate(time.Minute) + future = now.Add(1 * time.Hour) + schedule = pkgSch.Schedule{ + StartDateTime: future, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: future, + } + metrics = []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + } + config = reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: "test_title", + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + } + email = reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + } + + testReportConfig = sdk.ReportConfig{ + ID: reportConfigID, + Name: reportName, + Description: reportDescription, + DomainID: domainID, + Status: "enabled", + Schedule: schedule, + Metrics: metrics, + Config: &config, + Email: &email, + } +) + +func setupReports() (*httptest.Server, *rmocks.Service, *authnmocks.Authentication) { + rsvc := new(rmocks.Service) + log := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := chi.NewRouter() + _ = api.MakeHandler(rsvc, am, mux, log, "") + return httptest.NewServer(mux), rsvc, authn +} + +func TestAddReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Name: reportName, + Description: reportDescription, + DomainID: domainID, + Status: reports.EnabledStatus, + Schedule: schedule, + Metrics: []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + }, + Config: &reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: "test_title", + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + }, + Email: &reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + }, + } + + cases := []struct { + desc string + cfg sdk.ReportConfig + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "add report config successfully", + cfg: testReportConfig, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "add report config with empty token", + cfg: sdk.ReportConfig{Name: "daily-report"}, + token: "", + wantErr: true, + svcErr: errors.New("missing or invalid bearer user token"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("AddReportConfig", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.AddReportConfig(context.Background(), tc.cfg, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Name: reportName, + Description: reportDescription, + DomainID: domainID, + Status: reports.EnabledStatus, + Metrics: metrics, + Config: &config, + Email: &email, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "view report config successfully", + id: reportConfigID, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "view report config with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + { + desc: "view non-existent report config", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("ViewReportConfig", mock.Anything, tc.session, tc.id, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ViewReportConfig(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + updatedConfig := testReportConfig + updatedConfig.Name = reportUpdatedName + updatedConfig.Description = reportUpdatedDesc + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Name: reportUpdatedName, + Description: reportUpdatedDesc, + DomainID: domainID, + Status: reports.EnabledStatus, + Metrics: metrics, + Config: &config, + Email: &email, + } + + cases := []struct { + desc string + cfg sdk.ReportConfig + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update report config successfully", + cfg: updatedConfig, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "update report config with empty token", + cfg: sdk.ReportConfig{ID: reportConfigID, Name: "updated-report"}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateReportConfig", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateReportConfig(context.Background(), tc.cfg, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateReportSchedule(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Name: reportName, + Status: reports.EnabledStatus, + } + + cases := []struct { + desc string + cfg sdk.ReportConfig + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update report schedule successfully", + cfg: sdk.ReportConfig{ID: reportConfigID, Schedule: map[string]any{"cron": "0 9 * * *"}}, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "update report schedule with empty token", + cfg: sdk.ReportConfig{ID: reportConfigID}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateReportSchedule", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateReportSchedule(context.Background(), tc.cfg, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "remove report config successfully", + id: reportConfigID, + token: validToken, + }, + { + desc: "remove report config with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + { + desc: "remove non-existent report config", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("RemoveReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + err := mgsdk.RemoveReportConfig(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListReportsConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcPage := reports.ReportConfigPage{} + + cases := []struct { + desc string + pm sdk.PageMetadata + token string + session smqauthn.Session + svcRes reports.ReportConfigPage + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "list reports config successfully", + pm: sdk.PageMetadata{Offset: 0, Limit: 10}, + token: validToken, + svcRes: svcPage, + }, + { + desc: "list reports config with filters", + pm: sdk.PageMetadata{ + Limit: 10, + Name: "daily", + Status: "enabled", + Dir: "desc", + Order: "created_at", + }, + token: validToken, + svcRes: svcPage, + }, + { + desc: "list reports config with empty metadata excludes filter params", + pm: sdk.PageMetadata{}, + token: validToken, + svcRes: reports.ReportConfigPage{}, + }, + { + desc: "list reports config with empty token", + pm: sdk.PageMetadata{Limit: 10}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("ListReportsConfig", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ListReportsConfig(context.Background(), tc.pm, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotNil(t, result) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestEnableReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Status: reports.EnabledStatus, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "enable report config successfully", + id: reportConfigID, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "enable report config with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("EnableReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.EnableReportConfig(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDisableReportConfig(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcCfg := reports.ReportConfig{ + ID: reportConfigID, + Status: reports.DisabledStatus, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes reports.ReportConfig + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "disable report config successfully", + id: reportConfigID, + token: validToken, + svcRes: svcCfg, + }, + { + desc: "disable report config with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("DisableReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.DisableReportConfig(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateReportTemplate(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + cfg sdk.ReportConfig + token string + session smqauthn.Session + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update report template successfully", + cfg: sdk.ReportConfig{ + ID: reportConfigID, + ReportTemplate: validTemplate, + }, + token: validToken, + }, + { + desc: "update report template with empty token", + cfg: sdk.ReportConfig{ID: reportConfigID}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateReportTemplate", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr) + err := mgsdk.UpdateReportTemplate(context.Background(), tc.cfg, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewReportTemplate(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcTmpl := reports.ReportTemplate(validTemplate) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes reports.ReportTemplate + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "view report template successfully", + id: reportConfigID, + token: validToken, + svcRes: svcTmpl, + }, + { + desc: "view report template with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("ViewReportTemplate", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ViewReportTemplate(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteReportTemplate(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "delete report template successfully", + id: reportConfigID, + token: validToken, + }, + { + desc: "delete report template with empty token", + id: reportConfigID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("DeleteReportTemplate", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + err := mgsdk.DeleteReportTemplate(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestGenerateReport(t *testing.T) { + rs, rsvc, auth := setupReports() + defer rs.Close() + + conf := sdk.Config{ + ReportsURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcPage := reports.ReportPage{} + + config := sdk.ReportConfig{ + ID: reportConfigID, + Name: reportName, + Description: reportDescription, + DomainID: domainID, + Metrics: metrics, + Config: &config, + ReportTemplate: reports.ReportTemplate(validTemplate), + } + + cases := []struct { + desc string + cfg sdk.ReportConfig + action sdk.ReportAction + token string + session smqauthn.Session + svcRes reports.ReportPage + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "generate report successfully", + cfg: config, + action: sdk.ViewReportAction, + token: validToken, + svcRes: svcPage, + }, + { + desc: "generate report with download action", + cfg: config, + action: sdk.DownloadReportAction, + token: validToken, + svcRes: svcPage, + }, + { + desc: "generate report with empty token", + cfg: config, + action: sdk.ViewReportAction, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{ + DomainUserID: domainID + "_" + validID, + UserID: validID, + DomainID: domainID, + } + } + + authCall := auth.On( + "Authenticate", + mock.Anything, + tc.token, + ).Return(tc.session, tc.authenticateErr) + + svcCall := rsvc.On( + "GenerateReport", + mock.Anything, + tc.session, + mock.Anything, + mock.Anything, + ).Return(tc.svcRes, tc.svcErr) + + page, file, err := mgsdk.GenerateReport( + context.Background(), + tc.cfg, + tc.action, + domainID, + tc.token, + ) + + assert.Equal(t, tc.wantErr, err != nil) + + if !tc.wantErr { + if tc.action == sdk.DownloadReportAction { + // download should return file + assert.NotNil(t, file) + } else { + // view/email should return page + assert.Equal(t, tc.svcRes.Total, page.Total) + } + } + + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/pkg/sdk/responses.go b/pkg/sdk/responses.go index 6a4545bc2..9074608e6 100644 --- a/pkg/sdk/responses.go +++ b/pkg/sdk/responses.go @@ -64,6 +64,18 @@ type DomainsPage struct { PageRes } +// BootstrapPage contains list of bootstrap configs in a page with proper metadata. +type BootstrapPage struct { + Configs []BootstrapConfig `json:"configs"` + PageRes +} + +// SubscriptionPage contains list of subscriptions in a page with proper metadata. +type SubscriptionPage struct { + Subscriptions []Subscription `json:"subscriptions"` + PageRes +} + type roleActionsRes struct { Actions []string `json:"actions"` } diff --git a/pkg/sdk/rules.go b/pkg/sdk/rules.go new file mode 100644 index 000000000..f57a667cf --- /dev/null +++ b/pkg/sdk/rules.go @@ -0,0 +1,200 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/absmach/supermq/pkg/errors" +) + +const rulesEndpoint = "rules" + +// Rule represents a rule configuration. +type Rule struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + DomainID string `json:"domain,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` + Tags []string `json:"tags,omitempty"` + InputChannel string `json:"input_channel,omitempty"` + InputTopic string `json:"input_topic,omitempty"` + Logic any `json:"logic,omitempty"` + Outputs any `json:"outputs,omitempty"` + Schedule any `json:"schedule,omitempty"` + Status string `json:"status,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + CreatedBy string `json:"created_by,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + UpdatedBy string `json:"updated_by,omitempty"` +} + +type Page struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Rules []Rule `json:"rules"` +} + +func (sdk mgSDK) AddRule(ctx context.Context, r Rule, domainID, token string) (Rule, errors.SDKError) { + data, err := json.Marshal(r) + if err != nil { + return Rule{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s", sdk.rulesEngineURL, domainID, rulesEndpoint) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, data, nil, http.StatusCreated, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) ViewRule(ctx context.Context, id, domainID, token string) (Rule, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.rulesEngineURL, domainID, rulesEndpoint, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) UpdateRule(ctx context.Context, r Rule, domainID, token string) (Rule, errors.SDKError) { + data, err := json.Marshal(r) + if err != nil { + return Rule{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s", sdk.rulesEngineURL, domainID, rulesEndpoint, r.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) UpdateRuleTags(ctx context.Context, r Rule, domainID, token string) (Rule, errors.SDKError) { + data, err := json.Marshal(map[string]any{"tags": r.Tags}) + if err != nil { + return Rule{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s/tags", sdk.rulesEngineURL, domainID, rulesEndpoint, r.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) UpdateRuleSchedule(ctx context.Context, r Rule, domainID, token string) (Rule, errors.SDKError) { + data, err := json.Marshal(map[string]any{"schedule": r.Schedule}) + if err != nil { + return Rule{}, errors.NewSDKError(err) + } + + url := fmt.Sprintf("%s/%s/%s/%s/schedule", sdk.rulesEngineURL, domainID, rulesEndpoint, r.ID) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPatch, url, token, data, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) ListRules(ctx context.Context, pm PageMetadata, domainID, token string) (Page, errors.SDKError) { + endpoint := fmt.Sprintf("%s/%s", domainID, rulesEndpoint) + url, err := sdk.withQueryParams(sdk.rulesEngineURL, endpoint, pm) + if err != nil { + return Page{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodGet, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Page{}, sdkerr + } + + var ap Page + if err := json.Unmarshal(body, &ap); err != nil { + return Page{}, errors.NewSDKError(err) + } + + return ap, nil +} + +func (sdk mgSDK) RemoveRule(ctx context.Context, id, domainID, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/%s", sdk.rulesEngineURL, domainID, rulesEndpoint, id) + + _, _, sdkerr := sdk.processRequest(ctx, http.MethodDelete, url, token, nil, nil, http.StatusNoContent, http.StatusOK) + return sdkerr +} + +func (sdk mgSDK) EnableRule(ctx context.Context, id, domainID, token string) (Rule, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/enable", sdk.rulesEngineURL, domainID, rulesEndpoint, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} + +func (sdk mgSDK) DisableRule(ctx context.Context, id, domainID, token string) (Rule, errors.SDKError) { + url := fmt.Sprintf("%s/%s/%s/%s/disable", sdk.rulesEngineURL, domainID, rulesEndpoint, id) + + _, body, sdkerr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK) + if sdkerr != nil { + return Rule{}, sdkerr + } + + var a Rule + if err := json.Unmarshal(body, &a); err != nil { + return Rule{}, errors.NewSDKError(err) + } + + return a, nil +} diff --git a/pkg/sdk/rules_test.go b/pkg/sdk/rules_test.go new file mode 100644 index 000000000..5211903b5 --- /dev/null +++ b/pkg/sdk/rules_test.go @@ -0,0 +1,586 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/sdk" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/api" + remocks "github.com/absmach/supermq/re/mocks" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ruleID = "rule-1" + +var testRule = sdk.Rule{ + ID: ruleID, + Name: "temperature-rule", + InputChannel: "chan-1", + InputTopic: "sensors/temperature", + Status: "enabled", + Tags: []string{"temperature", "alerts"}, +} + +func setupRules() (*httptest.Server, *remocks.Service, *authnmocks.Authentication) { + rsvc := new(remocks.Service) + log := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := chi.NewRouter() + _ = api.MakeHandler(rsvc, am, mux, log, "") + return httptest.NewServer(mux), rsvc, authn +} + +func TestAddRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + Name: "temperature-rule", + InputChannel: "chan-1", + Status: re.EnabledStatus, + } + + cases := []struct { + desc string + rule sdk.Rule + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "add rule successfully", + rule: sdk.Rule{Name: "temp-rule", InputChannel: "chan-1"}, + token: validToken, + svcRes: svcRule, + }, + { + desc: "add rule with empty token", + rule: sdk.Rule{Name: "temp-rule"}, + token: "", + wantErr: true, + }, + { + desc: "add rule with bad request", + rule: sdk.Rule{}, + token: validToken, + svcErr: errors.New("bad request"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("AddRule", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, []roles.RoleProvision(nil), tc.svcErr) + result, err := mgsdk.AddRule(context.Background(), tc.rule, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + Name: "temperature-rule", + InputChannel: "chan-1", + Status: re.EnabledStatus, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "view rule successfully", + id: ruleID, + token: validToken, + svcRes: svcRule, + }, + { + desc: "view rule with empty token", + id: ruleID, + token: "", + wantErr: true, + }, + { + desc: "view non-existent rule", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("ViewRule", mock.Anything, tc.session, tc.id, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ViewRule(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + updatedRule := testRule + updatedRule.Name = "updated-rule" + + svcRule := re.Rule{ + ID: ruleID, + Name: "updated-rule", + InputChannel: "chan-1", + InputTopic: "sensors/temperature", + Status: re.EnabledStatus, + Tags: []string{"temperature", "alerts"}, + } + + cases := []struct { + desc string + rule sdk.Rule + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update rule successfully", + rule: updatedRule, + token: validToken, + svcRes: svcRule, + }, + { + desc: "update rule with empty token", + rule: updatedRule, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateRule", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateRule(context.Background(), tc.rule, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRuleTags(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + Tags: []string{"new-tag"}, + } + + cases := []struct { + desc string + rule sdk.Rule + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update rule tags successfully", + rule: sdk.Rule{ID: ruleID, Tags: []string{"new-tag"}}, + token: validToken, + svcRes: svcRule, + }, + { + desc: "update rule tags with empty token", + rule: sdk.Rule{ID: ruleID}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateRuleTags", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateRuleTags(context.Background(), tc.rule, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRuleSchedule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + } + + cases := []struct { + desc string + rule sdk.Rule + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "update rule schedule successfully", + rule: sdk.Rule{ID: ruleID, Schedule: map[string]any{"cron": "0 * * * *"}}, + token: validToken, + svcRes: svcRule, + }, + { + desc: "update rule schedule with empty token", + rule: sdk.Rule{ID: ruleID}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("UpdateRuleSchedule", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.UpdateRuleSchedule(context.Background(), tc.rule, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListRules(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcPage := re.Page{} + + cases := []struct { + desc string + pm sdk.PageMetadata + token string + session smqauthn.Session + svcRes re.Page + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "list rules successfully", + pm: sdk.PageMetadata{Offset: 0, Limit: 10}, + token: validToken, + svcRes: svcPage, + }, + { + desc: "list rules with filters", + pm: sdk.PageMetadata{ + Limit: 5, + Name: "temp", + Status: "enabled", + InputChannel: "chan-1", + Tag: "temperature", + Dir: "desc", + Order: "created_at", + }, + token: validToken, + svcRes: svcPage, + }, + { + desc: "list rules with empty metadata excludes filter params", + pm: sdk.PageMetadata{}, + token: validToken, + svcRes: re.Page{}, + }, + { + desc: "list rules with empty token", + pm: sdk.PageMetadata{Limit: 10}, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("ListRules", mock.Anything, tc.session, mock.Anything).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.ListRules(context.Background(), tc.pm, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotNil(t, result) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestEnableRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + Status: re.EnabledStatus, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "enable rule successfully", + id: ruleID, + token: validToken, + svcRes: svcRule, + }, + { + desc: "enable rule with empty token", + id: ruleID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("EnableRule", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.EnableRule(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDisableRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + svcRule := re.Rule{ + ID: ruleID, + Status: re.DisabledStatus, + } + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcRes re.Rule + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "disable rule successfully", + id: ruleID, + token: validToken, + svcRes: svcRule, + }, + { + desc: "disable rule with empty token", + id: ruleID, + token: "", + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("DisableRule", mock.Anything, tc.session, tc.id).Return(tc.svcRes, tc.svcErr) + result, err := mgsdk.DisableRule(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + if !tc.wantErr { + assert.NotEmpty(t, result.ID) + } + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestRemoveRule(t *testing.T) { + rs, rsvc, auth := setupRules() + defer rs.Close() + + conf := sdk.Config{ + RulesEngineURL: rs.URL, + } + mgsdk := sdk.NewSDK(conf) + + cases := []struct { + desc string + id string + token string + session smqauthn.Session + svcErr error + authenticateErr error + wantErr bool + }{ + { + desc: "remove rule successfully", + id: ruleID, + token: validToken, + }, + { + desc: "remove rule with empty token", + id: ruleID, + token: "", + wantErr: true, + }, + { + desc: "remove non-existent rule", + id: "non-existent", + token: validToken, + svcErr: errors.New("not found"), + wantErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} + } + authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) + svcCall := rsvc.On("RemoveRule", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + err := mgsdk.RemoveRule(context.Background(), tc.id, domainID, tc.token) + assert.Equal(t, tc.wantErr, err != nil) + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index f1fd5c10b..28a4e0292 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -8,16 +8,20 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log" + "net" "net/http" "net/url" "strconv" "strings" + "syscall" "time" - "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/certs" + smqerrors "github.com/absmach/supermq/pkg/errors" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "moul.io/http2curl" ) @@ -164,6 +168,22 @@ type PageMetadata struct { EndLevel int64 `json:"end_level,omitempty"` CreatedFrom time.Time `json:"created_from,omitempty"` CreatedTo time.Time `json:"created_to,omitempty"` + Dir string `json:"dir,omitempty"` + Tag string `json:"tag,omitempty"` + InputChannel string `json:"input_channel,omitempty"` + RuleID string `json:"rule_id,omitempty"` + ChannelID string `json:"channel_id,omitempty"` + ClientID string `json:"client_id,omitempty"` + Subtopic string `json:"subtopic,omitempty"` + AssigneeID string `json:"assignee_id,omitempty"` + Severity uint8 `json:"severity,omitempty"` + UpdatedBy string `json:"updated_by,omitempty"` + AssignedBy string `json:"assigned_by,omitempty"` + AcknowledgedBy string `json:"acknowledged_by,omitempty"` + ResolvedBy string `json:"resolved_by,omitempty"` + EntityID string `json:"entity_id,omitempty"` + CommonName string `json:"common_name,omitempty"` + TTL string `json:"ttl,omitempty"` } type Role struct { @@ -193,6 +213,82 @@ type Credentials struct { Secret string `json:"secret,omitempty"` // password or token } +// CertStatus represents the status of a certificate. +type CertStatus int + +const ( + CertValid CertStatus = iota + CertRevoked CertStatus = iota + CertUnknown CertStatus = iota +) + +func (c CertStatus) String() string { + switch c { + case CertValid: + return "Valid" + case CertRevoked: + return "Revoked" + default: + return "Unknown" + } +} + +func (c CertStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + +// Certificate holds certificate data returned by the certs service SDK. +type Certificate struct { + SerialNumber string `json:"serial_number,omitempty"` + Certificate string `json:"certificate,omitempty"` + Key string `json:"key,omitempty"` + Revoked bool `json:"revoked,omitempty"` + ExpiryTime time.Time `json:"expiry_time,omitempty"` + EntityID string `json:"entity_id,omitempty"` + DownloadUrl string `json:"-"` +} + +// CertificatePage holds a page of certificates. +type CertificatePage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Certificates []Certificate `json:"certificates,omitempty"` +} + +// CertificateBundle holds CA and certificate data for download. +type CertificateBundle struct { + CA []byte `json:"ca"` + Certificate []byte `json:"certificate"` + PrivateKey []byte `json:"private_key"` +} + +// OCSPResponse holds the OCSP status response for a certificate. +type OCSPResponse struct { + Status CertStatus `json:"status"` + SerialNumber string `json:"serial_number"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + ProducedAt *time.Time `json:"produced_at,omitempty"` + ThisUpdate *time.Time `json:"this_update,omitempty"` + NextUpdate *time.Time `json:"next_update,omitempty"` + Certificate []byte `json:"certificate,omitempty"` + IssuerHash string `json:"issuer_hash,omitempty"` + RevocationReason int `json:"revocation_reason,omitempty"` +} + +// Options holds certificate subject options for issuance. +type Options struct { + CommonName string `json:"common_name"` + Organization []string `json:"organization"` + OrganizationalUnit []string `json:"organizational_unit"` + Country []string `json:"country"` + Province []string `json:"province"` + Locality []string `json:"locality"` + StreetAddress []string `json:"street_address"` + PostalCode []string `json:"postal_code"` + DnsNames []string `json:"dns_names"` +} + // SDK contains SuperMQ API. type SDK interface { // CreateUser registers supermq user. @@ -209,21 +305,21 @@ type SDK interface { // } // user, _ := sdk.CreateUser(ctx, user) // fmt.Println(user) - CreateUser(ctx context.Context, user User, token string) (User, errors.SDKError) + CreateUser(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // SendVerification sends a verification email to the user. // // example: // err := sdk.SendVerification("token") // fmt.Println(err) - SendVerification(ctx context.Context, token string) errors.SDKError + SendVerification(ctx context.Context, token string) smqerrors.SDKError // VerifyEmail verifies the user's email address using the provided token. // // example: // err := sdk.VerifyEmail("verificationToken") // fmt.Println(user) - VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError + VerifyEmail(ctx context.Context, verificationToken string) smqerrors.SDKError // User returns user object by id. // @@ -231,7 +327,7 @@ type SDK interface { // ctx := context.Background() // user, _ := sdk.User(ctx, "userID", "token") // fmt.Println(user) - User(ctx context.Context, id, token string) (User, errors.SDKError) + User(ctx context.Context, id, token string) (User, smqerrors.SDKError) // Users returns list of users. // @@ -244,7 +340,7 @@ type SDK interface { // } // users, _ := sdk.Users(ctx, pm, "token") // fmt.Println(users) - Users(ctx context.Context, pm PageMetadata, token string) (UsersPage, errors.SDKError) + Users(ctx context.Context, pm PageMetadata, token string) (UsersPage, smqerrors.SDKError) // UserProfile returns user logged in. // @@ -252,7 +348,7 @@ type SDK interface { // ctx := context.Background() // user, _ := sdk.UserProfile(ctx, "token") // fmt.Println(user) - UserProfile(ctx context.Context, token string) (User, errors.SDKError) + UserProfile(ctx context.Context, token string) (User, smqerrors.SDKError) // UpdateUser updates existing user. // @@ -267,7 +363,7 @@ type SDK interface { // } // user, _ := sdk.UpdateUser(ctx, user, "token") // fmt.Println(user) - UpdateUser(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateUser(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // UpdateUserEmail updates the user's email // @@ -281,7 +377,7 @@ type SDK interface { // } // user, _ := sdk.UpdateUserEmail(ctx, user, "token") // fmt.Println(user) - UpdateUserEmail(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateUserEmail(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // UpdateUserTags updates the user's tags. // @@ -293,7 +389,7 @@ type SDK interface { // } // user, _ := sdk.UpdateUserTags(ctx, user, "token") // fmt.Println(user) - UpdateUserTags(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateUserTags(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // UpdateUsername updates the user's Username. // @@ -307,7 +403,7 @@ type SDK interface { // } // user, _ := sdk.UpdateUsername(ctx, user, "token") // fmt.Println(user) - UpdateUsername(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateUsername(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // UpdateProfilePicture updates the user's profile picture. // @@ -319,7 +415,7 @@ type SDK interface { // } // user, _ := sdk.UpdateProfilePicture(ctx, user, "token") // fmt.Println(user) - UpdateProfilePicture(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateProfilePicture(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // UpdateUserRole updates the user's role. // @@ -331,7 +427,7 @@ type SDK interface { // } // user, _ := sdk.UpdateUserRole(ctx, user, "token") // fmt.Println(user) - UpdateUserRole(ctx context.Context, user User, token string) (User, errors.SDKError) + UpdateUserRole(ctx context.Context, user User, token string) (User, smqerrors.SDKError) // ResetPasswordRequest sends a password request email to a user. // @@ -339,7 +435,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.ResetPasswordRequest(ctx, "example@email.com") // fmt.Println(err) - ResetPasswordRequest(ctx context.Context, email string) errors.SDKError + ResetPasswordRequest(ctx context.Context, email string) smqerrors.SDKError // ResetPassword changes a user's password to the one passed in the argument. // @@ -347,7 +443,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.ResetPassword(ctx, "password","password","token") // fmt.Println(err) - ResetPassword(ctx context.Context, password, confPass, token string) errors.SDKError + ResetPassword(ctx context.Context, password, confPass, token string) smqerrors.SDKError // UpdatePassword updates user password. // @@ -355,7 +451,7 @@ type SDK interface { // ctx := context.Background() // user, _ := sdk.UpdatePassword(ctx, "oldPass", "newPass", "token") // fmt.Println(user) - UpdatePassword(ctx context.Context, oldPass, newPass, token string) (User, errors.SDKError) + UpdatePassword(ctx context.Context, oldPass, newPass, token string) (User, smqerrors.SDKError) // EnableUser changes the status of the user to enabled. // @@ -363,7 +459,7 @@ type SDK interface { // ctx := context.Background() // user, _ := sdk.EnableUser(ctx, "userID", "token") // fmt.Println(user) - EnableUser(ctx context.Context, id, token string) (User, errors.SDKError) + EnableUser(ctx context.Context, id, token string) (User, smqerrors.SDKError) // DisableUser changes the status of the user to disabled. // @@ -371,7 +467,7 @@ type SDK interface { // ctx := context.Background() // user, _ := sdk.DisableUser(ctx, "userID", "token") // fmt.Println(user) - DisableUser(ctx context.Context, id, token string) (User, errors.SDKError) + DisableUser(ctx context.Context, id, token string) (User, smqerrors.SDKError) // DeleteUser deletes a user with the given id. // @@ -379,7 +475,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteUser(ctx, "userID", "token") // fmt.Println(err) - DeleteUser(ctx context.Context, id, token string) errors.SDKError + DeleteUser(ctx context.Context, id, token string) smqerrors.SDKError // CreateToken receives credentials and returns user token. // @@ -391,7 +487,7 @@ type SDK interface { // } // token, _ := sdk.CreateToken(ctx, lt) // fmt.Println(token) - CreateToken(ctx context.Context, lt Login) (Token, errors.SDKError) + CreateToken(ctx context.Context, lt Login) (Token, smqerrors.SDKError) // RefreshToken receives credentials and returns user token. // @@ -399,7 +495,7 @@ type SDK interface { // ctx := context.Background() // token, _ := sdk.RefreshToken(ctx, "refresh_token") // fmt.Println(token) - RefreshToken(ctx context.Context, token string) (Token, errors.SDKError) + RefreshToken(ctx context.Context, token string) (Token, smqerrors.SDKError) // SeachUsers filters users and returns a page result. // @@ -412,7 +508,7 @@ type SDK interface { // } // users, _ := sdk.SearchUsers(ctx, pm, "token") // fmt.Println(users) - SearchUsers(ctx context.Context, pm PageMetadata, token string) (UsersPage, errors.SDKError) + SearchUsers(ctx context.Context, pm PageMetadata, token string) (UsersPage, smqerrors.SDKError) // CreateClient registers new client and returns its id. // @@ -426,7 +522,7 @@ type SDK interface { // } // client, _ := sdk.CreateClient(ctx, client, "domainID", "token") // fmt.Println(client) - CreateClient(ctx context.Context, client Client, domainID, token string) (Client, errors.SDKError) + CreateClient(ctx context.Context, client Client, domainID, token string) (Client, smqerrors.SDKError) // CreateClients registers new clients and returns their ids. // @@ -448,7 +544,7 @@ type SDK interface { // } // clients, _ := sdk.CreateClients(ctx, clients, "domainID", "token") // fmt.Println(clients) - CreateClients(ctx context.Context, client []Client, domainID, token string) ([]Client, errors.SDKError) + CreateClients(ctx context.Context, client []Client, domainID, token string) ([]Client, smqerrors.SDKError) // Filters clients and returns a page result. // @@ -461,7 +557,7 @@ type SDK interface { // } // clients, _ := sdk.Clients(ctx, pm, "domainID", "token") // fmt.Println(clients) - Clients(ctx context.Context, pm PageMetadata, domainID, token string) (ClientsPage, errors.SDKError) + Clients(ctx context.Context, pm PageMetadata, domainID, token string) (ClientsPage, smqerrors.SDKError) // Client returns client object by id. // @@ -469,7 +565,7 @@ type SDK interface { // ctx := context.Background() // client, _ := sdk.Client(ctx, "clientID", "domainID", "token") // fmt.Println(client) - Client(ctx context.Context, id, domainID, token string) (Client, errors.SDKError) + Client(ctx context.Context, id, domainID, token string) (Client, smqerrors.SDKError) // UpdateClient updates existing client. // @@ -484,7 +580,7 @@ type SDK interface { // } // client, _ := sdk.UpdateClient(ctx, client, "domainID", "token") // fmt.Println(client) - UpdateClient(ctx context.Context, client Client, domainID, token string) (Client, errors.SDKError) + UpdateClient(ctx context.Context, client Client, domainID, token string) (Client, smqerrors.SDKError) // UpdateClientTags updates the client's tags. // @@ -496,7 +592,7 @@ type SDK interface { // } // client, _ := sdk.UpdateClientTags(ctx, client, "domainID", "token") // fmt.Println(client) - UpdateClientTags(ctx context.Context, client Client, domainID, token string) (Client, errors.SDKError) + UpdateClientTags(ctx context.Context, client Client, domainID, token string) (Client, smqerrors.SDKError) // UpdateClientSecret updates the client's secret // @@ -504,7 +600,7 @@ type SDK interface { // ctx := context.Background() // client, err := sdk.UpdateClientSecret(ctx, "clientID", "newSecret", "domainID," "token") // fmt.Println(client) - UpdateClientSecret(ctx context.Context, id, secret, domainID, token string) (Client, errors.SDKError) + UpdateClientSecret(ctx context.Context, id, secret, domainID, token string) (Client, smqerrors.SDKError) // EnableClient changes client status to enabled. // @@ -512,7 +608,7 @@ type SDK interface { // ctx := context.Background() // client, _ := sdk.EnableClient(ctx, "clientID", "domainID", "token") // fmt.Println(client) - EnableClient(ctx context.Context, id, domainID, token string) (Client, errors.SDKError) + EnableClient(ctx context.Context, id, domainID, token string) (Client, smqerrors.SDKError) // DisableClient changes client status to disabled - soft delete. // @@ -520,7 +616,7 @@ type SDK interface { // ctx := context.Background() // client, _ := sdk.DisableClient(ctx, "clientID", "domainID", "token") // fmt.Println(client) - DisableClient(ctx context.Context, id, domainID, token string) (Client, errors.SDKError) + DisableClient(ctx context.Context, id, domainID, token string) (Client, smqerrors.SDKError) // DeleteClient deletes a client with the given id. // @@ -528,7 +624,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteClient(ctx, "clientID", "domainID", "token") // fmt.Println(err) - DeleteClient(ctx context.Context, id, domainID, token string) errors.SDKError + DeleteClient(ctx context.Context, id, domainID, token string) smqerrors.SDKError // SetClientParent sets the parent group of a client. // @@ -536,7 +632,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.SetClientParent(ctx, "clientID", "domainID", "groupID", "token") // fmt.Println(err) - SetClientParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + SetClientParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // RemoveClientParent removes the parent group of a client. // @@ -544,7 +640,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveClientParent(ctx, "clientID", "domainID", "groupID", "token") // fmt.Println(err) - RemoveClientParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + RemoveClientParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // CreateClientRole creates new client role and returns its id. // @@ -557,7 +653,7 @@ type SDK interface { // } // role, _ := sdk.CreateClientRole(ctx, "clientID", "domainID", rq, "token") // fmt.Println(role) - CreateClientRole(ctx context.Context, id, domainID string, rq RoleReq, token string) (Role, errors.SDKError) + CreateClientRole(ctx context.Context, id, domainID string, rq RoleReq, token string) (Role, smqerrors.SDKError) // ClientRoles returns client roles. // @@ -569,7 +665,7 @@ type SDK interface { // } // roles, _ := sdk.ClientRoles(ctx, "clientID", "domainID", pm, "token") // fmt.Println(roles) - ClientRoles(ctx context.Context, id, domainID string, pm PageMetadata, token string) (RolesPage, errors.SDKError) + ClientRoles(ctx context.Context, id, domainID string, pm PageMetadata, token string) (RolesPage, smqerrors.SDKError) // ClientRole returns client role object by roleID. // @@ -577,7 +673,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.ClientRole(ctx, "clientID", "roleID", "domainID", "token") // fmt.Println(role) - ClientRole(ctx context.Context, id, roleID, domainID, token string) (Role, errors.SDKError) + ClientRole(ctx context.Context, id, roleID, domainID, token string) (Role, smqerrors.SDKError) // UpdateClientRole updates existing client role name. // @@ -585,7 +681,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.UpdateClientRole(ctx, "clientID", "roleID", "newName", "domainID", "token") // fmt.Println(role) - UpdateClientRole(ctx context.Context, id, roleID, newName, domainID string, token string) (Role, errors.SDKError) + UpdateClientRole(ctx context.Context, id, roleID, newName, domainID string, token string) (Role, smqerrors.SDKError) // DeleteClientRole deletes a client role with the given clientID and roleID. // @@ -593,7 +689,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteClientRole(ctx, "clientID", "roleID", "domainID", "token") // fmt.Println(err) - DeleteClientRole(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + DeleteClientRole(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AddClientRoleActions adds actions to a client role. // @@ -602,7 +698,7 @@ type SDK interface { // actions := []string{"read", "update"} // actions, _ := sdk.AddClientRoleActions(ctx, "clientID", "roleID", "domainID", actions, "token") // fmt.Println(actions) - AddClientRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) ([]string, errors.SDKError) + AddClientRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) ([]string, smqerrors.SDKError) // ClientRoleActions returns client role actions by roleID. // @@ -610,7 +706,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.ClientRoleActions(ctx, "clientID", "roleID", "domainID", "token") // fmt.Println(actions) - ClientRoleActions(ctx context.Context, id, roleID, domainID string, token string) ([]string, errors.SDKError) + ClientRoleActions(ctx context.Context, id, roleID, domainID string, token string) ([]string, smqerrors.SDKError) // RemoveClientRoleActions removes actions from a client role. // @@ -619,7 +715,7 @@ type SDK interface { // actions := []string{"read", "update"} // err := sdk.RemoveClientRoleActions(ctx, "clientID", "roleID", "domainID", actions, "token") // fmt.Println(err) - RemoveClientRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) errors.SDKError + RemoveClientRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) smqerrors.SDKError // RemoveAllClientRoleActions removes all actions from a client role. // @@ -627,7 +723,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllClientRoleActions(ctx, "clientID", "roleID", "domainID", "token") // fmt.Println(err) - RemoveAllClientRoleActions(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + RemoveAllClientRoleActions(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AddClientRoleMembers adds members to a client role. // @@ -636,7 +732,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // members, _ := sdk.AddClientRoleMembers(ctx, "clientID", "roleID", "domainID", members, "token") // fmt.Println(members) - AddClientRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) ([]string, errors.SDKError) + AddClientRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) ([]string, smqerrors.SDKError) // ClientRoleMembers returns client role members by roleID. // @@ -648,7 +744,7 @@ type SDK interface { // } // members, _ := sdk.ClientRoleMembers(ctx, "clientID", "roleID", "domainID", pm,"token") // fmt.Println(members) - ClientRoleMembers(ctx context.Context, id, roleID, domainID string, pm PageMetadata, token string) (RoleMembersPage, errors.SDKError) + ClientRoleMembers(ctx context.Context, id, roleID, domainID string, pm PageMetadata, token string) (RoleMembersPage, smqerrors.SDKError) // RemoveClientRoleMembers removes members from a client role. // @@ -657,7 +753,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // err := sdk.RemoveClientRoleMembers(ctx, "clientID", "roleID", "domainID", members, "token") // fmt.Println(err) - RemoveClientRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) errors.SDKError + RemoveClientRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) smqerrors.SDKError // RemoveAllClientRoleMembers removes all members from a client role. // @@ -665,7 +761,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllClientRoleMembers(ctx, "clientID", "roleID", "domainID", "token") // fmt.Println(err) - RemoveAllClientRoleMembers(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + RemoveAllClientRoleMembers(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AvailableClientRoleActions returns available actions for a client role. // @@ -673,7 +769,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.AvailableClientRoleActions(ctx, "domainID", "token") // fmt.Println(actions) - AvailableClientRoleActions(ctx context.Context, domainID, token string) ([]string, errors.SDKError) + AvailableClientRoleActions(ctx context.Context, domainID, token string) ([]string, smqerrors.SDKError) // ListClientMembers list all members from all roles in a client . // @@ -685,7 +781,7 @@ type SDK interface { // } // members, _ := sdk.ListClientMembers(ctx, "client_id","domainID", pm, "token") // fmt.Println(members) - ListClientMembers(ctx context.Context, clientID, domainID string, pm PageMetadata, token string) (EntityMembersPage, errors.SDKError) + ListClientMembers(ctx context.Context, clientID, domainID string, pm PageMetadata, token string) (EntityMembersPage, smqerrors.SDKError) // CreateGroup creates new group and returns its id. // @@ -699,7 +795,7 @@ type SDK interface { // } // group, _ := sdk.CreateGroup(ctx, group, "domainID", "token") // fmt.Println(group) - CreateGroup(ctx context.Context, group Group, domainID, token string) (Group, errors.SDKError) + CreateGroup(ctx context.Context, group Group, domainID, token string) (Group, smqerrors.SDKError) // Groups returns page of groups. // @@ -712,7 +808,7 @@ type SDK interface { // } // groups, _ := sdk.Groups(ctx, pm, "domainID", "token") // fmt.Println(groups) - Groups(ctx context.Context, pm PageMetadata, domainID, token string) (GroupsPage, errors.SDKError) + Groups(ctx context.Context, pm PageMetadata, domainID, token string) (GroupsPage, smqerrors.SDKError) // Group returns users group object by id. // @@ -720,7 +816,7 @@ type SDK interface { // ctx := context.Background() // group, _ := sdk.Group(ctx, "groupID", "domainID", "token") // fmt.Println(group) - Group(ctx context.Context, id, domainID, token string) (Group, errors.SDKError) + Group(ctx context.Context, id, domainID, token string) (Group, smqerrors.SDKError) // UpdateGroup updates existing group. // @@ -735,7 +831,7 @@ type SDK interface { // } // group, _ := sdk.UpdateGroup(ctx, group, "domainID", "token") // fmt.Println(group) - UpdateGroup(ctx context.Context, group Group, domainID, token string) (Group, errors.SDKError) + UpdateGroup(ctx context.Context, group Group, domainID, token string) (Group, smqerrors.SDKError) // UpdateGroupTags updates tags for existing group. // @@ -747,7 +843,7 @@ type SDK interface { // } // group, _ := sdk.UpdateGroupTags(ctx, group, "domainID", "token") // fmt.Println(group) - UpdateGroupTags(ctx context.Context, group Group, domainID, token string) (Group, errors.SDKError) + UpdateGroupTags(ctx context.Context, group Group, domainID, token string) (Group, smqerrors.SDKError) // SetGroupParent sets the parent group of a group. // @@ -755,7 +851,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.SetGroupParent(ctx, "groupID", "domainID", "groupID", "token") // fmt.Println(err) - SetGroupParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + SetGroupParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // RemoveGroupParent removes the parent group of a group. // @@ -763,7 +859,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveGroupParent(ctx, "groupID", "domainID", "groupID", "token") // fmt.Println(err) - RemoveGroupParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + RemoveGroupParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // AddChildren adds children groups to a group. // @@ -772,7 +868,7 @@ type SDK interface { // groupIDs := []string{"groupID1", "groupID2"} // err := sdk.AddChildren(ctx, "groupID", "domainID", groupIDs, "token") // fmt.Println(err) - AddChildren(ctx context.Context, id, domainID string, groupIDs []string, token string) errors.SDKError + AddChildren(ctx context.Context, id, domainID string, groupIDs []string, token string) smqerrors.SDKError // RemoveChildren removes children groups from a group. // @@ -781,7 +877,7 @@ type SDK interface { // groupIDs := []string{"groupID1", "groupID2"} // err := sdk.RemoveChildren(ctx, "groupID", "domainID", groupIDs, "token") // fmt.Println(err) - RemoveChildren(ctx context.Context, id, domainID string, groupIDs []string, token string) errors.SDKError + RemoveChildren(ctx context.Context, id, domainID string, groupIDs []string, token string) smqerrors.SDKError // RemoveAllChildren removes all children groups from a group. // @@ -789,7 +885,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllChildren(ctx, "groupID", "domainID", "token") // fmt.Println(err) - RemoveAllChildren(ctx context.Context, id, domainID, token string) errors.SDKError + RemoveAllChildren(ctx context.Context, id, domainID, token string) smqerrors.SDKError // Children returns page of children groups. // @@ -801,7 +897,7 @@ type SDK interface { // } // groups, _ := sdk.Children(ctx, "groupID", "domainID", pm, "token") // fmt.Println(groups) - Children(ctx context.Context, id, domainID string, pm PageMetadata, token string) (GroupsPage, errors.SDKError) + Children(ctx context.Context, id, domainID string, pm PageMetadata, token string) (GroupsPage, smqerrors.SDKError) // EnableGroup changes group status to enabled. // @@ -809,7 +905,7 @@ type SDK interface { // ctx := context.Background() // group, _ := sdk.EnableGroup(ctx, "groupID", "domainID", "token") // fmt.Println(group) - EnableGroup(ctx context.Context, id, domainID, token string) (Group, errors.SDKError) + EnableGroup(ctx context.Context, id, domainID, token string) (Group, smqerrors.SDKError) // DisableGroup changes group status to disabled - soft delete. // @@ -817,7 +913,7 @@ type SDK interface { // ctx := context.Background() // group, _ := sdk.DisableGroup(ctx, "groupID", "domainID", "token") // fmt.Println(group) - DisableGroup(ctx context.Context, id, domainID, token string) (Group, errors.SDKError) + DisableGroup(ctx context.Context, id, domainID, token string) (Group, smqerrors.SDKError) // DeleteGroup delete given group id. // @@ -825,7 +921,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteGroup(ctx, "groupID", "domainID", "token") // fmt.Println(err) - DeleteGroup(ctx context.Context, id, domainID, token string) errors.SDKError + DeleteGroup(ctx context.Context, id, domainID, token string) smqerrors.SDKError // Hierarchy returns page of groups hierarchy. // @@ -838,7 +934,7 @@ type SDK interface { // } // groups, _ := sdk.Hierarchy(ctx, "groupID", "domainID", pm, "token") // fmt.Println(groups) - Hierarchy(ctx context.Context, id, domainID string, pm PageMetadata, token string) (GroupsHierarchyPage, errors.SDKError) + Hierarchy(ctx context.Context, id, domainID string, pm PageMetadata, token string) (GroupsHierarchyPage, smqerrors.SDKError) // CreateGroupRole creates new group role and returns its id. // @@ -851,7 +947,7 @@ type SDK interface { // } // role, _ := sdk.CreateGroupRole(ctx, "groupID", "domainID", rq, "token") // fmt.Println(role) - CreateGroupRole(ctx context.Context, id, domainID string, rq RoleReq, token string) (Role, errors.SDKError) + CreateGroupRole(ctx context.Context, id, domainID string, rq RoleReq, token string) (Role, smqerrors.SDKError) // GroupRoles returns group roles. // @@ -863,7 +959,7 @@ type SDK interface { // } // roles, _ := sdk.GroupRoles(ctx, "groupID", "domainID",pm, "token") // fmt.Println(roles) - GroupRoles(ctx context.Context, id, domainID string, pm PageMetadata, token string) (RolesPage, errors.SDKError) + GroupRoles(ctx context.Context, id, domainID string, pm PageMetadata, token string) (RolesPage, smqerrors.SDKError) // GroupRole returns group role object by roleID. // @@ -871,7 +967,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.GroupRole(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(role) - GroupRole(ctx context.Context, id, roleID, domainID, token string) (Role, errors.SDKError) + GroupRole(ctx context.Context, id, roleID, domainID, token string) (Role, smqerrors.SDKError) // UpdateGroupRole updates existing group role name. // @@ -879,7 +975,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.UpdateGroupRole(ctx, "groupID", "roleID", "newName", "domainID", "token") // fmt.Println(role) - UpdateGroupRole(ctx context.Context, id, roleID, newName, domainID string, token string) (Role, errors.SDKError) + UpdateGroupRole(ctx context.Context, id, roleID, newName, domainID string, token string) (Role, smqerrors.SDKError) // DeleteGroupRole deletes a group role with the given groupID and roleID. // @@ -887,7 +983,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteGroupRole(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(err) - DeleteGroupRole(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + DeleteGroupRole(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AddGroupRoleActions adds actions to a group role. // @@ -896,7 +992,7 @@ type SDK interface { // actions := []string{"read", "update"} // actions, _ := sdk.AddGroupRoleActions(ctx, "groupID", "roleID", "domainID", actions, "token") // fmt.Println(actions) - AddGroupRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) ([]string, errors.SDKError) + AddGroupRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) ([]string, smqerrors.SDKError) // GroupRoleActions returns group role actions by roleID. // @@ -904,7 +1000,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.GroupRoleActions(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(actions) - GroupRoleActions(ctx context.Context, id, roleID, domainID string, token string) ([]string, errors.SDKError) + GroupRoleActions(ctx context.Context, id, roleID, domainID string, token string) ([]string, smqerrors.SDKError) // RemoveGroupRoleActions removes actions from a group role. // @@ -913,7 +1009,7 @@ type SDK interface { // actions := []string{"read", "update"} // err := sdk.RemoveGroupRoleActions(ctx, "groupID", "roleID", "domainID", actions, "token") // fmt.Println(err) - RemoveGroupRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) errors.SDKError + RemoveGroupRoleActions(ctx context.Context, id, roleID, domainID string, actions []string, token string) smqerrors.SDKError // RemoveAllGroupRoleActions removes all actions from a group role. // @@ -921,7 +1017,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllGroupRoleActions(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(err) - RemoveAllGroupRoleActions(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + RemoveAllGroupRoleActions(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AddGroupRoleMembers adds members to a group role. // @@ -930,7 +1026,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // members, _ := sdk.AddGroupRoleMembers(ctx, "groupID", "roleID", "domainID", members, "token") // fmt.Println(members) - AddGroupRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) ([]string, errors.SDKError) + AddGroupRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) ([]string, smqerrors.SDKError) // GroupRoleMembers returns group role members by roleID. // @@ -942,7 +1038,7 @@ type SDK interface { // } // members, _ := sdk.GroupRoleMembers(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(members) - GroupRoleMembers(ctx context.Context, id, roleID, domainID string, pm PageMetadata, token string) (RoleMembersPage, errors.SDKError) + GroupRoleMembers(ctx context.Context, id, roleID, domainID string, pm PageMetadata, token string) (RoleMembersPage, smqerrors.SDKError) // RemoveGroupRoleMembers removes members from a group role. // @@ -951,7 +1047,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // err := sdk.RemoveGroupRoleMembers(ctx, "groupID", "roleID", "domainID", members, "token") // fmt.Println(err) - RemoveGroupRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) errors.SDKError + RemoveGroupRoleMembers(ctx context.Context, id, roleID, domainID string, members []string, token string) smqerrors.SDKError // RemoveAllGroupRoleMembers removes all members from a group role. // @@ -959,7 +1055,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllGroupRoleMembers(ctx, "groupID", "roleID", "domainID", "token") // fmt.Println(err) - RemoveAllGroupRoleMembers(ctx context.Context, id, roleID, domainID, token string) errors.SDKError + RemoveAllGroupRoleMembers(ctx context.Context, id, roleID, domainID, token string) smqerrors.SDKError // AvailableGroupRoleActions returns available actions for a group role. // @@ -967,7 +1063,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.AvailableGroupRoleActions(ctx, "groupID", "token") // fmt.Println(actions) - AvailableGroupRoleActions(ctx context.Context, id, token string) ([]string, errors.SDKError) + AvailableGroupRoleActions(ctx context.Context, id, token string) ([]string, smqerrors.SDKError) // ListGroupMembers list all members from all roles in a group . // @@ -979,7 +1075,7 @@ type SDK interface { // } // members, _ := sdk.ListGroupMembers(ctx, "group_id","domainID", pm, "token") // fmt.Println(members) - ListGroupMembers(ctx context.Context, groupID, domainID string, pm PageMetadata, token string) (EntityMembersPage, errors.SDKError) + ListGroupMembers(ctx context.Context, groupID, domainID string, pm PageMetadata, token string) (EntityMembersPage, smqerrors.SDKError) // CreateChannel creates new channel and returns its id. // @@ -993,7 +1089,7 @@ type SDK interface { // } // channel, _ := sdk.CreateChannel(ctx, channel, "domainID", "token") // fmt.Println(channel) - CreateChannel(ctx context.Context, channel Channel, domainID, token string) (Channel, errors.SDKError) + CreateChannel(ctx context.Context, channel Channel, domainID, token string) (Channel, smqerrors.SDKError) // CreateChannels creates new channels and returns their ids. // @@ -1015,7 +1111,7 @@ type SDK interface { // } // channels, _ := sdk.CreateChannels(ctx, channels, "domainID", "token") // fmt.Println(channels) - CreateChannels(ctx context.Context, channels []Channel, domainID, token string) ([]Channel, errors.SDKError) + CreateChannels(ctx context.Context, channels []Channel, domainID, token string) ([]Channel, smqerrors.SDKError) // Channels returns page of channels. // @@ -1028,7 +1124,7 @@ type SDK interface { // } // channels, _ := sdk.Channels(ctx, pm, "domainID", "token") // fmt.Println(channels) - Channels(ctx context.Context, pm PageMetadata, domainID, token string) (ChannelsPage, errors.SDKError) + Channels(ctx context.Context, pm PageMetadata, domainID, token string) (ChannelsPage, smqerrors.SDKError) // Channel returns channel data by id. // @@ -1036,7 +1132,7 @@ type SDK interface { // ctx := context.Background() // channel, _ := sdk.Channel(ctx, "channelID", "domainID", "token") // fmt.Println(channel) - Channel(ctx context.Context, id, domainID, token string) (Channel, errors.SDKError) + Channel(ctx context.Context, id, domainID, token string) (Channel, smqerrors.SDKError) // UpdateChannel updates existing channel. // @@ -1051,7 +1147,7 @@ type SDK interface { // } // channel, _ := sdk.UpdateChannel(ctx, channel, "domainID", "token") // fmt.Println(channel) - UpdateChannel(ctx context.Context, channel Channel, domainID, token string) (Channel, errors.SDKError) + UpdateChannel(ctx context.Context, channel Channel, domainID, token string) (Channel, smqerrors.SDKError) // UpdateChannelTags updates the channel's tags. // @@ -1063,7 +1159,7 @@ type SDK interface { // } // channel, _ := sdk.UpdateChannelTags(ctx, channel, "domainID", "token") // fmt.Println(channel) - UpdateChannelTags(ctx context.Context, c Channel, domainID, token string) (Channel, errors.SDKError) + UpdateChannelTags(ctx context.Context, c Channel, domainID, token string) (Channel, smqerrors.SDKError) // EnableChannel changes channel status to enabled. // @@ -1071,7 +1167,7 @@ type SDK interface { // ctx := context.Background() // channel, _ := sdk.EnableChannel(ctx, "channelID", "domainID", "token") // fmt.Println(channel) - EnableChannel(ctx context.Context, id, domainID, token string) (Channel, errors.SDKError) + EnableChannel(ctx context.Context, id, domainID, token string) (Channel, smqerrors.SDKError) // DisableChannel changes channel status to disabled - soft delete. // @@ -1079,7 +1175,7 @@ type SDK interface { // ctx := context.Background() // channel, _ := sdk.DisableChannel(ctx, "channelID", "domainID", "token") // fmt.Println(channel) - DisableChannel(ctx context.Context, id, domainID, token string) (Channel, errors.SDKError) + DisableChannel(ctx context.Context, id, domainID, token string) (Channel, smqerrors.SDKError) // DeleteChannel delete given group id. // @@ -1087,7 +1183,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteChannel(ctx, "channelID", "domainID", "token") // fmt.Println(err) - DeleteChannel(ctx context.Context, id, domainID, token string) errors.SDKError + DeleteChannel(ctx context.Context, id, domainID, token string) smqerrors.SDKError // SetChannelParent sets the parent group of a channel. // @@ -1095,7 +1191,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.SetChannelParent(ctx, "channelID", "domainID", "groupID", "token") // fmt.Println(err) - SetChannelParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + SetChannelParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // RemoveChannelParent removes the parent group of a channel. // @@ -1103,7 +1199,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveChannelParent(ctx, "channelID", "domainID", "groupID", "token") // fmt.Println(err) - RemoveChannelParent(ctx context.Context, id, domainID, groupID, token string) errors.SDKError + RemoveChannelParent(ctx context.Context, id, domainID, groupID, token string) smqerrors.SDKError // Connect bulk connects clients to channels specified by id. // @@ -1116,7 +1212,7 @@ type SDK interface { // } // err := sdk.Connect(ctx, conns, "domainID", "token") // fmt.Println(err) - Connect(ctx context.Context, conn Connection, domainID, token string) errors.SDKError + Connect(ctx context.Context, conn Connection, domainID, token string) smqerrors.SDKError // Disconnect // @@ -1129,7 +1225,7 @@ type SDK interface { // } // err := sdk.Disconnect(ctx, conns, "domainID", "token") // fmt.Println(err) - Disconnect(ctx context.Context, conn Connection, domainID, token string) errors.SDKError + Disconnect(ctx context.Context, conn Connection, domainID, token string) smqerrors.SDKError // ConnectClient connects client to specified channel by id. // @@ -1138,7 +1234,7 @@ type SDK interface { // clientIDs := []string{"client_id_1", "client_id_2"} // err := sdk.ConnectClients(ctx, "channelID", clientIDs, []string{"Publish", "Subscribe"}, "domainID", "token") // fmt.Println(err) - ConnectClients(ctx context.Context, channelID string, clientIDs, connTypes []string, domainID, token string) errors.SDKError + ConnectClients(ctx context.Context, channelID string, clientIDs, connTypes []string, domainID, token string) smqerrors.SDKError // DisconnectClient disconnect client from specified channel by id. // @@ -1147,7 +1243,7 @@ type SDK interface { // clientIDs := []string{"client_id_1", "client_id_2"} // err := sdk.DisconnectClients(ctx, "channelID", clientIDs, []string{"Publish", "Subscribe"}, "domainID", "token") // fmt.Println(err) - DisconnectClients(ctx context.Context, channelID string, clientIDs, connTypes []string, domainID, token string) errors.SDKError + DisconnectClients(ctx context.Context, channelID string, clientIDs, connTypes []string, domainID, token string) smqerrors.SDKError // ListChannelMembers list all members from all roles in a channel . // @@ -1159,7 +1255,7 @@ type SDK interface { // } // members, _ := sdk.ListChannelMembers(ctx, "channel_id","domainID", pm, "token") // fmt.Println(members) - ListChannelMembers(ctx context.Context, channelID, domainID string, pm PageMetadata, token string) (EntityMembersPage, errors.SDKError) + ListChannelMembers(ctx context.Context, channelID, domainID string, pm PageMetadata, token string) (EntityMembersPage, smqerrors.SDKError) // SendMessage send message to specified channel. // @@ -1168,21 +1264,21 @@ type SDK interface { // msg := '[{"bn":"some-base-name:","bt":1.276020076001e+09, "bu":"A","bver":5, "n":"voltage","u":"V","v":120.1}, {"n":"current","t":-5,"v":1.2}, {"n":"current","t":-4,"v":1.3}]' // err := sdk.SendMessage(ctx, "domainID", "topic", msg, "clientSecret") // fmt.Println(err) - SendMessage(ctx context.Context, domainID, topic, msg, secret string) errors.SDKError + SendMessage(ctx context.Context, domainID, topic, msg, secret string) smqerrors.SDKError // SetContentType sets message content type. // // example: // err := sdk.SetContentType("application/json") // fmt.Println(err) - SetContentType(ct ContentType) errors.SDKError + SetContentType(ct ContentType) smqerrors.SDKError // Health returns service health check. // // example: // health, _ := sdk.Health("service") // fmt.Println(health) - Health(service string) (HealthInfo, errors.SDKError) + Health(service string) (HealthInfo, smqerrors.SDKError) // CreateDomain creates new domain and returns its details. // @@ -1196,7 +1292,7 @@ type SDK interface { // } // domain, _ := sdk.CreateDomain(ctx, group, "token") // fmt.Println(domain) - CreateDomain(ctx context.Context, d Domain, token string) (Domain, errors.SDKError) + CreateDomain(ctx context.Context, d Domain, token string) (Domain, smqerrors.SDKError) // Domain retrieve domain information of given domain ID . // @@ -1204,7 +1300,7 @@ type SDK interface { // ctx := context.Background() // domain, _ := sdk.Domain(ctx, "domainID", "token") // fmt.Println(domain) - Domain(ctx context.Context, domainID, token string) (Domain, errors.SDKError) + Domain(ctx context.Context, domainID, token string) (Domain, smqerrors.SDKError) // UpdateDomain updates details of the given domain ID. // @@ -1219,7 +1315,7 @@ type SDK interface { // } // domain, _ := sdk.UpdateDomain(ctx, domain, "token") // fmt.Println(domain) - UpdateDomain(ctx context.Context, d Domain, token string) (Domain, errors.SDKError) + UpdateDomain(ctx context.Context, d Domain, token string) (Domain, smqerrors.SDKError) // Domains returns list of domain for the given filters. // @@ -1233,7 +1329,7 @@ type SDK interface { // } // domains, _ := sdk.Domains(ctx, pm, "token") // fmt.Println(domains) - Domains(ctx context.Context, pm PageMetadata, token string) (DomainsPage, errors.SDKError) + Domains(ctx context.Context, pm PageMetadata, token string) (DomainsPage, smqerrors.SDKError) // EnableDomain changes the status of the domain to enabled. // @@ -1241,7 +1337,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.EnableDomain(ctx, "domainID", "token") // fmt.Println(err) - EnableDomain(ctx context.Context, domainID, token string) errors.SDKError + EnableDomain(ctx context.Context, domainID, token string) smqerrors.SDKError // DisableDomain changes the status of the domain to disabled. // @@ -1249,7 +1345,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DisableDomain(ctx, "domainID", "token") // fmt.Println(err) - DisableDomain(ctx context.Context, domainID, token string) errors.SDKError + DisableDomain(ctx context.Context, domainID, token string) smqerrors.SDKError // FreezeDomain changes the status of the domain to frozen. // @@ -1257,7 +1353,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.FreezeDomain(ctx, "domainID", "token") // fmt.Println(err) - FreezeDomain(ctx context.Context, domainID, token string) errors.SDKError + FreezeDomain(ctx context.Context, domainID, token string) smqerrors.SDKError // CreateDomainRole creates new domain role and returns its id. // @@ -1270,7 +1366,7 @@ type SDK interface { // } // role, _ := sdk.CreateDomainRole(ctx, "domainID", rq, "token") // fmt.Println(role) - CreateDomainRole(ctx context.Context, id string, rq RoleReq, token string) (Role, errors.SDKError) + CreateDomainRole(ctx context.Context, id string, rq RoleReq, token string) (Role, smqerrors.SDKError) // DomainRoles returns domain roles. // @@ -1282,7 +1378,7 @@ type SDK interface { // } // roles, _ := sdk.DomainRoles(ctx, "domainID", pm, "token") // fmt.Println(roles) - DomainRoles(ctx context.Context, id string, pm PageMetadata, token string) (RolesPage, errors.SDKError) + DomainRoles(ctx context.Context, id string, pm PageMetadata, token string) (RolesPage, smqerrors.SDKError) // DomainRole returns domain role object by roleID. // @@ -1290,7 +1386,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.DomainRole(ctx, "domainID", "roleID", "token") // fmt.Println(role) - DomainRole(ctx context.Context, id, roleID, token string) (Role, errors.SDKError) + DomainRole(ctx context.Context, id, roleID, token string) (Role, smqerrors.SDKError) // UpdateDomainRole updates existing domain role name. // @@ -1298,7 +1394,7 @@ type SDK interface { // ctx := context.Background() // role, _ := sdk.UpdateDomainRole(ctx, "domainID", "roleID", "newName", "token") // fmt.Println(role) - UpdateDomainRole(ctx context.Context, id, roleID, newName string, token string) (Role, errors.SDKError) + UpdateDomainRole(ctx context.Context, id, roleID, newName string, token string) (Role, smqerrors.SDKError) // DeleteDomainRole deletes a domain role with the given domainID and roleID. // @@ -1306,7 +1402,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.DeleteDomainRole(ctx, "domainID", "roleID", "token") // fmt.Println(err) - DeleteDomainRole(ctx context.Context, id, roleID, token string) errors.SDKError + DeleteDomainRole(ctx context.Context, id, roleID, token string) smqerrors.SDKError // AddDomainRoleActions adds actions to a domain role. // @@ -1315,7 +1411,7 @@ type SDK interface { // actions := []string{"read", "update"} // actions, _ := sdk.AddDomainRoleActions(ctx, "domainID", "roleID", actions, "token") // fmt.Println(actions) - AddDomainRoleActions(ctx context.Context, id, roleID string, actions []string, token string) ([]string, errors.SDKError) + AddDomainRoleActions(ctx context.Context, id, roleID string, actions []string, token string) ([]string, smqerrors.SDKError) // DomainRoleActions returns domain role actions by roleID. // @@ -1323,7 +1419,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.DomainRoleActions(ctx, "domainID", "roleID", "token") // fmt.Println(actions) - DomainRoleActions(ctx context.Context, id, roleID string, token string) ([]string, errors.SDKError) + DomainRoleActions(ctx context.Context, id, roleID string, token string) ([]string, smqerrors.SDKError) // RemoveDomainRoleActions removes actions from a domain role. // @@ -1332,7 +1428,7 @@ type SDK interface { // actions := []string{"read", "update"} // err := sdk.RemoveDomainRoleActions(ctx, "domainID", "roleID", actions, "token") // fmt.Println(err) - RemoveDomainRoleActions(ctx context.Context, id, roleID string, actions []string, token string) errors.SDKError + RemoveDomainRoleActions(ctx context.Context, id, roleID string, actions []string, token string) smqerrors.SDKError // RemoveAllDomainRoleActions removes all actions from a domain role. // @@ -1340,7 +1436,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllDomainRoleActions(ctx, "domainID", "roleID", "token") // fmt.Println(err) - RemoveAllDomainRoleActions(ctx context.Context, id, roleID, token string) errors.SDKError + RemoveAllDomainRoleActions(ctx context.Context, id, roleID, token string) smqerrors.SDKError // AddDomainRoleMembers adds members to a domain role. // @@ -1349,7 +1445,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // members, _ := sdk.AddDomainRoleMembers(ctx, "domainID", "roleID", members, "token") // fmt.Println(members) - AddDomainRoleMembers(ctx context.Context, id, roleID string, members []string, token string) ([]string, errors.SDKError) + AddDomainRoleMembers(ctx context.Context, id, roleID string, members []string, token string) ([]string, smqerrors.SDKError) // DomainRoleMembers returns domain role members by roleID. // @@ -1361,7 +1457,7 @@ type SDK interface { // } // members, _ := sdk.DomainRoleMembers(ctx, "domainID", "roleID", "token") // fmt.Println(members) - DomainRoleMembers(ctx context.Context, id, roleID string, pm PageMetadata, token string) (RoleMembersPage, errors.SDKError) + DomainRoleMembers(ctx context.Context, id, roleID string, pm PageMetadata, token string) (RoleMembersPage, smqerrors.SDKError) // RemoveDomainRoleMembers removes members from a domain role. // @@ -1370,7 +1466,7 @@ type SDK interface { // members := []string{"member_id_1", "member_id_2"} // err := sdk.RemoveDomainRoleMembers(ctx, "domainID", "roleID", members, "token") // fmt.Println(err) - RemoveDomainRoleMembers(ctx context.Context, id, roleID string, members []string, token string) errors.SDKError + RemoveDomainRoleMembers(ctx context.Context, id, roleID string, members []string, token string) smqerrors.SDKError // RemoveAllDomainRoleMembers removes all members from a domain role. // @@ -1378,7 +1474,7 @@ type SDK interface { // ctx := context.Background() // err := sdk.RemoveAllDomainRoleMembers(ctx, "domainID", "roleID", "token") // fmt.Println(err) - RemoveAllDomainRoleMembers(ctx context.Context, id, roleID, token string) errors.SDKError + RemoveAllDomainRoleMembers(ctx context.Context, id, roleID, token string) smqerrors.SDKError // AvailableDomainRoleActions returns available actions for a domain role. // @@ -1386,7 +1482,7 @@ type SDK interface { // ctx := context.Background() // actions, _ := sdk.AvailableDomainRoleActions(ctx, "token") // fmt.Println(actions) - AvailableDomainRoleActions(ctx context.Context, token string) ([]string, errors.SDKError) + AvailableDomainRoleActions(ctx context.Context, token string) ([]string, smqerrors.SDKError) // ListDomainUsers returns list of users for the given domain ID and filters. // @@ -1398,7 +1494,7 @@ type SDK interface { // } // members, _ := sdk.ListDomainMembers(ctx, "domain_id", pm, "token") // fmt.Println(members) - ListDomainMembers(ctx context.Context, domainID string, pm PageMetadata, token string) (EntityMembersPage, errors.SDKError) + ListDomainMembers(ctx context.Context, domainID string, pm PageMetadata, token string) (EntityMembersPage, smqerrors.SDKError) // SendInvitation sends an invitation to the email address associated with the given user. // @@ -1463,6 +1559,216 @@ type SDK interface { // invitations, _ := sdk.DomainInvitations(ctx, "domainID", pm, "token") // fmt.Println(invitations) DomainInvitations(ctx context.Context, pm PageMetadata, token, domainID string) (invitations InvitationPage, err error) + + // AddBootstrap add bootstrap configuration + AddBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) (string, smqerrors.SDKError) + + // ViewBootstrap returns Client Config with given ID belonging to the user identified by the given token. + ViewBootstrap(ctx context.Context, id, domainID, token string) (BootstrapConfig, smqerrors.SDKError) + + // UpdateBootstrap updates editable fields of the provided Config. + UpdateBootstrap(ctx context.Context, cfg BootstrapConfig, domainID, token string) smqerrors.SDKError + + // UpdateBootstrapCerts updates bootstrap config certificates. + UpdateBootstrapCerts(ctx context.Context, id string, clientCert, clientKey, ca string, domainID, token string) (BootstrapConfig, smqerrors.SDKError) + + // UpdateBootstrapConnection updates connections performs update of the channel list corresponding Client is connected to. + UpdateBootstrapConnection(ctx context.Context, id string, channels []string, domainID, token string) smqerrors.SDKError + + // RemoveBootstrap removes Config with specified token that belongs to the user identified by the given token. + RemoveBootstrap(ctx context.Context, id, domainID, token string) smqerrors.SDKError + + // Bootstrap returns Config to the Client with provided external ID using external key. + Bootstrap(ctx context.Context, externalID, externalKey string) (BootstrapConfig, smqerrors.SDKError) + + // BootstrapSecure retrieves a configuration with given external ID and encrypted external key. + BootstrapSecure(ctx context.Context, externalID, externalKey, cryptoKey string) (BootstrapConfig, smqerrors.SDKError) + + // Bootstraps retrieves a list of managed configs. + Bootstraps(ctx context.Context, pm PageMetadata, domainID, token string) (BootstrapPage, smqerrors.SDKError) + + // Whitelist updates Client state Config with given ID belonging to the user identified by the given token. + Whitelist(ctx context.Context, clientID string, state int, domainID, token string) smqerrors.SDKError + + // ReadMessages reads messages of specified channel. + ReadMessages(ctx context.Context, pm MessagePageMetadata, chanID, domainID, token string) (MessagesPage, smqerrors.SDKError) + + // CreateSubscription creates a new subscription. + CreateSubscription(ctx context.Context, topic, contact, token string) (string, smqerrors.SDKError) + + // ListSubscriptions list subscriptions given list parameters. + ListSubscriptions(ctx context.Context, pm PageMetadata, token string) (SubscriptionPage, smqerrors.SDKError) + + // ViewSubscription retrieves a subscription with the provided id. + ViewSubscription(ctx context.Context, id, token string) (Subscription, smqerrors.SDKError) + + // DeleteSubscription removes a subscription with the provided id. + DeleteSubscription(ctx context.Context, id, token string) smqerrors.SDKError + + // UpdateAlarm updates an existing alarm. + UpdateAlarm(ctx context.Context, alarm Alarm, domainID, token string) (Alarm, smqerrors.SDKError) + + // ViewAlarm retrieves an alarm by its ID. + ViewAlarm(ctx context.Context, id, domainID, token string) (Alarm, smqerrors.SDKError) + + // ListAlarms retrieves a page of alarms. + ListAlarms(ctx context.Context, pm PageMetadata, domainID, token string) (AlarmsPage, smqerrors.SDKError) + + // DeleteAlarm deletes an alarm. + DeleteAlarm(ctx context.Context, id, domainID, token string) smqerrors.SDKError + + // AddReportConfig creates a new report configuration. + AddReportConfig(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // ViewReportConfig retrieves a report config by its ID. + ViewReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // UpdateReportConfig updates an existing report configuration. + UpdateReportConfig(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // UpdateReportSchedule updates an existing report configuration's schedule. + UpdateReportSchedule(ctx context.Context, cfg ReportConfig, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // RemoveReportConfig deletes a report config. + RemoveReportConfig(ctx context.Context, id, domainID, token string) smqerrors.SDKError + + // ListReportsConfig retrieves a page of report configs. + ListReportsConfig(ctx context.Context, pm PageMetadata, domainID, token string) (ReportConfigPage, smqerrors.SDKError) + + // EnableReportConfig enables a report config. + EnableReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // DisableReportConfig disables a report config. + DisableReportConfig(ctx context.Context, id, domainID, token string) (ReportConfig, smqerrors.SDKError) + + // UpdateReportTemplate updates a report template. + UpdateReportTemplate(ctx context.Context, cfg ReportConfig, domainID, token string) smqerrors.SDKError + + // ViewReportTemplate retrieves a report template. + ViewReportTemplate(ctx context.Context, id, domainID, token string) (ReportTemplate, smqerrors.SDKError) + + // DeleteReportTemplate deletes a report template. + DeleteReportTemplate(ctx context.Context, id, domainID, token string) smqerrors.SDKError + + // GenerateReport generates a report from a configuration. + GenerateReport(ctx context.Context, config ReportConfig, action ReportAction, domainID, token string) (ReportPage, *ReportFile, smqerrors.SDKError) + + // AddRule creates a new rule. + AddRule(ctx context.Context, r Rule, domainID, token string) (Rule, smqerrors.SDKError) + + // ViewRule retrieves a rule by its ID. + ViewRule(ctx context.Context, id, domainID, token string) (Rule, smqerrors.SDKError) + + // UpdateRule updates an existing rule. + UpdateRule(ctx context.Context, r Rule, domainID, token string) (Rule, smqerrors.SDKError) + + // UpdateRuleTags updates an existing rule's tags. + UpdateRuleTags(ctx context.Context, r Rule, domainID, token string) (Rule, smqerrors.SDKError) + + // UpdateRuleSchedule updates an existing rule's schedule. + UpdateRuleSchedule(ctx context.Context, r Rule, domainID, token string) (Rule, smqerrors.SDKError) + + // ListRules retrieves a page of rules. + ListRules(ctx context.Context, pm PageMetadata, domainID, token string) (Page, smqerrors.SDKError) + + // RemoveRule deletes a rule. + RemoveRule(ctx context.Context, id, domainID, token string) smqerrors.SDKError + + // EnableRule enables a rule. + EnableRule(ctx context.Context, id, domainID, token string) (Rule, smqerrors.SDKError) + + // DisableRule disables a rule. + DisableRule(ctx context.Context, id, domainID, token string) (Rule, smqerrors.SDKError) + + // IssueCert issues a certificate for an entity. + // + // example: + // cert, _ := sdk.IssueCert(context.Background(), "entityID", "8760h", []string{"127.0.0.1"}, sdk.Options{CommonName: "cn"}, "domainID", "token") + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, opts Options, domainID, token string) (Certificate, smqerrors.SDKError) + + // RevokeCert revokes a certificate by serial number. + // + // example: + // err := sdk.RevokeCert(context.Background(), "serialNumber", "domainID", "token") + RevokeCert(ctx context.Context, serialNumber, domainID, token string) smqerrors.SDKError + + // RenewCert renews a certificate by serial number. + // + // example: + // cert, _ := sdk.RenewCert(context.Background(), "serialNumber", "domainID", "token") + RenewCert(ctx context.Context, serialNumber, domainID, token string) (Certificate, smqerrors.SDKError) + + // ListCerts lists certificates matching the given metadata filter. + // + // example: + // page, _ := sdk.ListCerts(context.Background(), sdk.PageMetadata{Limit: 10}, "domainID", "token") + ListCerts(ctx context.Context, pm PageMetadata, domainID, token string) (CertificatePage, smqerrors.SDKError) + + // DeleteCert deletes all certificates for the given entity ID. + // + // example: + // err := sdk.DeleteCert(context.Background(), "entityID", "domainID", "token") + DeleteCert(ctx context.Context, entityID, domainID, token string) smqerrors.SDKError + + // ViewCert retrieves a certificate by serial number. + // + // example: + // cert, _ := sdk.ViewCert(context.Background(), "serialNumber", "domainID", "token") + ViewCert(ctx context.Context, serialNumber, domainID, token string) (Certificate, smqerrors.SDKError) + + // OCSP checks the revocation status of a certificate. + // + // example: + // resp, _ := sdk.OCSP(context.Background(), "serialNumber", "") + OCSP(ctx context.Context, serialNumber, cert string) (OCSPResponse, smqerrors.SDKError) + + // ViewCA views the signing CA certificate. + // + // example: + // cert, _ := sdk.ViewCA(context.Background()) + ViewCA(ctx context.Context) (Certificate, smqerrors.SDKError) + + // DownloadCA downloads the signing CA certificate bundle. + // + // example: + // bundle, _ := sdk.DownloadCA(context.Background()) + DownloadCA(ctx context.Context) (CertificateBundle, smqerrors.SDKError) + + // IssueFromCSR issues a certificate from a provided CSR. + // + // example: + // cert, _ := sdk.IssueFromCSR(context.Background(), "entityID", "8760h", csrPEM, "domainID", "token") + IssueFromCSR(ctx context.Context, entityID, ttl, csr, domainID, token string) (Certificate, smqerrors.SDKError) + + // IssueFromCSRInternal issues a certificate from a CSR using agent authentication. + // + // example: + // cert, _ := sdk.IssueFromCSRInternal(context.Background(), "entityID", "8760h", csrPEM, "agentToken") + IssueFromCSRInternal(ctx context.Context, entityID, ttl, csr, token string) (Certificate, smqerrors.SDKError) + + // GenerateCRL generates a Certificate Revocation List. + // + // example: + // crl, _ := sdk.GenerateCRL(context.Background()) + GenerateCRL(ctx context.Context) ([]byte, smqerrors.SDKError) + + // RevokeAll revokes all certificates for an entity ID. + // + // example: + // err := sdk.RevokeAll(context.Background(), "entityID", "domainID", "token") + RevokeAll(ctx context.Context, entityID, domainID, token string) smqerrors.SDKError + + // EntityID gets the entity ID for a certificate by serial number. + // + // example: + // id, _ := sdk.EntityID(context.Background(), "serialNumber", "domainID", "token") + EntityID(ctx context.Context, serialNumber, domainID, token string) (string, smqerrors.SDKError) + + // CreateCSR creates a Certificate Signing Request from metadata and a private key. + // + // example: + // csr, _ := sdk.CreateCSR(context.Background(), metadata, privateKeyBytes) + CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, smqerrors.SDKError) } type mgSDK struct { @@ -1475,6 +1781,11 @@ type mgSDK struct { domainsURL string journalURL string HostURL string + bootstrapURL string + readersURL string + alarmsURL string + reportsURL string + rulesEngineURL string msgContentType ContentType client *http.Client @@ -1493,6 +1804,11 @@ type Config struct { DomainsURL string JournalURL string HostURL string + BootstrapURL string + ReaderURL string + AlarmsURL string + ReportsURL string + RulesEngineURL string MsgContentType ContentType TLSVerification bool @@ -1512,12 +1828,18 @@ func NewSDK(conf Config) SDK { domainsURL: conf.DomainsURL, journalURL: conf.JournalURL, HostURL: conf.HostURL, + bootstrapURL: conf.BootstrapURL, + readersURL: conf.ReaderURL, + alarmsURL: conf.AlarmsURL, + reportsURL: conf.ReportsURL, + rulesEngineURL: conf.RulesEngineURL, msgContentType: conf.MsgContentType, client: &http.Client{Transport: otelhttp.NewTransport(&http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: !conf.TLSVerification, }, + IdleConnTimeout: 90 * time.Second, })}, curlFlag: conf.CurlFlag, roles: conf.Roles, @@ -1526,13 +1848,13 @@ func NewSDK(conf Config) SDK { // processRequest creates and send a new HTTP request, and checks for errors in the HTTP response. // It then returns the response headers, the response body, and the associated error(s) (if any). -func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, errors.SDKError) { +func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token string, data []byte, headers map[string]string, expectedRespCodes ...int) (http.Header, []byte, smqerrors.SDKError) { if sdk.roles { reqUrl = fmt.Sprintf("%s?roles=%v", reqUrl, true) } req, err := http.NewRequestWithContext(ctx, method, reqUrl, bytes.NewReader(data)) if err != nil { - return make(http.Header), []byte{}, errors.NewSDKError(err) + return make(http.Header), []byte{}, smqerrors.NewSDKError(err) } // Sets a default value for the Content-Type. @@ -1545,7 +1867,7 @@ func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token strin if token != "" { if !strings.Contains(token, ClientPrefix) { - token = fmt.Sprintf("%s%s", BearerPrefix, token) + token = BearerPrefix + token } req.Header.Set("Authorization", token) } @@ -1553,25 +1875,35 @@ func (sdk mgSDK) processRequest(ctx context.Context, method, reqUrl, token strin if sdk.curlFlag { curlCommand, err := http2curl.GetCurlCommand(req) if err != nil { - return nil, nil, errors.NewSDKError(err) + return nil, nil, smqerrors.NewSDKError(err) } log.Println(curlCommand.String()) } resp, err := sdk.client.Do(req) if err != nil { - return make(http.Header), []byte{}, errors.NewSDKError(err) + var opErr *net.OpError + switch { + case errors.Is(err, syscall.ECONNRESET): + return make(http.Header), []byte{}, smqerrors.NewSDKError(fmt.Errorf("request failed: connection reset by peer: %w", err)) + case errors.As(err, &opErr): + return make(http.Header), []byte{}, smqerrors.NewSDKError(fmt.Errorf("request failed: network error (%s): %w", opErr.Op, err)) + case errors.Is(err, io.EOF): + return make(http.Header), []byte{}, smqerrors.NewSDKError(fmt.Errorf("request failed: connection closed unexpectedly: %w", err)) + default: + return make(http.Header), []byte{}, smqerrors.NewSDKError(fmt.Errorf("request failed: %w", err)) + } } defer resp.Body.Close() - sdkErr := errors.CheckError(resp, expectedRespCodes...) + sdkErr := smqerrors.CheckError(resp, expectedRespCodes...) if sdkErr != nil { return make(http.Header), []byte{}, sdkErr } body, err := io.ReadAll(resp.Body) if err != nil { - return make(http.Header), []byte{}, errors.NewSDKError(err) + return make(http.Header), []byte{}, smqerrors.NewSDKError(err) } return resp.Header, body, nil @@ -1639,7 +1971,7 @@ func (pm PageMetadata) query() (string, error) { if pm.Metadata != nil { md, err := json.Marshal(pm.Metadata) if err != nil { - return "", errors.NewSDKError(err) + return "", smqerrors.NewSDKError(err) } q.Add("metadata", string(md)) } @@ -1712,6 +2044,15 @@ func (pm PageMetadata) query() (string, error) { } q.Add("with_attributes", strconv.FormatBool(pm.WithAttributes)) q.Add("with_metadata", strconv.FormatBool(pm.WithMetadata)) + if pm.EntityID != "" { + q.Add("entity_id", pm.EntityID) + } + if pm.CommonName != "" { + q.Add("common_name", pm.CommonName) + } + if pm.TTL != "" { + q.Add("ttl", pm.TTL) + } return q.Encode(), nil } diff --git a/pkg/sdk/setup_test.go b/pkg/sdk/setup_test.go index 0aaffdc2c..eeba97eed 100644 --- a/pkg/sdk/setup_test.go +++ b/pkg/sdk/setup_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/absmach/supermq/channels" + chmocks "github.com/absmach/supermq/channels/mocks" "github.com/absmach/supermq/clients" + climocks "github.com/absmach/supermq/clients/mocks" "github.com/absmach/supermq/domains" groups "github.com/absmach/supermq/groups" "github.com/absmach/supermq/internal/nullable" @@ -31,7 +33,7 @@ const ( InvalidEmail = "invalidemail" secret = "strongsecret" invalidToken = "invalid" - contentType = "application/senml+json" + contentType = sdk.CTJSON invalid = "invalid" wrongID = "wrongID" roleName = "roleName" @@ -49,6 +51,9 @@ var ( total uint64 = 200 passRegex = regexp.MustCompile("^.{8,}$") validID = testsutil.GenerateUUID(&testing.T{}) + + clientsGRPCClient *climocks.ClientsServiceClient + channelsGRPCClient *chmocks.ChannelsServiceClient ) func generateUUID(t *testing.T) string { diff --git a/pkg/sdk/transport_test.go b/pkg/sdk/transport_test.go new file mode 100644 index 000000000..4c37e6160 --- /dev/null +++ b/pkg/sdk/transport_test.go @@ -0,0 +1,157 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package sdk_test + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/absmach/supermq/pkg/sdk" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestTransport verifies that IdleConnTimeout=90s keeps connections pooled for +// healthy servers, and that network errors (EOF, reset) surface as descriptive errors. +func TestTransport(t *testing.T) { + cases := []struct { + desc string + serverFunc func(t *testing.T) (url string, cleanup func()) + ctxFunc func() context.Context + wantErr bool + errContains string + }{ + { + desc: "make request successfully with connection reuse", + serverFunc: func(t *testing.T) (string, func()) { + t.Helper() + var connCount atomic.Int32 + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"clients":[{"id":"1","name":"test-client"}]}`)) + })) + srv.Config.ConnState = func(_ net.Conn, state http.ConnState) { + if state == http.StateNew { + connCount.Add(1) + } + } + srv.Start() + return srv.URL, func() { + srv.Close() + assert.Equal(t, int32(1), connCount.Load(), "expected connections to be reused (keep-alives enabled)") + } + }, + wantErr: false, + }, + { + desc: "make request with server closing connection", + serverFunc: func(t *testing.T) (string, func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + conn.Close() + } + }() + return "http://" + ln.Addr().String(), func() { ln.Close() } + }, + wantErr: true, + errContains: "request failed", + }, + { + desc: "make request with connection reset by peer", + serverFunc: func(t *testing.T) (string, func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + + tcpConn, ok := conn.(*net.TCPConn) + if ok { + _ = tcpConn.SetLinger(0) + } + conn.Close() + } + }() + return "http://" + ln.Addr().String(), func() { ln.Close() } + }, + wantErr: true, + errContains: "request failed", + }, + { + desc: "make request with unreachable server", + serverFunc: func(t *testing.T) (string, func()) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := ln.Addr().String() + ln.Close() + return "http://" + addr, func() {} + }, + wantErr: true, + errContains: "request failed", + }, + { + desc: "make request with cancelled context", + serverFunc: func(t *testing.T) (string, func()) { + t.Helper() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + return srv.URL, srv.Close + }, + ctxFunc: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx + }, + wantErr: true, + errContains: "request failed", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + url, cleanup := tc.serverFunc(t) + defer cleanup() + + smqsdk := sdk.NewSDK(sdk.Config{ClientsURL: url}) + + ctx := context.Background() + if tc.ctxFunc != nil { + ctx = tc.ctxFunc() + } + + client := sdk.Client{Name: "test-client"} + for i := 0; i < 2; i++ { + _, err := smqsdk.CreateClients(ctx, []sdk.Client{client}, domainID, validToken) + if tc.wantErr { + require.Error(t, err) + if tc.errContains != "" { + assert.True(t, strings.Contains(err.Error(), tc.errContains), + "expected error %q to contain %q", err.Error(), tc.errContains) + } + break + } + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/ticker/mocks/ticker.go b/pkg/ticker/mocks/ticker.go new file mode 100644 index 000000000..5e676a90c --- /dev/null +++ b/pkg/ticker/mocks/ticker.go @@ -0,0 +1,121 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "time" + + mock "github.com/stretchr/testify/mock" +) + +// NewTicker creates a new instance of Ticker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTicker(t interface { + mock.TestingT + Cleanup(func()) +}) *Ticker { + mock := &Ticker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Ticker is an autogenerated mock type for the Ticker type +type Ticker struct { + mock.Mock +} + +type Ticker_Expecter struct { + mock *mock.Mock +} + +func (_m *Ticker) EXPECT() *Ticker_Expecter { + return &Ticker_Expecter{mock: &_m.Mock} +} + +// Stop provides a mock function for the type Ticker +func (_mock *Ticker) Stop() { + _mock.Called() + return +} + +// Ticker_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type Ticker_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *Ticker_Expecter) Stop() *Ticker_Stop_Call { + return &Ticker_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *Ticker_Stop_Call) Run(run func()) *Ticker_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Ticker_Stop_Call) Return() *Ticker_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *Ticker_Stop_Call) RunAndReturn(run func()) *Ticker_Stop_Call { + _c.Run(run) + return _c +} + +// Tick provides a mock function for the type Ticker +func (_mock *Ticker) Tick() <-chan time.Time { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Tick") + } + + var r0 <-chan time.Time + if returnFunc, ok := ret.Get(0).(func() <-chan time.Time); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan time.Time) + } + } + return r0 +} + +// Ticker_Tick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Tick' +type Ticker_Tick_Call struct { + *mock.Call +} + +// Tick is a helper method to define mock.On call +func (_e *Ticker_Expecter) Tick() *Ticker_Tick_Call { + return &Ticker_Tick_Call{Call: _e.mock.On("Tick")} +} + +func (_c *Ticker_Tick_Call) Run(run func()) *Ticker_Tick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Ticker_Tick_Call) Return(timeCh <-chan time.Time) *Ticker_Tick_Call { + _c.Call.Return(timeCh) + return _c +} + +func (_c *Ticker_Tick_Call) RunAndReturn(run func() <-chan time.Time) *Ticker_Tick_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/ticker/ticker.go b/pkg/ticker/ticker.go new file mode 100644 index 000000000..7220bde07 --- /dev/null +++ b/pkg/ticker/ticker.go @@ -0,0 +1,23 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package ticker + +import "time" + +type Ticker interface { + Tick() <-chan time.Time + Stop() +} + +type timeTicker struct { + *time.Ticker +} + +func NewTicker(d time.Duration) Ticker { + return &timeTicker{time.NewTicker(d)} +} + +func (t *timeTicker) Tick() <-chan time.Time { + return t.C +} diff --git a/provision/README.md b/provision/README.md new file mode 100644 index 000000000..238675ab7 --- /dev/null +++ b/provision/README.md @@ -0,0 +1,232 @@ +# Provision service + +Provision service provides an HTTP API to create initial Magistrala resources for gateways or edge deployments. It can create clients and channels based on a configurable layout, optionally create bootstrap configurations, whitelist clients, and issue X.509 certificates for mTLS. + +For gateways to communicate with [Magistrala][magistrala], configuration is required (MQTT host, client, channels, certificates). A gateway can fetch bootstrap configuration from the [Bootstrap][bootstrap] service using its `` and ``. The [Agent][agent] service is typically used on gateways to retrieve that configuration. + +You can create bootstrap configuration directly via [Bootstrap][bootstrap] or through Provision. [Magistrala UI][mgxui] uses the Bootstrap service; Provision is intended to automate gateway setups where one physical gateway may require multiple clients and channels (for example, [Agent][agent] and [Export][export]). This setup is defined as a **provision layout**. + +## Configuration + +The service is configured using environment variables and/or a TOML config file. Defaults below are from `provision/config.go`. Docker add-on examples are in `docker/addons/provision/docker-compose.yaml` and [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env). The binary reads `MG_PROVISION_*` variables; the add-on compose file uses `MG_PROVISION_*`, so ensure the container receives the expected names. + +### Core service + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_PROVISION_HTTP_PORT` | Provision service listening port | `9016` | +| `MG_PROVISION_LOG_LEVEL` | Service log level | `info` | +| `MG_PROVISION_ENV_CLIENTS_TLS` | SDK TLS verification | `false` | +| `MG_PROVISION_SERVER_CERT` | HTTPS server certificate | "" | +| `MG_PROVISION_SERVER_KEY` | HTTPS server key | "" | +| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | `true` | +| `MG_MQTT_ADAPTER_INSTANCE_ID` | Instance ID used in health output | "" | + +### Magistrala endpoints and credentials + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_PROVISION_USERS_LOCATION` | Users service URL | `http://localhost` | +| `MG_PROVISION_CLIENTS_LOCATION` | Clients service URL | `http://localhost` | +| `MG_PROVISION_CERTS_LOCATION` | Certs service URL (certs SDK) | `http://localhost` | +| `MG_PROVISION_BS_SVC_URL` | Bootstrap service URL | `http://localhost:9000` | +| `MG_PROVISION_CERTS_SVC_URL` | Certs service URL (Magistrala SDK) | `http://localhost:9019` | +| `MG_PROVISION_USERNAME` | Magistrala username | `user` | +| `MG_PROVISION_PASS` | Magistrala password | `test` | +| `MG_PROVISION_API_KEY` | Magistrala authentication token | "" | +| `MG_PROVISION_EMAIL` | Magistrala user email | `test@example.com` | +| `MG_PROVISION_DOMAIN_ID` | Default domain ID (unused by HTTP API) | "" | + +### Provisioning behavior + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_PROVISION_CONFIG_FILE` | Provision config file | `config.toml` | +| `MG_PROVISION_X509_PROVISIONING` | Issue client certificates during provisioning | `false` | +| `MG_PROVISION_BS_CONFIG_PROVISIONING` | Save client config in Bootstrap | `true` | +| `MG_PROVISION_BS_AUTO_WHITELIST` | Auto-whitelist client | `true` | +| `MG_PROVISION_BS_CONTENT` | Bootstrap config content (JSON string) | "" | +| `MG_PROVISION_CERTS_HOURS_VALID` | Client cert validity period | `2400h` | + +## Features + +- **Layout-driven provisioning**: Create clients and channels from a predefined layout. +- **Bootstrap integration**: Create bootstrap configs and optionally whitelist clients. +- **X.509 certificates**: Issue client certificates during provisioning when enabled. +- **Gateway metadata**: Enrich gateway clients with control/data/export channel IDs. +- **Observability**: `/metrics` and `/health` endpoints. + +## Provision layout + +Provision layout is configured in a TOML file (see `provision/configs/config.toml` or `docker/addons/provision/configs/config.toml`). If the file exists, it is loaded and any missing fields are filled with env values. The layout defines which clients and channels will be created when calling `/mapping`. + +Default behavior (when no config file is loaded) creates one client and two channels: `control` and `data`. + +Notes: + +- At least one client must include `external_id` in metadata. This value is replaced with the `external_id` from the provisioning request and is used for bootstrap creation. +- Channel metadata `type` is reserved for `control`, `data`, and `export` and is used to enrich gateway metadata. +- Bootstrap content can be provided via `bootstrap.content` in the TOML file or as JSON through `MG_PROVISION_BS_CONTENT`. + +Example layout: + +```toml +[[clients]] + name = "client" + + [clients.metadata] + external_id = "xxxxxx" + +[[channels]] + name = "control-channel" + + [channels.metadata] + type = "control" + +[[channels]] + name = "data-channel" + + [channels.metadata] + type = "data" + +[[channels]] + name = "export-channel" + + [channels.metadata] + type = "data" +``` + +## Authentication + +Provision uses Magistrala APIs and requires a valid token. There are three ways to provide it: + +- `Authorization: Bearer ` on each request. +- `MG_PROVISION_API_KEY` in env or TOML (used when no header token is provided). +- `MG_PROVISION_USERNAME` and `MG_PROVISION_PASS` in env or TOML (used to create an access token when no header token is provided). + +`POST /{domainID}/mapping` can create its own token using API key or username/password if no `Authorization` header is provided. The `Authorization` header takes precedence when present. `GET /{domainID}/mapping` always requires a bearer token. + +## Architecture + +### Runtime flow + +1. The service loads configuration from env and optionally merges a config file. +2. `POST /{domainID}/mapping` validates the request and ensures a token exists. +3. Clients are created from the configured layout (external ID is injected into metadata). +4. Channels are created with names prefixed by the request `name`. +5. If enabled, bootstrap configs are created and clients are whitelisted (connected to channels). +6. If X.509 provisioning is enabled, certificates are issued and returned in the response. + +## Running + +Provision service can be run standalone or via Docker Compose. + +Standalone: + +```bash +make provision + +MG_PROVISION_BS_SVC_URL=http://localhost:9013 \ +MG_PROVISION_CLIENTS_LOCATION=http://localhost:9006 \ +MG_PROVISION_USERS_LOCATION=http://localhost:9002 \ +MG_PROVISION_CONFIG_FILE=provision/configs/config.toml \ +./build/provision +``` + +Docker Compose (add-on): + +```bash +docker compose -f docker/docker-compose.yaml -f docker/addons/provision/docker-compose.yaml up provision +``` + +## Usage + +The Provision service exposes the following endpoints: + +| Operation | Method & Path | Description | +| --- | --- | --- | +| `provision` | `POST /{domainID}/mapping` | Create clients, channels, bootstrap config, and optional certs | +| `mapping` | `GET /{domainID}/mapping` | Return bootstrap content from config | +| `health` | `GET /health` | Service health check | + +### Example: Provision a gateway + +When credentials are available via env/config, you can omit the `Authorization` header. `Content-Type` must be exactly `application/json`. + +```bash +curl -s -S -X POST http://localhost://mapping \ + -H 'Content-Type: application/json' \ + -d '{"name": "gateway-a", "external_id": "33:52:77:99:43", "external_key": "223334fw2"}' +``` + +If you want to supply a token explicitly: + +```bash +curl -s -S -X POST http://localhost://mapping \ + -H "Authorization: Bearer " \ + -H 'Content-Type: application/json' \ + -d '{"name": "gateway-a", "external_id": "", "external_key": ""}' +``` + +Response contains created clients, channels, and optional certificate data: + +```json +{ + "clients": [ + { + "id": "c22b0c0f-8c03-40da-a06b-37ed3a72c8d1", + "name": "client", + "key": "007cce56-e0eb-40d6-b2b9-ed348a97d1eb", + "metadata": { + "external_id": "33:52:79:C3:43" + } + } + ], + "channels": [ + { + "id": "064c680e-181b-4b58-975e-6983313a5170", + "name": "control-channel", + "metadata": { + "type": "control" + } + }, + { + "id": "579da92d-6078-4801-a18a-dd1cfa2aa44f", + "name": "data-channel", + "metadata": { + "type": "data" + } + } + ], + "whitelisted": { + "c22b0c0f-8c03-40da-a06b-37ed3a72c8d1": true + } +} +``` + +### Example: Read bootstrap mapping + +```bash +curl -s -S -X GET http://localhost://mapping \ + -H "Authorization: Bearer " \ + -H 'Content-Type: application/json' +``` + +## Certificates + +When `MG_PROVISION_X509_PROVISIONING=true`, the provisioning flow issues certificates for each client and returns them in the response as `client_cert`, `client_key`, and `ca_cert`. The certificate TTL is controlled by `MG_PROVISION_CERTS_HOURS_VALID`. + +## Testing + +```bash +go test ./provision/... +``` + +For an in-depth explanation of our Provision Service, see the [official documentation][doc]. + +[doc]: https://docs.magistrala.absmach.eu/dev-guide/provision/ +[magistrala]: https://github.com/absmach/magistrala +[bootstrap]: https://github.com/absmach/magistrala/tree/main/bootstrap +[export]: https://github.com/absmach/export +[agent]: https://github.com/absmach/agent +[mgxui]: https://github.com/absmach/magistrala/ui diff --git a/provision/api/doc.go b/provision/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/provision/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/provision/api/endpoint.go b/provision/api/endpoint.go new file mode 100644 index 000000000..7b50f8e66 --- /dev/null +++ b/provision/api/endpoint.go @@ -0,0 +1,75 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/provision" + "github.com/go-kit/kit/endpoint" +) + +func doProvision(svc provision.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + req := request.(provisionReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + res, err := svc.Provision(ctx, session.DomainID, req.token, req.Name, req.ExternalID, req.ExternalKey) + if err != nil { + return nil, err + } + + provisionResponse := provisionRes{ + Clients: res.Clients, + Channels: res.Channels, + ClientCert: res.ClientCert, + ClientKey: res.ClientKey, + CACert: res.CACert, + Whitelisted: res.Whitelisted, + } + + return provisionResponse, nil + } +} + +func getMapping(svc provision.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + res := svc.Mapping() + + return mappingRes{Data: res}, nil + } +} + +func issueCert(svc provision.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + req := request.(certReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + cert, key, err := svc.Cert(ctx, session.DomainID, req.token, req.ClientID, req.TTL) + if err != nil { + return nil, err + } + + return certRes{ + Certificate: cert, + Key: key, + }, nil + } +} diff --git a/provision/api/endpoint_test.go b/provision/api/endpoint_test.go new file mode 100644 index 000000000..1354016d5 --- /dev/null +++ b/provision/api/endpoint_test.go @@ -0,0 +1,329 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/provision" + "github.com/absmach/supermq/provision/api" + mocks "github.com/absmach/supermq/provision/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + validToken = "valid" + validContenType = "application/json" + validID = testsutil.GenerateUUID(&testing.T{}) + userID = testsutil.GenerateUUID(&testing.T{}) + domainID = testsutil.GenerateUUID(&testing.T{}) + validSession = smqauthn.Session{ + DomainUserID: auth.EncodeDomainUserID(domainID, userID), + UserID: userID, + DomainID: domainID, + } +) + +type testRequest struct { + client *http.Client + method string + url string + token string + contentType string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + return tr.client.Do(req) +} + +func newProvisionServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + svc := new(mocks.Service) + + logger := smqlog.NewMock() + authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := api.MakeHandler(svc, am, logger, "test") + return httptest.NewServer(mux), svc, authn +} + +func TestProvision(t *testing.T) { + is, svc, authn := newProvisionServer() + + cases := []struct { + desc string + token string + domainID string + data string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + }{ + { + desc: "valid request", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusCreated, + contentType: validContenType, + authnRes: validSession, + svcErr: nil, + }, + { + desc: "request with empty external id", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_key": "%s"}`, validID), + status: http.StatusBadRequest, + contentType: validContenType, + authnRes: validSession, + }, + { + desc: "request with empty external key", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s"}`, validID), + status: http.StatusUnauthorized, + contentType: validContenType, + authnRes: validSession, + svcErr: nil, + }, + { + desc: "empty token", + token: "", + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusUnauthorized, + contentType: validContenType, + authnRes: smqauthn.Session{}, + authnErr: errors.ErrAuthentication, + svcErr: nil, + }, + { + desc: "invalid content type", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + authnRes: validSession, + svcErr: nil, + }, + { + desc: "invalid request", + token: validToken, + domainID: validID, + data: `data`, + status: http.StatusBadRequest, + contentType: validContenType, + authnRes: validSession, + svcErr: nil, + }, + { + desc: "service error", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"name": "test", "external_id": "%s", "external_key": "%s"}`, validID, validID), + status: http.StatusForbidden, + contentType: validContenType, + authnRes: validSession, + svcErr: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + repocall := svc.On("Provision", mock.Anything, validID, tc.token, "test", validID, validID).Return(provision.Result{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: is.URL + fmt.Sprintf("/%s/mapping", tc.domainID), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + resp, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, resp.StatusCode, tc.desc) + authCall.Unset() + repocall.Unset() + }) + } +} + +func TestMapping(t *testing.T) { + is, svc, authn := newProvisionServer() + + cases := []struct { + desc string + token string + domainID string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + }{ + { + desc: "valid request", + token: validToken, + domainID: validID, + status: http.StatusOK, + contentType: validContenType, + svcErr: nil, + authnRes: validSession, + authnErr: nil, + }, + { + desc: "empty token", + token: "", + domainID: validID, + status: http.StatusUnauthorized, + contentType: validContenType, + svcErr: nil, + authnRes: smqauthn.Session{}, + authnErr: errors.ErrAuthentication, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + repocall := svc.On("Mapping").Return(map[string]any{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodGet, + url: is.URL + fmt.Sprintf("/%s/mapping", tc.domainID), + token: tc.token, + contentType: tc.contentType, + } + + resp, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, resp.StatusCode, tc.desc) + authCall.Unset() + repocall.Unset() + }) + } +} + +func TestCert(t *testing.T) { + is, svc, authn := newProvisionServer() + + cases := []struct { + desc string + token string + domainID string + data string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + }{ + { + desc: "valid request", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"client_id": "%s", "ttl": "1h"}`, validID), + status: http.StatusCreated, + contentType: validContenType, + authnRes: validSession, + svcErr: nil, + }, + { + desc: "empty token", + token: "", + domainID: validID, + data: fmt.Sprintf(`{"client_id": "%s", "ttl": "1h"}`, validID), + status: http.StatusUnauthorized, + contentType: validContenType, + authnRes: smqauthn.Session{}, + authnErr: errors.ErrAuthentication, + svcErr: nil, + }, + { + desc: "invalid content type", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"client_id": "%s", "ttl": "1h"}`, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + authnRes: validSession, + svcErr: nil, + }, + { + desc: "invalid request", + token: validToken, + domainID: validID, + data: `data`, + status: http.StatusBadRequest, + contentType: validContenType, + authnRes: validSession, + svcErr: nil, + }, + { + desc: "service error", + token: validToken, + domainID: validID, + data: fmt.Sprintf(`{"client_id": "%s", "ttl": "1h"}`, validID), + status: http.StatusForbidden, + contentType: validContenType, + authnRes: validSession, + svcErr: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + repocall := svc.On("Cert", mock.Anything, validID, tc.token, validID, "1h").Return("cert", "key", tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: is.URL + fmt.Sprintf("/%s/cert", tc.domainID), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + resp, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, resp.StatusCode, tc.desc) + authCall.Unset() + repocall.Unset() + }) + } +} diff --git a/provision/api/requests.go b/provision/api/requests.go new file mode 100644 index 000000000..c03865cf7 --- /dev/null +++ b/provision/api/requests.go @@ -0,0 +1,43 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import apiutil "github.com/absmach/supermq/api/http/util" + +type provisionReq struct { + token string + Name string `json:"name"` + ExternalID string `json:"external_id"` + ExternalKey string `json:"external_key"` +} + +func (req provisionReq) validate() error { + if req.ExternalID == "" { + return apiutil.ErrMissingID + } + + if req.ExternalKey == "" { + return apiutil.ErrBearerKey + } + + if req.Name == "" { + return apiutil.ErrMissingName + } + + return nil +} + +type certReq struct { + token string + ClientID string `json:"client_id"` + TTL string `json:"ttl,omitempty"` +} + +func (req certReq) validate() error { + if req.ClientID == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/provision/api/requests_test.go b/provision/api/requests_test.go new file mode 100644 index 000000000..31fe8e493 --- /dev/null +++ b/provision/api/requests_test.go @@ -0,0 +1,58 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "testing" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestProvisioReq(t *testing.T) { + cases := []struct { + desc string + req provisionReq + err error + }{ + { + desc: "valid request", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: testsutil.GenerateUUID(t), + ExternalKey: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "empty external id", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: "", + ExternalKey: testsutil.GenerateUUID(t), + }, + err: apiutil.ErrMissingID, + }, + { + desc: "empty external key", + req: provisionReq{ + token: "token", + Name: "name", + ExternalID: testsutil.GenerateUUID(t), + ExternalKey: "", + }, + err: apiutil.ErrBearerKey, + }, + } + + for _, tc := range cases { + err := tc.req.validate() + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected `%v` got `%v`", tc.desc, tc.err, err)) + } +} diff --git a/provision/api/responses.go b/provision/api/responses.go new file mode 100644 index 000000000..8f276ac91 --- /dev/null +++ b/provision/api/responses.go @@ -0,0 +1,72 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "encoding/json" + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/pkg/sdk" +) + +var _ supermq.Response = (*provisionRes)(nil) + +type provisionRes struct { + Clients []sdk.Client `json:"clients"` + Channels []sdk.Channel `json:"channels"` + ClientCert map[string]string `json:"client_cert,omitempty"` + ClientKey map[string]string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` + Whitelisted map[string]bool `json:"whitelisted,omitempty"` +} + +func (res provisionRes) Code() int { + return http.StatusCreated +} + +func (res provisionRes) Headers() map[string]string { + return map[string]string{} +} + +func (res provisionRes) Empty() bool { + return false +} + +type mappingRes struct { + Data any +} + +func (res mappingRes) Code() int { + return http.StatusOK +} + +func (res mappingRes) Headers() map[string]string { + return map[string]string{} +} + +func (res mappingRes) Empty() bool { + return false +} + +type certRes struct { + Certificate string `json:"certificate"` + Key string `json:"key"` +} + +func (res certRes) Code() int { + return http.StatusCreated +} + +func (res certRes) Headers() map[string]string { + return map[string]string{} +} + +func (res certRes) Empty() bool { + return false +} + +func (res mappingRes) MarshalJSON() ([]byte, error) { + return json.Marshal(res.Data) +} diff --git a/provision/api/transport.go b/provision/api/transport.go new file mode 100644 index 000000000..b4d696d90 --- /dev/null +++ b/provision/api/transport.go @@ -0,0 +1,96 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/provision" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + contentType = "application/json" +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc provision.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + + r := chi.NewRouter() + + r.Route("/{domainID}", func(r chi.Router) { + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) + r.Route("/mapping", func(r chi.Router) { + r.Post("/", kithttp.NewServer( + doProvision(svc), + decodeProvisionRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/", kithttp.NewServer( + getMapping(svc), + decodeMappingRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) + r.Post("/cert", kithttp.NewServer( + issueCert(svc), + decodeCertRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) + r.Handle("/metrics", promhttp.Handler()) + r.Get("/health", supermq.Health("provision", instanceID)) + + return r +} + +func decodeProvisionRequest(_ context.Context, r *http.Request) (any, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, apiutil.ErrUnsupportedContentType + } + + req := provisionReq{ + token: apiutil.ExtractBearerToken(r), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeMappingRequest(_ context.Context, r *http.Request) (any, error) { + return nil, nil +} + +func decodeCertRequest(_ context.Context, r *http.Request) (any, error) { + if r.Header.Get("Content-Type") != contentType { + return nil, apiutil.ErrUnsupportedContentType + } + + req := certReq{ + token: apiutil.ExtractBearerToken(r), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} diff --git a/provision/config.go b/provision/config.go new file mode 100644 index 000000000..e0154c77f --- /dev/null +++ b/provision/config.go @@ -0,0 +1,104 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package provision + +import ( + "fmt" + "os" + + "github.com/absmach/supermq/channels" + "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/pkg/errors" + "github.com/pelletier/go-toml" +) + +var errFailedToReadConfig = errors.New("failed to read config file") + +// ServiceConf represents service config. +type ServiceConf struct { + Port string `toml:"port" env:"MG_PROVISION_HTTP_PORT" envDefault:"9016"` + LogLevel string `toml:"log_level" env:"MG_PROVISION_LOG_LEVEL" envDefault:"info"` + TLS bool `toml:"tls" env:"MG_PROVISION_ENV_CLIENTS_TLS" envDefault:"false"` + ServerCert string `toml:"server_cert" env:"MG_PROVISION_SERVER_CERT" envDefault:""` + ServerKey string `toml:"server_key" env:"MG_PROVISION_SERVER_KEY" envDefault:""` + ClientsURL string `toml:"clients_url" env:"MG_PROVISION_CLIENTS_URL" envDefault:"http://localhost"` + ChannelsURL string `toml:"channels_url" env:"MG_PROVISION_CHANNELS_URL" envDefault:"http://localhost"` + UsersURL string `toml:"users_url" env:"MG_PROVISION_USERS_URL" envDefault:"http://localhost"` + CertsURL string `toml:"certs_url" env:"MG_PROVISION_CERTS_URL" envDefault:"http://localhost"` + MgEmail string `toml:"mg_email" env:"MG_PROVISION_EMAIL" envDefault:"test@example.com"` + MgUsername string `toml:"mg_username" env:"MG_PROVISION_USERNAME" envDefault:"user"` + MgPass string `toml:"mg_pass" env:"MG_PROVISION_PASS" envDefault:"test"` + MgDomainID string `toml:"mg_domain_id" env:"MG_PROVISION_DOMAIN_ID" envDefault:""` + MgAPIKey string `toml:"mg_api_key" env:"MG_PROVISION_API_KEY" envDefault:""` + MgBSURL string `toml:"mg_bs_url" env:"MG_PROVISION_BS_SVC_URL" envDefault:"http://localhost:9000"` +} + +// Bootstrap represetns the Bootstrap config. +type Bootstrap struct { + X509Provision bool `toml:"x509_provision" env:"MG_PROVISION_X509_PROVISIONING" envDefault:"false"` + Provision bool `toml:"provision" env:"MG_PROVISION_BS_CONFIG_PROVISIONING" envDefault:"true"` + AutoWhiteList bool `toml:"autowhite_list" env:"MG_PROVISION_BS_AUTO_WHITELIST" envDefault:"true"` + Content map[string]any `toml:"content"` +} + +// Gateway represetns the Gateway config. +type Gateway struct { + Type string `toml:"type" json:"type"` + ExternalID string `toml:"external_id" json:"external_id"` + ExternalKey string `toml:"external_key" json:"external_key"` + CtrlChannelID string `toml:"ctrl_channel_id" json:"ctrl_channel_id"` + DataChannelID string `toml:"data_channel_id" json:"data_channel_id"` + ExportChannelID string `toml:"export_channel_id" json:"export_channel_id"` + CfgID string `toml:"cfg_id" json:"cfg_id"` +} + +// Cert represetns the certificate config. +type Cert struct { + TTL string `json:"ttl" toml:"ttl" env:"MG_PROVISION_CERTS_HOURS_VALID" envDefault:"2400h"` +} + +// Config struct of Provision. +type Config struct { + File string `toml:"file" env:"MG_PROVISION_CONFIG_FILE" envDefault:"config.toml"` + Server ServiceConf `toml:"server" mapstructure:"server"` + Bootstrap Bootstrap `toml:"bootstrap" mapstructure:"bootstrap"` + Clients []clients.Client `toml:"clients" mapstructure:"clients"` + Channels []channels.Channel `toml:"channels" mapstructure:"channels"` + Cert Cert `toml:"cert" mapstructure:"cert"` + BSContent string `env:"MG_PROVISION_BS_CONTENT" envDefault:""` + SendTelemetry bool `env:"MG_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"MG_MQTT_ADAPTER_INSTANCE_ID" envDefault:""` +} + +// Save - store config in a file. +func Save(c Config, file string) error { + if file == "" { + return errors.ErrEmptyPath + } + + b, err := toml.Marshal(c) + if err != nil { + return errors.Wrap(errFailedToReadConfig, err) + } + if err := os.WriteFile(file, b, 0o644); err != nil { + return fmt.Errorf("Error writing toml: %w", err) + } + + return nil +} + +// Read - retrieve config from a file. +func Read(file string) (Config, error) { + data, err := os.ReadFile(file) + if err != nil { + return Config{}, errors.Wrap(errFailedToReadConfig, err) + } + + var c Config + if err := toml.Unmarshal(data, &c); err != nil { + return Config{}, fmt.Errorf("Error unmarshaling toml: %w", err) + } + + return c, nil +} diff --git a/provision/config_test.go b/provision/config_test.go new file mode 100644 index 000000000..5d2683fe4 --- /dev/null +++ b/provision/config_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package provision_test + +import ( + "fmt" + "os" + "testing" + + "github.com/absmach/supermq/channels" + "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/pkg/connections" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/provision" + "github.com/pelletier/go-toml" + "github.com/stretchr/testify/assert" +) + +var ( + validConfig = provision.Config{ + Server: provision.ServiceConf{ + Port: "9016", + LogLevel: "info", + TLS: false, + }, + Bootstrap: provision.Bootstrap{ + X509Provision: true, + Provision: true, + AutoWhiteList: true, + Content: map[string]any{ + "test": "test", + }, + }, + Clients: []clients.Client{ + { + ID: "1234567890", + Name: "test", + Tags: []string{"test"}, + Metadata: map[string]any{ + "test": "test", + }, + PrivateMetadata: clients.Metadata{}, + Actions: []string{}, + AccessProviderRoleActions: []string{}, + ConnectionTypes: []connections.ConnType{}, + }, + }, + Channels: []channels.Channel{ + { + ID: "1234567890", + Name: "test", + Tags: []string{"test"}, + Metadata: map[string]any{ + "test": "test", + }, + Actions: []string{}, + AccessProviderRoleActions: []string{}, + ConnectionTypes: []connections.ConnType{}, + }, + }, + Cert: provision.Cert{}, + SendTelemetry: true, + InstanceID: "1234567890", + } + validConfigFile = "./config.toml" + invalidConfig = provision.Config{ + Bootstrap: provision.Bootstrap{ + Content: map[string]any{ + "invalid": make(chan int), + }, + }, + } + invalidConfigFile = "./invalid.toml" +) + +func createInvalidConfigFile() error { + config := map[string]any{ + "invalid": "invalid", + } + b, err := toml.Marshal(config) + if err != nil { + return err + } + + f, err := os.Create(invalidConfigFile) + if err != nil { + return err + } + + if _, err = f.Write(b); err != nil { + return err + } + + return nil +} + +func createValidConfigFile() error { + b, err := toml.Marshal(validConfig) + if err != nil { + return err + } + + f, err := os.Create(validConfigFile) + if err != nil { + return err + } + + if _, err = f.Write(b); err != nil { + return err + } + + return nil +} + +func TestSave(t *testing.T) { + cases := []struct { + desc string + cfg provision.Config + file string + err error + }{ + { + desc: "save valid config", + cfg: validConfig, + file: validConfigFile, + err: nil, + }, + { + desc: "save valid config with empty file name", + cfg: validConfig, + file: "", + err: errors.ErrEmptyPath, + }, + { + desc: "save empty config with valid config file", + cfg: provision.Config{}, + file: validConfigFile, + err: nil, + }, + { + desc: "save empty config with empty file name", + cfg: provision.Config{}, + file: "", + err: errors.ErrEmptyPath, + }, + { + desc: "save invalid config", + cfg: invalidConfig, + file: invalidConfigFile, + err: errors.New("failed to read config file"), + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + err := provision.Save(c.cfg, c.file) + assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err)) + + if err == nil { + defer func() { + if c.file != "" { + err := os.Remove(c.file) + assert.NoError(t, err) + } + }() + + cfg, err := provision.Read(c.file) + if c.cfg.Bootstrap.Content == nil { + c.cfg.Bootstrap.Content = map[string]any{} + } + assert.Equal(t, c.err, err) + assert.Equal(t, c.cfg, cfg) + } + }) + } +} + +func TestRead(t *testing.T) { + err := createInvalidConfigFile() + assert.NoError(t, err) + + err = createValidConfigFile() + assert.NoError(t, err) + + t.Cleanup(func() { + err := os.Remove(invalidConfigFile) + assert.NoError(t, err) + err = os.Remove(validConfigFile) + assert.NoError(t, err) + }) + + cases := []struct { + desc string + file string + cfg provision.Config + err error + }{ + { + desc: "read valid config", + file: validConfigFile, + cfg: validConfig, + err: nil, + }, + { + desc: "read invalid config", + file: invalidConfigFile, + cfg: invalidConfig, + err: nil, + }, + { + desc: "read empty config", + file: "", + cfg: provision.Config{}, + err: errors.New("failed to read config file"), + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + cfg, err := provision.Read(c.file) + if c.desc == "read invalid config" { + c.cfg.Bootstrap.Content = nil + } + assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected: %v, got: %v", c.err, err)) + assert.Equal(t, c.cfg, cfg) + }) + } +} diff --git a/provision/configs/config.toml b/provision/configs/config.toml new file mode 100644 index 000000000..650ed3518 --- /dev/null +++ b/provision/configs/config.toml @@ -0,0 +1,47 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +file = "config.toml" + +[bootstrap] + autowhite_list = true + content = "" + provision = true + x509_provision = false + + +[server] + LogLevel = "info" + ca_certs = "" + http_port = "8190" + mg_api_key = "" + mg_bs_url = "http://localhost:9013" + mg_certs_url = "http://localhost:9019" + mg_pass = "" + mg_user = "" + mqtt_url = "" + port = "" + server_cert = "" + server_key = "" + clients_location = "http://localhost:9006" + tls = true + users_location = "" + +[[clients]] + name = "client" + + [client.metadata] + external_id = "xxxxxx" + + +[[channels]] + name = "control-channel" + + [channels.metadata] + type = "control" + +[[channels]] + name = "data-channel" + + [channels.metadata] + type = "data" diff --git a/provision/doc.go b/provision/doc.go new file mode 100644 index 000000000..e9b855294 --- /dev/null +++ b/provision/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package provision contains domain concept definitions needed to support +// Provision service feature, i.e. automate provision process. +package provision diff --git a/provision/middleware/logging.go b/provision/middleware/logging.go new file mode 100644 index 000000000..f54d1734f --- /dev/null +++ b/provision/middleware/logging.go @@ -0,0 +1,71 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/provision" +) + +var _ provision.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc provision.Service +} + +// NewLogging adds logging facilities to the core service. +func NewLogging(svc provision.Service, logger *slog.Logger) provision.Service { + return &loggingMiddleware{logger, svc} +} + +func (lm *loggingMiddleware) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res provision.Result, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("name", name), + slog.String("external_id", externalID), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Provision failed", args...) + return + } + lm.logger.Info("Provision completed successfully", args...) + }(time.Now()) + + return lm.svc.Provision(ctx, domainID, token, name, externalID, externalKey) +} + +func (lm *loggingMiddleware) Cert(ctx context.Context, domainID, token, clientID, duration string) (cert, key string, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("client_id", clientID), + slog.String("ttl", duration), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Client certificate creation failed", args...) + return + } + lm.logger.Info("Client certificate created successfully", args...) + }(time.Now()) + + return lm.svc.Cert(ctx, domainID, token, clientID, duration) +} + +func (lm *loggingMiddleware) Mapping() (res map[string]any) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + lm.logger.Info("Mapping completed successfully", args...) + }(time.Now()) + + return lm.svc.Mapping() +} diff --git a/provision/mocks/service.go b/provision/mocks/service.go new file mode 100644 index 000000000..455333968 --- /dev/null +++ b/provision/mocks/service.go @@ -0,0 +1,269 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/provision" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// Cert provides a mock function for the type Service +func (_mock *Service) Cert(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error) { + ret := _mock.Called(ctx, domainID, token, clientID, duration) + + if len(ret) == 0 { + panic("no return value specified for Cert") + } + + var r0 string + var r1 string + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) (string, string, error)); ok { + return returnFunc(ctx, domainID, token, clientID, duration) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string) string); ok { + r0 = returnFunc(ctx, domainID, token, clientID, duration) + } else { + r0 = ret.Get(0).(string) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string) string); ok { + r1 = returnFunc(ctx, domainID, token, clientID, duration) + } else { + r1 = ret.Get(1).(string) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, string, string, string) error); ok { + r2 = returnFunc(ctx, domainID, token, clientID, duration) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// Service_Cert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Cert' +type Service_Cert_Call struct { + *mock.Call +} + +// Cert is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - token string +// - clientID string +// - duration string +func (_e *Service_Expecter) Cert(ctx interface{}, domainID interface{}, token interface{}, clientID interface{}, duration interface{}) *Service_Cert_Call { + return &Service_Cert_Call{Call: _e.mock.On("Cert", ctx, domainID, token, clientID, duration)} +} + +func (_c *Service_Cert_Call) Run(run func(ctx context.Context, domainID string, token string, clientID string, duration string)) *Service_Cert_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_Cert_Call) Return(s string, s1 string, err error) *Service_Cert_Call { + _c.Call.Return(s, s1, err) + return _c +} + +func (_c *Service_Cert_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, clientID string, duration string) (string, string, error)) *Service_Cert_Call { + _c.Call.Return(run) + return _c +} + +// Mapping provides a mock function for the type Service +func (_mock *Service) Mapping() map[string]any { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Mapping") + } + + var r0 map[string]any + if returnFunc, ok := ret.Get(0).(func() map[string]any); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]any) + } + } + return r0 +} + +// Service_Mapping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Mapping' +type Service_Mapping_Call struct { + *mock.Call +} + +// Mapping is a helper method to define mock.On call +func (_e *Service_Expecter) Mapping() *Service_Mapping_Call { + return &Service_Mapping_Call{Call: _e.mock.On("Mapping")} +} + +func (_c *Service_Mapping_Call) Run(run func()) *Service_Mapping_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Service_Mapping_Call) Return(stringToV map[string]any) *Service_Mapping_Call { + _c.Call.Return(stringToV) + return _c +} + +func (_c *Service_Mapping_Call) RunAndReturn(run func() map[string]any) *Service_Mapping_Call { + _c.Call.Return(run) + return _c +} + +// Provision provides a mock function for the type Service +func (_mock *Service) Provision(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error) { + ret := _mock.Called(ctx, domainID, token, name, externalID, externalKey) + + if len(ret) == 0 { + panic("no return value specified for Provision") + } + + var r0 provision.Result + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) (provision.Result, error)); ok { + return returnFunc(ctx, domainID, token, name, externalID, externalKey) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, string, string) provision.Result); ok { + r0 = returnFunc(ctx, domainID, token, name, externalID, externalKey) + } else { + r0 = ret.Get(0).(provision.Result) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string, string, string) error); ok { + r1 = returnFunc(ctx, domainID, token, name, externalID, externalKey) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_Provision_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Provision' +type Service_Provision_Call struct { + *mock.Call +} + +// Provision is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - token string +// - name string +// - externalID string +// - externalKey string +func (_e *Service_Expecter) Provision(ctx interface{}, domainID interface{}, token interface{}, name interface{}, externalID interface{}, externalKey interface{}) *Service_Provision_Call { + return &Service_Provision_Call{Call: _e.mock.On("Provision", ctx, domainID, token, name, externalID, externalKey)} +} + +func (_c *Service_Provision_Call) Run(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string)) *Service_Provision_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + var arg5 string + if args[5] != nil { + arg5 = args[5].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_Provision_Call) Return(result provision.Result, err error) *Service_Provision_Call { + _c.Call.Return(result, err) + return _c +} + +func (_c *Service_Provision_Call) RunAndReturn(run func(ctx context.Context, domainID string, token string, name string, externalID string, externalKey string) (provision.Result, error)) *Service_Provision_Call { + _c.Call.Return(run) + return _c +} diff --git a/provision/service.go b/provision/service.go new file mode 100644 index 000000000..3440e642e --- /dev/null +++ b/provision/service.go @@ -0,0 +1,417 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package provision + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + certs "github.com/absmach/certs" + csdk "github.com/absmach/certs/sdk" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/sdk" + smqSDK "github.com/absmach/supermq/pkg/sdk" +) + +const ( + externalIDKey = "external_id" + gateway = "gateway" + Active = 1 + + control = "control" + data = "data" + export = "export" +) + +var ( + ErrUnauthorized = errors.NewAuthNError("unauthorized access") + ErrFailedToCreateToken = errors.NewAuthNError("failed to create access token") + ErrEmptyClientsList = errors.NewRequestError("clients list in configuration empty") + ErrClientUpdate = errors.NewRequestError("failed to update client") + ErrEmptyChannelsList = errors.NewRequestError("channels list in configuration is empty") + ErrFailedChannelCreation = errors.NewRequestError("failed to create channel") + ErrFailedChannelRetrieval = errors.NewRequestError("failed to retrieve channel") + ErrFailedClientCreation = errors.NewRequestError("failed to create client") + ErrFailedClientRetrieval = errors.NewRequestError("failed to retrieve client") + ErrMissingCredentials = errors.NewRequestError("missing credentials") + ErrFailedBootstrapRetrieval = errors.NewServiceError("failed to retrieve bootstrap") + ErrFailedCertCreation = errors.NewServiceError("failed to create certificates") + ErrFailedCertView = errors.NewServiceError("failed to view certificate") + ErrFailedBootstrap = errors.NewServiceError("failed to create bootstrap config") + ErrFailedBootstrapValidate = errors.NewServiceError("failed to validate bootstrap config creation") + ErrGatewayUpdate = errors.NewServiceError("failed to update gateway metadata") +) + +var _ Service = (*provisionService)(nil) + +// Service specifies Provision service API. +type Service interface { + // Provision is the only method this API specifies. Depending on the configuration, + // the following actions will can be executed: + // - create a Client based on external_id (eg. MAC address) + // - create multiple Channels + // - create Bootstrap configuration + // - whitelist Client in Bootstrap configuration == connect Client to Channels + Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (Result, error) + + // Mapping returns current configuration used for provision + // useful for using in ui to create configuration that matches + // one created with Provision method. + Mapping() map[string]any + + // Certs creates certificate for clients that communicate over mTLS + // A duration string is a possibly signed sequence of decimal numbers, + // each with optional fraction and a unit suffix, such as "300ms", "-1.5h" or "2h45m". + // Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". + Cert(ctx context.Context, domainID, token, clientID, duration string) (string, string, error) +} + +type provisionService struct { + logger *slog.Logger + sdk sdk.SDK + csdk csdk.SDK + conf Config +} + +// Result represent what is created with additional info. +type Result struct { + Clients []smqSDK.Client `json:"clients,omitempty"` + Channels []smqSDK.Channel `json:"channels,omitempty"` + ClientCert map[string]string `json:"client_cert,omitempty"` + ClientKey map[string]string `json:"client_key,omitempty"` + CACert string `json:"ca_cert,omitempty"` + Whitelisted map[string]bool `json:"whitelisted,omitempty"` + Error string `json:"error,omitempty"` +} + +// New returns new provision service. +func New(cfg Config, mgsdk sdk.SDK, certsSdk csdk.SDK, logger *slog.Logger) Service { + return &provisionService{ + logger: logger, + csdk: certsSdk, + conf: cfg, + sdk: mgsdk, + } +} + +// Mapping retrieves current configuration. +func (ps *provisionService) Mapping() map[string]any { + return ps.conf.Bootstrap.Content +} + +// Provision is provision method for creating setup according to +// provision layout specified in config.toml. +func (ps *provisionService) Provision(ctx context.Context, domainID, token, name, externalID, externalKey string) (res Result, err error) { + var channels []smqSDK.Channel + var clients []smqSDK.Client + defer ps.recover(ctx, &err, &clients, &channels, domainID, token) + + token, err = ps.createTokenIfEmpty(ctx, token) + if err != nil { + return res, errors.Wrap(ErrFailedToCreateToken, err) + } + + if len(ps.conf.Clients) == 0 { + return res, ErrEmptyClientsList + } + if len(ps.conf.Channels) == 0 { + return res, ErrEmptyChannelsList + } + for _, c := range ps.conf.Clients { + // If client in configs contains metadata with external_id + // set value for it from the provision request + if _, ok := c.Metadata[externalIDKey]; ok { + c.Metadata[externalIDKey] = externalID + } + + cli := smqSDK.Client{ + Metadata: c.Metadata, + } + if name == "" { + name = c.Name + } + cli.Name = name + cli, err := ps.sdk.CreateClient(ctx, cli, domainID, token) + if err != nil { + res.Error = err.Error() + return res, errors.Wrap(ErrFailedClientCreation, err) + } + + // Get newly created client (in order to get the key). + cli, err = ps.sdk.Client(ctx, cli.ID, domainID, token) + if err != nil { + e := errors.Wrap(err, fmt.Errorf("client id: %s", cli.ID)) + return res, errors.Wrap(ErrFailedClientRetrieval, e) + } + clients = append(clients, cli) + } + + for _, channel := range ps.conf.Channels { + ch := smqSDK.Channel{ + Name: name + "_" + channel.Name, + Metadata: smqSDK.Metadata(channel.Metadata), + } + ch, err := ps.sdk.CreateChannel(ctx, ch, domainID, token) + if err != nil { + return res, errors.Wrap(ErrFailedChannelCreation, err) + } + ch, err = ps.sdk.Channel(ctx, ch.ID, domainID, token) + if err != nil { + e := errors.Wrap(err, fmt.Errorf("channel id: %s", ch.ID)) + return res, errors.Wrap(ErrFailedChannelRetrieval, e) + } + channels = append(channels, ch) + } + + res = Result{ + Clients: clients, + Channels: channels, + Whitelisted: map[string]bool{}, + ClientCert: map[string]string{}, + ClientKey: map[string]string{}, + } + + var cert certs.Certificate + var bsConfig sdk.BootstrapConfig + for _, c := range clients { + var chanIDs []string + + for _, ch := range channels { + chanIDs = append(chanIDs, ch.ID) + } + content, err := json.Marshal(ps.conf.Bootstrap.Content) + if err != nil { + return Result{}, errors.Wrap(ErrFailedBootstrap, err) + } + + if ps.conf.Bootstrap.Provision && needsBootstrap(c) { + bsReq := sdk.BootstrapConfig{ + ClientID: c.ID, + ExternalID: externalID, + ExternalKey: externalKey, + Channels: chanIDs, + CACert: res.CACert, + ClientCert: string(cert.Certificate), + ClientKey: string(cert.Key), + Content: string(content), + } + bsid, err := ps.sdk.AddBootstrap(ctx, bsReq, domainID, token) + if err != nil { + return Result{}, errors.Wrap(ErrFailedBootstrap, err) + } + + bsConfig, err = ps.sdk.ViewBootstrap(ctx, bsid, domainID, token) + if err != nil { + return Result{}, errors.Wrap(ErrFailedBootstrapValidate, err) + } + } + + if ps.conf.Bootstrap.X509Provision { + var cert csdk.Certificate + + cert, err = ps.csdk.IssueCert(ctx, c.ID, ps.conf.Cert.TTL, nil, csdk.Options{}, domainID, token) + if err != nil { + e := errors.Wrap(err, fmt.Errorf("client id: %s", c.ID)) + return res, errors.Wrap(ErrFailedCertCreation, e) + } + cert, err := ps.csdk.ViewCert(ctx, cert.SerialNumber, domainID, token) + if err != nil { + return res, errors.Wrap(ErrFailedCertView, err) + } + + res.ClientCert[c.ID] = cert.Certificate + res.ClientKey[c.ID] = cert.Key + res.CACert = "" + + if needsBootstrap(c) { + if _, err = ps.sdk.UpdateBootstrapCerts(ctx, bsConfig.ClientID, cert.Certificate, cert.Key, "", domainID, token); err != nil { + return Result{}, errors.Wrap(ErrFailedCertCreation, err) + } + } + } + + if ps.conf.Bootstrap.AutoWhiteList { + if err := ps.sdk.Whitelist(ctx, c.ID, Active, domainID, token); err != nil { + res.Error = err.Error() + return res, ErrClientUpdate + } + res.Whitelisted[c.ID] = true + } + } + + if err = ps.updateGateway(ctx, domainID, token, bsConfig, channels); err != nil { + return res, err + } + return res, nil +} + +func (ps *provisionService) Cert(ctx context.Context, domainID, token, clientID, ttl string) (string, string, error) { + token, err := ps.createTokenIfEmpty(ctx, token) + if err != nil { + return "", "", errors.Wrap(ErrFailedToCreateToken, err) + } + + c, err := ps.sdk.Client(ctx, clientID, domainID, token) + if err != nil { + return "", "", errors.Wrap(ErrUnauthorized, err) + } + cert, err := ps.csdk.IssueCert(ctx, c.ID, ps.conf.Cert.TTL, []string{}, csdk.Options{}, domainID, token) + if err != nil { + return "", "", errors.Wrap(ErrFailedCertCreation, err) + } + cert, err = ps.csdk.ViewCert(ctx, cert.SerialNumber, domainID, token) + if err != nil { + return "", "", errors.Wrap(ErrFailedCertView, err) + } + return cert.Certificate, cert.Key, err +} + +func (ps *provisionService) createTokenIfEmpty(ctx context.Context, token string) (string, error) { + if token != "" { + return token, nil + } + + // If no token in request is provided + // use API key provided in config file or env + if ps.conf.Server.MgAPIKey != "" { + return ps.conf.Server.MgAPIKey, nil + } + + // If no API key use username and password provided to create access token. + if ps.conf.Server.MgUsername == "" || ps.conf.Server.MgPass == "" { + return token, ErrMissingCredentials + } + + u := smqSDK.Login{ + Username: ps.conf.Server.MgUsername, + Password: ps.conf.Server.MgPass, + } + tkn, err := ps.sdk.CreateToken(ctx, u) + if err != nil { + return token, errors.Wrap(ErrFailedToCreateToken, err) + } + + return tkn.AccessToken, nil +} + +func (ps *provisionService) updateGateway(ctx context.Context, domainID, token string, bs sdk.BootstrapConfig, channels []smqSDK.Channel) error { + var gw Gateway + for _, ch := range channels { + switch ch.Metadata["type"] { + case control: + gw.CtrlChannelID = ch.ID + case data: + gw.DataChannelID = ch.ID + case export: + gw.ExportChannelID = ch.ID + } + } + gw.ExternalID = bs.ExternalID + gw.ExternalKey = bs.ExternalKey + gw.CfgID = bs.ClientID + gw.Type = gateway + + c, sdkerr := ps.sdk.Client(ctx, bs.ClientID, domainID, token) + if sdkerr != nil { + return errors.Wrap(ErrGatewayUpdate, sdkerr) + } + b, err := json.Marshal(gw) + if err != nil { + return errors.Wrap(ErrGatewayUpdate, err) + } + if err := json.Unmarshal(b, &c.Metadata); err != nil { + return errors.Wrap(ErrGatewayUpdate, err) + } + if _, err := ps.sdk.UpdateClient(ctx, c, domainID, token); err != nil { + return errors.Wrap(ErrGatewayUpdate, err) + } + return nil +} + +func (ps *provisionService) errLog(err error) { + if err != nil { + ps.logger.Error(fmt.Sprintf("Error recovering: %s", err)) + } +} + +func clean(ctx context.Context, ps *provisionService, clients []smqSDK.Client, channels []smqSDK.Channel, domainID, token string) { + for _, t := range clients { + err := ps.sdk.DeleteClient(ctx, t.ID, domainID, token) + ps.errLog(err) + } + for _, c := range channels { + err := ps.sdk.DeleteChannel(ctx, c.ID, domainID, token) + ps.errLog(err) + } +} + +func (ps *provisionService) recover(ctx context.Context, e *error, ths *[]smqSDK.Client, chs *[]smqSDK.Channel, domainID, token string) { + if e == nil { + return + } + clients, channels, err := *ths, *chs, *e + + if errors.Contains(err, ErrFailedClientRetrieval) || errors.Contains(err, ErrFailedChannelCreation) { + for _, c := range clients { + err := ps.sdk.DeleteClient(ctx, c.ID, domainID, token) + ps.errLog(err) + } + return + } + + if errors.Contains(err, ErrFailedBootstrap) || errors.Contains(err, ErrFailedChannelRetrieval) { + clean(ctx, ps, clients, channels, domainID, token) + return + } + + if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { + clean(ctx, ps, clients, channels, domainID, token) + for _, c := range clients { + if needsBootstrap(c) { + ps.errLog(ps.sdk.RemoveBootstrap(ctx, c.ID, domainID, token)) + } + } + return + } + + if errors.Contains(err, ErrFailedBootstrapValidate) || errors.Contains(err, ErrFailedCertCreation) { + clean(ctx, ps, clients, channels, domainID, token) + for _, c := range clients { + if needsBootstrap(c) { + bs, err := ps.sdk.ViewBootstrap(ctx, c.ID, domainID, token) + ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) + } + } + } + + if errors.Contains(err, ErrClientUpdate) || errors.Contains(err, ErrGatewayUpdate) { + clean(ctx, ps, clients, channels, domainID, token) + for _, c := range clients { + if ps.conf.Bootstrap.X509Provision && needsBootstrap(c) { + err := ps.csdk.RevokeCert(ctx, c.ID, domainID, token) + ps.errLog(err) + } + if needsBootstrap(c) { + bs, err := ps.sdk.ViewBootstrap(ctx, c.ID, domainID, token) + ps.errLog(errors.Wrap(ErrFailedBootstrapRetrieval, err)) + ps.errLog(ps.sdk.RemoveBootstrap(ctx, bs.ClientID, domainID, token)) + } + } + return + } +} + +func needsBootstrap(c smqSDK.Client) bool { + if c.Metadata == nil { + return false + } + + if _, ok := c.Metadata[externalIDKey]; ok { + return true + } + return false +} diff --git a/provision/service_test.go b/provision/service_test.go new file mode 100644 index 000000000..4ea275822 --- /dev/null +++ b/provision/service_test.go @@ -0,0 +1,237 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package provision_test + +import ( + "context" + "fmt" + "testing" + + csdk "github.com/absmach/certs/sdk" + csdkmocks "github.com/absmach/certs/sdk/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + svcerr "github.com/absmach/supermq/pkg/errors/service" + smqSDK "github.com/absmach/supermq/pkg/sdk" + sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" + "github.com/absmach/supermq/provision" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var validToken = "valid" + +func TestMapping(t *testing.T) { + mgsdk := new(sdkmocks.SDK) + certs := new(csdkmocks.SDK) + svc := provision.New(validConfig, mgsdk, certs, smqlog.NewMock()) + + cases := []struct { + desc string + content map[string]any + sdkerr error + err error + }{ + { + desc: "valid request", + content: validConfig.Bootstrap.Content, + sdkerr: nil, + err: nil, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + content := svc.Mapping() + assert.Equal(t, c.content, content) + }) + } +} + +func TestCert(t *testing.T) { + cases := []struct { + desc string + config provision.Config + domainID string + token string + returnedToken string + clientID string + ttl string + serial string + cert string + key string + sdkClientErr error + sdkCertErr error + sdkTokenErr error + err error + }{ + { + desc: "valid", + config: validConfig, + domainID: testsutil.GenerateUUID(t), + token: validToken, + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "cert", + key: "key", + sdkClientErr: nil, + sdkCertErr: nil, + sdkTokenErr: nil, + err: nil, + }, + { + desc: "empty token with config API key", + config: provision.Config{ + Server: provision.ServiceConf{MgAPIKey: "key"}, + Cert: provision.Cert{TTL: "1h"}, + }, + domainID: testsutil.GenerateUUID(t), + token: "", + returnedToken: "key", + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "cert", + key: "key", + sdkClientErr: nil, + sdkCertErr: nil, + sdkTokenErr: nil, + err: nil, + }, + { + desc: "empty token with username and password", + config: provision.Config{ + Server: provision.ServiceConf{ + MgUsername: "testUsername", + MgPass: "12345678", + MgDomainID: testsutil.GenerateUUID(t), + }, + Cert: provision.Cert{TTL: "1h"}, + }, + domainID: testsutil.GenerateUUID(t), + token: "", + returnedToken: validToken, + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "cert", + key: "key", + sdkClientErr: nil, + sdkCertErr: nil, + sdkTokenErr: nil, + err: nil, + }, + { + desc: "empty token with username and invalid password", + config: provision.Config{ + Server: provision.ServiceConf{ + MgUsername: "testUsername", + MgPass: "12345678", + MgDomainID: testsutil.GenerateUUID(t), + }, + Cert: provision.Cert{TTL: "1h"}, + }, + domainID: testsutil.GenerateUUID(t), + token: "", + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "", + key: "", + sdkClientErr: nil, + sdkCertErr: nil, + sdkTokenErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, 401), + err: provision.ErrFailedToCreateToken, + }, + { + desc: "empty token with empty username and password", + config: provision.Config{ + Server: provision.ServiceConf{}, + Cert: provision.Cert{TTL: "1h"}, + }, + domainID: testsutil.GenerateUUID(t), + token: "", + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "", + key: "", + sdkClientErr: nil, + sdkCertErr: nil, + sdkTokenErr: nil, + err: provision.ErrMissingCredentials, + }, + { + desc: "invalid clientID", + config: validConfig, + domainID: testsutil.GenerateUUID(t), + token: "invalid", + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "", + key: "", + sdkClientErr: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, 401), + sdkCertErr: nil, + sdkTokenErr: nil, + err: provision.ErrUnauthorized, + }, + { + desc: "invalid clientID", + config: validConfig, + domainID: testsutil.GenerateUUID(t), + token: validToken, + clientID: "invalid", + ttl: "1h", + cert: "", + key: "", + sdkClientErr: errors.NewSDKErrorWithStatus(repoerr.ErrNotFound, 404), + sdkCertErr: nil, + sdkTokenErr: nil, + err: provision.ErrUnauthorized, + }, + { + desc: "failed to issue cert", + config: validConfig, + domainID: testsutil.GenerateUUID(t), + token: validToken, + clientID: testsutil.GenerateUUID(t), + ttl: "1h", + cert: "", + key: "", + sdkClientErr: nil, + sdkTokenErr: nil, + sdkCertErr: errors.NewSDKError(repoerr.ErrCreateEntity), + err: repoerr.ErrCreateEntity, + }, + } + mgsdk := new(sdkmocks.SDK) + certs := new(csdkmocks.SDK) + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + svc := provision.New(c.config, mgsdk, certs, smqlog.NewMock()) + + call1 := mgsdk.On("Client", mock.Anything, c.clientID, c.domainID, mock.Anything).Return(smqSDK.Client{ID: c.clientID}, c.sdkClientErr) + var call2 *mock.Call + switch c.token { + case "": + call2 = certs.On("IssueCert", context.Background(), c.clientID, c.config.Cert.TTL, []string{}, csdk.Options{}, c.domainID, c.returnedToken).Return(csdk.Certificate{SerialNumber: c.serial}, c.sdkCertErr) + default: + call2 = certs.On("IssueCert", context.Background(), c.clientID, c.config.Cert.TTL, []string{}, csdk.Options{}, c.domainID, c.token).Return(csdk.Certificate{SerialNumber: c.serial}, c.sdkCertErr) + } + call3 := certs.On("ViewCert", mock.Anything, c.serial, mock.Anything, mock.Anything).Return(csdk.Certificate{Certificate: c.cert, Key: c.key}, c.sdkCertErr) + + login := smqSDK.Login{ + Username: c.config.Server.MgUsername, + Password: c.config.Server.MgPass, + } + call4 := mgsdk.On("CreateToken", mock.Anything, login).Return(smqSDK.Token{AccessToken: validToken}, c.sdkTokenErr) + cert, key, err := svc.Cert(context.Background(), c.domainID, c.token, c.clientID, c.ttl) + assert.Equal(t, c.cert, cert) + assert.Equal(t, c.key, key) + assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected error %v, got %v", c.err, err)) + call1.Unset() + call2.Unset() + call3.Unset() + call4.Unset() + }) + } +} diff --git a/re/README.md b/re/README.md new file mode 100644 index 000000000..2bd52c3ee --- /dev/null +++ b/re/README.md @@ -0,0 +1,323 @@ +# Rules Engine + +The Magistrala Rules Engine (RE) processes incoming messages using user-defined scripts (Lua or Go) and routes the results to outputs such as channels, alarms, email, SenML writers, PostgreSQL, or Slack. It also supports scheduled rule execution and publishes rule events to the event store. + +## Configuration + +The service is configured using the following environment variables (values shown are from [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env) as consumed by [docker/docker-compose.yaml](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml)): + +### Core service + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_RE_LOG_LEVEL` | Log level for the service | `debug` | +| `MG_RE_HTTP_HOST` | HTTP host to bind | `re` | +| `MG_RE_HTTP_PORT` | HTTP port to bind | `9008` | +| `MG_RE_HTTP_SERVER_CERT` | Path to PEM-encoded HTTPS server certificate | "" | +| `MG_RE_HTTP_SERVER_KEY` | Path to PEM-encoded HTTPS server key | "" | +| `MG_RE_INSTANCE_ID` | Instance ID for tracing/health | "" | +| `MG_MESSAGE_BROKER_URL` | Internal message broker URL | `nats://nats:4222` | +| `MG_ES_URL` | Event store broker URL | `nats://nats:4222` | +| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` | +| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | `true` | + +### Database + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_RE_DB_HOST` | PostgreSQL host | `re-db` | +| `MG_RE_DB_PORT` | PostgreSQL port | `5432` | +| `MG_RE_DB_USER` | PostgreSQL user | `magistrala` | +| `MG_RE_DB_PASS` | PostgreSQL password | `magistrala` | +| `MG_RE_DB_NAME` | PostgreSQL database name | `rules_engine` | +| `MG_RE_DB_SSL_MODE` | PostgreSQL SSL mode | `disable` | +| `MG_RE_DB_SSL_CERT` | PostgreSQL SSL client cert | "" | +| `MG_RE_DB_SSL_KEY` | PostgreSQL SSL client key | "" | +| `MG_RE_DB_SSL_ROOT_CERT` | PostgreSQL SSL root cert | "" | + +### Auth and domains gRPC + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_AUTH_GRPC_URL` | Auth gRPC endpoint | `auth:7001` | +| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `300s` | +| `MG_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt}` | +| `MG_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key}` | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_DOMAINS_GRPC_URL` | Domains gRPC endpoint | `domains:7003` | +| `MG_DOMAINS_GRPC_TIMEOUT` | Domains gRPC timeout | `300s` | +| `MG_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt}` | +| `MG_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key}` | +| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to access | `true` | + +### Readers gRPC + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_TIMESCALE_READER_GRPC_URL` | Readers gRPC endpoint | `timescale-reader:7011` | +| `MG_TIMESCALE_READER_GRPC_TIMEOUT` | Readers gRPC timeout | `300s` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_CERT` | Readers gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/reader-grpc-client.crt}` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS` | Readers gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_KEY` | Readers gRPC client key path | `${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key}` | + +### Email + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_EMAIL_HOST` | SMTP host | `smtp.mailtrap.io` | +| `MG_EMAIL_PORT` | SMTP port | `2525` | +| `MG_EMAIL_USERNAME` | SMTP username | `18bf7f70705139` | +| `MG_EMAIL_PASSWORD` | SMTP password | `2b0d302e775b1e` | +| `MG_EMAIL_FROM_ADDRESS` | Sender email address | `from@example.com` | +| `MG_EMAIL_FROM_NAME` | Sender display name | `Example` | +| `MG_EMAIL_TEMPLATE` | Email template path | `email.tmpl` | +| `MG_RE_EMAIL_TEMPLATE` | Template file mounted by Docker Compose | `re.tmpl` | + +### Callout + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_RE_CALLOUT_URLS` | Callout target URLs | "" | +| `MG_RE_CALLOUT_METHOD` | Callout HTTP method | `POST` | +| `MG_RE_CALLOUT_TLS_VERIFICATION` | TLS verification for callout | `false` | +| `MG_RE_CALLOUT_TIMEOUT` | Callout timeout | `10s` | +| `MG_RE_CALLOUT_CA_CERT` | Callout CA cert path | "" | +| `MG_RE_CALLOUT_CERT` | Callout client cert path | "" | +| `MG_RE_CALLOUT_KEY` | Callout client key path | "" | +| `MG_RE_CALLOUT_OPERATIONS` | Callout operations filter | "" | + +### Optional cache defaults (from code) + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_RE_CACHE_URL` | Cache URL | `redis://localhost:6379/0` | +| `MG_RE_CACHE_KEY_DURATION` | Cache key TTL | `10m` | + +## Features + +- **Rule execution**: Runs Lua or Go scripts for incoming messages. +- **Multiple outputs**: Channels, alarms, email, SenML writers, remote PostgreSQL, and Slack outputs. +- **Scheduling**: Runs rules at specific times with recurring intervals. +- **Filtering and matching**: Input channel filtering and NATS-style topic matching (`*`, `>`). +- **Observability**: `/metrics` Prometheus endpoint and Jaeger tracing support. +- **Payload limit**: Messages over 100 kB are rejected for processing. + +## Architecture + +### Runtime flow + +1. The service subscribes to all internal broker messages. +2. For each message, it lists enabled rules for the same domain and input channel. +3. It matches the rule `input_topic` against the message subtopic using NATS-style wildcards. +4. The rule logic (Lua or Go) is executed and the result is passed to configured outputs. + +### Message payloads + +In Lua, the engine injects a global `message` object: + +```lua +message = { + domain = "domain_id", + channel = "channel_id", + subtopic = "subtopic", + publisher = "client_id", + protocol = "nats", + created = timestamp, + payload = { ... } -- JSON object/array or a byte array if payload is not JSON +} +``` + +For Go scripts, the message is exposed as `messaging/m.message` and `main.logicFunction` must return a value. + +In rule definitions, `logic.type` uses numeric values: `0` = Lua, `1` = Go. + +If a script returns `false`, outputs are skipped. + +### Scheduling + +The scheduler runs on a 30-second ticker and selects enabled rules with a due time (`time`) earlier than now. It updates the next due time using `Schedule.NextDue()` and executes each rule with a synthetic message containing the scheduled timestamp. + +Recurring types are: `none`, `hourly`, `daily`, `weekly`, `monthly`. The `recurring_period` controls the interval (1 = every interval, 2 = every second interval, etc.). + +### Outputs + +Supported output types (`outputs.OutputType`) and their fields: + +| Output type | Fields | Notes | +| --- | --- | --- | +| `channels` | `channel`, `topic` | Republish result to another channel/topic. | +| `alarms` | none | Emits alarms from the script result. | +| `save_senml` | none | Forwards SenML to writers. | +| `email` | `to`, `subject`, `content` | `content` is a Go template. | +| `save_remote_pg` | `host`, `port`, `user`, `password`, `database`, `table`, `mapping` | `mapping` is a Go template that must render a JSON object. | +| `slack` | `token`, `channel_id`, `message` | `message` is a Go template. | + +Templates receive a `Message` (the incoming message) and a `Result` (the script output) value. + +## Data model + +### Rules table + +Defined in `re/postgres/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `id` | `VARCHAR(36)` | Rule UUID (primary key) | +| `name` | `VARCHAR(1024)` | Rule name | +| `domain_id` | `VARCHAR(36)` | Domain ID | +| `metadata` | `JSONB` | Custom metadata | +| `tags` | `TEXT[]` | Rule tags | +| `created_by` | `VARCHAR(254)` | Creator user ID | +| `created_at` | `TIMESTAMP` | Creation timestamp | +| `updated_at` | `TIMESTAMP` | Last update timestamp | +| `updated_by` | `VARCHAR(254)` | Last updater user ID | +| `input_channel` | `VARCHAR(36)` | Input channel ID | +| `input_topic` | `TEXT` | Input topic (supports wildcards) | +| `outputs` | `JSONB` | Output definitions | +| `status` | `SMALLINT` | 0 = enabled, 1 = disabled, 2 = deleted | +| `logic_type` | `SMALLINT` | 0 = Lua, 1 = Go | +| `logic_value` | `BYTEA` | Script body | +| `start_datetime` | `TIMESTAMP` | Schedule start time | +| `time` | `TIMESTAMP` | Next scheduled execution time | +| `recurring` | `SMALLINT` | Recurring type | +| `recurring_period` | `SMALLINT` | Recurring period | + +## Deployment + +### Build and run locally + +```bash +make re + +MG_RE_LOG_LEVEL=debug \ +MG_RE_HTTP_PORT=9008 \ +MG_RE_DB_HOST=localhost \ +MG_RE_DB_PORT=5432 \ +MG_RE_DB_USER=magistrala \ +MG_RE_DB_PASS=magistrala \ +MG_RE_DB_NAME=rules_engine \ +MG_MESSAGE_BROKER_URL=nats://localhost:4222 \ +MG_ES_URL=nats://localhost:4222 \ +MG_AUTH_GRPC_URL=localhost:7001 \ +MG_AUTH_GRPC_TIMEOUT=300s \ +MG_DOMAINS_GRPC_URL=localhost:7003 \ +MG_DOMAINS_GRPC_TIMEOUT=300s \ +MG_TIMESCALE_READER_GRPC_URL=localhost:7011 \ +MG_TIMESCALE_READER_GRPC_TIMEOUT=300s \ +./build/re +``` + +### Docker Compose + +The service is available as a Docker container. Refer to [docker/docker-compose.yaml](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml) for the `re` and `re-db` services and their environment variables. For a full local stack, ensure auth, domains, readers, and the message broker are running. + +```bash +docker compose -f docker/docker-compose.yaml up re re-db +``` + +### Health check + +```bash +curl -X GET http://localhost:9008/health \ + -H "accept: application/health+json" +``` + +## Testing + +```bash +go test ./re/... +``` + +## Usage + +The Rules Engine service supports the following operations: + +| Operation | Method & Path | Description | +| --- | --- | --- | +| `createRule` | `POST /{domainID}/rules` | Create a new rule | +| `listRules` | `GET /{domainID}/rules` | List rules with filters | +| `viewRule` | `GET /{domainID}/rules/{ruleID}` | Retrieve a rule | +| `updateRule` | `PATCH /{domainID}/rules/{ruleID}` | Update a rule | +| `updateRuleTags` | `PATCH /{domainID}/rules/{ruleID}/tags` | Update rule tags | +| `updateRuleSchedule` | `PATCH /{domainID}/rules/{ruleID}/schedule` | Update rule schedule | +| `enableRule` | `POST /{domainID}/rules/{ruleID}/enable` | Enable a rule | +| `disableRule` | `POST /{domainID}/rules/{ruleID}/disable` | Disable a rule | +| `removeRule` | `DELETE /{domainID}/rules/{ruleID}` | Delete a rule | +| `health` | `GET /health` | Service health check | + +List filters: `offset`, `limit`, `name`, `input_channel`, `status`, `order` (`name`, `created_at`, `updated_at`), `dir` (`asc`, `desc`), and `tag`. + +### Example: Create a rule (Lua + alarms + channels) + +```bash +curl -X POST http://localhost:9008//rules \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "High Temperature Alert", + "input_channel": "sensors", + "input_topic": "temperature.*", + "logic": { + "type": 0, + "value": "if message.payload.t > 30 then return {measurement=\"temperature\", value=tostring(message.payload.t), unit=\"C\", threshold=\"30\", cause=\"temp high\", severity=90} end" + }, + "outputs": [ + { "type": "alarms" }, + { "type": "channels", "channel": "alerts", "topic": "temperature" } + ], + "tags": ["temp", "alerts"], + "metadata": { "site": "lab" } + }' +``` + +### Example: List rules + +```bash +curl -X GET "http://localhost:9008//rules?status=enabled&input_channel=sensors&order=updated_at&dir=desc&tag=temp" \ + -H "Authorization: Bearer " +``` + +### Example: Update rule tags + +```bash +curl -X PATCH http://localhost:9008//rules//tags \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ "tags": ["temp", "critical"] }' +``` + +### Example: Update rule schedule + +```bash +curl -X PATCH http://localhost:9008//rules//schedule \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "schedule": { + "start_datetime": "2025-01-01T00:00:00Z", + "time": "2025-01-01T00:00:00Z", + "recurring": "hourly", + "recurring_period": 1 + } + }' +``` + +### Example: Enable a rule + +```bash +curl -X POST http://localhost:9008//rules//enable \ + -H "Authorization: Bearer " +``` + +### Example: Delete a rule + +```bash +curl -X DELETE http://localhost:9008//rules/ \ + -H "Authorization: Bearer " +``` + +For an in-depth explanation of our Rules Engine Service, see the [official documentation][doc]. + +[doc]: https://docs.magistrala.absmach.eu/dev-guide/rules-engine/ diff --git a/re/api/doc.go b/re/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/re/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/re/api/endpoints.go b/re/api/endpoints.go new file mode 100644 index 000000000..d4707abbe --- /dev/null +++ b/re/api/endpoints.go @@ -0,0 +1,205 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/re" + "github.com/go-kit/kit/endpoint" +) + +func addRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(addRuleReq) + if err := req.validate(); err != nil { + return addRuleRes{}, err + } + rule, _, err := s.AddRule(ctx, session, req.Rule) + if err != nil { + return addRuleRes{}, err + } + return addRuleRes{Rule: rule, created: true}, nil + } +} + +func viewRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(viewRuleReq) + if err := req.validate(); err != nil { + return viewRuleRes{}, err + } + rule, err := s.ViewRule(ctx, session, req.id, req.withRoles) + if err != nil { + return viewRuleRes{}, err + } + return viewRuleRes{Rule: rule}, nil + } +} + +func updateRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateRuleReq) + if err := req.validate(); err != nil { + return updateRuleRes{}, err + } + rule, err := s.UpdateRule(ctx, session, req.Rule) + if err != nil { + return updateRuleRes{}, err + } + return updateRuleRes{Rule: rule}, nil + } +} + +func updateRuleTagsEndpoint(svc re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(updateRuleTagsReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + r := re.Rule{ + ID: req.id, + Tags: req.Tags, + } + res, err := svc.UpdateRuleTags(ctx, session, r) + if err != nil { + return nil, err + } + + return updateRuleRes{Rule: res}, nil + } +} + +func updateRuleScheduleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateRuleScheduleReq) + if err := req.validate(); err != nil { + return updateRuleRes{}, err + } + + rule := re.Rule{ + ID: req.id, + Schedule: req.Schedule, + } + + updatedRule, err := s.UpdateRuleSchedule(ctx, session, rule) + if err != nil { + return updateRuleRes{}, err + } + return updateRuleRes{Rule: updatedRule}, nil + } +} + +func listRulesEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(listRulesReq) + if err := req.validate(); err != nil { + return pageRes{}, err + } + page, err := s.ListRules(ctx, session, req.PageMeta) + if err != nil { + return rulesPageRes{}, err + } + ret := rulesPageRes{ + Page: page, + } + return ret, nil + } +} + +func deleteRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(deleteRuleReq) + if err := req.validate(); err != nil { + return deleteRuleRes{}, err + } + err := s.RemoveRule(ctx, session, req.id) + if err != nil { + return deleteRuleRes{false}, err + } + return deleteRuleRes{true}, nil + } +} + +func enableRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateRuleStatusReq) + if err := req.validate(); err != nil { + return updateRuleStatusRes{}, err + } + + rule, err := s.EnableRule(ctx, session, req.id) + if err != nil { + return updateRuleStatusRes{}, err + } + + return updateRuleStatusRes{Rule: rule}, err + } +} + +func disableRuleEndpoint(s re.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateRuleStatusReq) + if err := req.validate(); err != nil { + return updateRuleStatusRes{}, err + } + + rule, err := s.DisableRule(ctx, session, req.id) + if err != nil { + return updateRuleStatusRes{}, err + } + + return updateRuleStatusRes{Rule: rule}, err + } +} diff --git a/re/api/endpoints_test.go b/re/api/endpoints_test.go new file mode 100644 index 000000000..2eb05dcd5 --- /dev/null +++ b/re/api/endpoints_test.go @@ -0,0 +1,1285 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/roles" + pkgSch "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/api" + "github.com/absmach/supermq/re/mocks" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const contentType = "application/json" + +var ( + namegen = namegenerator.NewGenerator() + domainID = testsutil.GenerateUUID(&testing.T{}) + userID = testsutil.GenerateUUID(&testing.T{}) + validID = testsutil.GenerateUUID(&testing.T{}) + validToken = "valid" + invalidToken = "invalid" + now = time.Now().UTC().Truncate(time.Minute) + future = now.Add(1 * time.Hour) + schedule = pkgSch.Schedule{ + StartDateTime: future, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + } + rule = re.Rule{ + ID: validID, + Name: namegen.Generate(), + DomainID: domainID, + Schedule: schedule, + Metadata: re.Metadata{ + "name": "test", + }, + } + past = now.Add(-1 * time.Hour) + scheduleInPast = pkgSch.Schedule{ + StartDateTime: past, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: past, + } +) + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + req.Header.Set("Referer", "http://localhost") + + return tr.client.Do(req) +} + +func newRuleEngineServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + svc := new(mocks.Service) + authn := new(authnmocks.Authentication) + + logger := smqlog.NewMock() + mux := chi.NewRouter() + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + api.MakeHandler(svc, am, mux, logger, "") + + return httptest.NewServer(mux), svc, authn +} + +func toJSON(data any) string { + jsonData, err := json.Marshal(data) + if err != nil { + return "" + } + return string(jsonData) +} + +func TestAddRuleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + ruleInPast := rule + ruleInPast.Schedule = scheduleInPast + + cases := []struct { + desc string + rule re.Rule + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes re.Rule + svcErr error + err error + len int + }{ + { + desc: "add rule successfully", + rule: rule, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusCreated, + svcRes: rule, + }, + { + desc: "add rule with invalid token", + rule: rule, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "add rule with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + rule: rule, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "add rule with name that is too long", + token: validToken, + rule: re.Rule{ + ID: validID, + Name: strings.Repeat("a", 1025), + Logic: re.Script{ + Type: re.ScriptType(0), + Value: "return `test` end", + }, + }, + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrNameSize, + }, + { + desc: "add rule with empty domainID", + token: validToken, + rule: rule, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "add rule with invalid content type", + token: validToken, + domainID: domainID, + rule: rule, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "add rule with startdatetime in past", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + rule: ruleInPast, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "add rule with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + rule: rule, + contentType: contentType, + svcErr: svcerr.ErrCreateEntity, + status: http.StatusUnprocessableEntity, + err: svcerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.rule) + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/rules", ts.URL, tc.domainID), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("AddRule", mock.Anything, tc.authnRes, tc.rule).Return(tc.svcRes, []roles.RoleProvision{}, tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewRuleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + cases := []struct { + desc string + id string + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes re.Rule + svcErr error + err error + len int + }{ + { + desc: "view rule successfully", + id: rule.ID, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusOK, + svcRes: rule, + }, + { + desc: "view rule with invalid token", + id: rule.ID, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "view rule with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + id: rule.ID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "view rule with empty domainID", + token: validToken, + id: rule.ID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "view rule with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + id: rule.ID, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("ViewRule", mock.Anything, tc.authnRes, tc.id, false).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListRulesEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + cases := []struct { + desc string + query string + domainID string + token string + session smqauthn.Session + listRulesResponse re.Page + status int + authnErr error + err error + }{ + { + desc: "list rules successfully", + domainID: domainID, + token: validToken, + status: http.StatusOK, + listRulesResponse: re.Page{ + Total: 1, + Rules: []re.Rule{rule}, + }, + err: nil, + }, + { + desc: "list rules with empty token", + domainID: domainID, + token: "", + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "list rules with invalid token", + domainID: domainID, + token: invalidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "list rules with offset", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{ + Total: 1, + + Rules: []re.Rule{rule}, + }, + query: "offset=1", + status: http.StatusOK, + err: nil, + }, + { + desc: "list rules with invalid offset", + domainID: domainID, + token: validToken, + query: "offset=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list rules with limit", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{ + Total: 1, + + Rules: []re.Rule{rule}, + }, + query: "limit=1", + status: http.StatusOK, + err: nil, + }, + { + desc: "list rules with invalid limit", + domainID: domainID, + token: validToken, + query: "limit=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list rules with invalid direction", + domainID: domainID, + token: validToken, + query: "dir=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidDirection, + }, + { + desc: "list rules with invalid order", + domainID: domainID, + token: validToken, + query: "order=invalid", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "list rule with limit that is too big", + domainID: domainID, + token: validToken, + query: "limit=10000", + status: http.StatusBadRequest, + err: apiutil.ErrLimitSize, + }, + { + desc: "list rules with input channel", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{ + Total: 1, + Rules: []re.Rule{rule}, + }, + query: "input_channel=input.channel", + status: http.StatusOK, + err: nil, + }, + { + desc: "list rules with duplicate input_channel", + domainID: domainID, + token: validToken, + query: "input_channel=1&input_channel=2", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list rules with status", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{ + Total: 1, + Rules: []re.Rule{rule}, + }, + query: "status=enabled", + status: http.StatusOK, + err: nil, + }, + { + desc: "list rules with invalid status", + domainID: domainID, + token: validToken, + query: "status=invalid", + status: http.StatusBadRequest, + err: svcerr.ErrInvalidStatus, + }, + { + desc: "list rules with duplicate status", + domainID: domainID, + token: validToken, + query: "status=enabled&status=disabled", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list rules with duplicate tags", + domainID: domainID, + token: validToken, + query: "tag=tag1&tag=tag2", + status: http.StatusBadRequest, + err: apiutil.ErrInvalidQueryParams, + }, + { + desc: "list rules with service error", + domainID: domainID, + token: validToken, + listRulesResponse: re.Page{}, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: ts.URL + "/" + tc.domainID + "/rules?" + tc.query, + contentType: contentType, + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("ListRules", mock.Anything, tc.session, mock.Anything).Return(tc.listRulesResponse, tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var bodyRes respBody + err = json.NewDecoder(res.Body).Decode(&bodyRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if bodyRes.Err != "" || bodyRes.Message != "" { + err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRulesEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + updateRuleReq := re.Rule{ + ID: rule.ID, + Name: rule.Name, + Logic: re.Script{ + Type: re.ScriptType(0), + Value: "return `test` end", + }, + InputChannel: testsutil.GenerateUUID(&testing.T{}), + Metadata: map[string]any{ + "name": "test", + }, + } + + cases := []struct { + desc string + token string + id string + domainID string + updateReq re.Rule + contentType string + session smqauthn.Session + svcResp re.Rule + svcErr error + status int + authnErr error + err error + }{ + { + desc: "update rule successfully", + token: validToken, + domainID: domainID, + id: rule.ID, + updateReq: updateRuleReq, + contentType: contentType, + svcResp: rule, + status: http.StatusOK, + err: nil, + }, + { + desc: "update rule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: rule.ID, + updateReq: updateRuleReq, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update rule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: rule.ID, + updateReq: updateRuleReq, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update rule with empty domainID", + token: validToken, + id: rule.ID, + updateReq: updateRuleReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update rule with name that is too long", + token: validToken, + id: validID, + domainID: domainID, + updateReq: re.Rule{ + ID: validID, + Name: strings.Repeat("a", 1025), + Logic: re.Script{ + Type: re.ScriptType(0), + Value: "return `test` end", + }, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrNameSize, + }, + { + desc: "update rule with invalid content type", + token: validToken, + id: rule.ID, + domainID: domainID, + updateReq: updateRuleReq, + contentType: "application/xml", + svcResp: rule, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update rule with service error", + token: validToken, + id: rule.ID, + domainID: domainID, + updateReq: updateRuleReq, + contentType: contentType, + svcResp: re.Rule{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.updateReq) + req := testRequest{ + client: ts.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateRule", mock.Anything, tc.session, tc.updateReq).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRuleTagsEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + newTag := "newtag" + + cases := []struct { + desc string + token string + id string + domainID string + data string + contentType string + session smqauthn.Session + svcResp re.Rule + svcErr error + resp re.Rule + status int + authnErr error + err error + }{ + { + desc: "update rule tags successfully", + token: validToken, + domainID: domainID, + id: validID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: contentType, + svcResp: rule, + status: http.StatusOK, + err: nil, + }, + { + desc: "update rule tags with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update rule tags with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update rule tags with empty domainID", + token: validToken, + id: validID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update rule tags with invalid content type", + token: validToken, + id: validID, + domainID: domainID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: "application/xml", + svcResp: rule, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update rule tags with service error", + token: validToken, + id: validID, + domainID: domainID, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + contentType: contentType, + svcResp: re.Rule{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "update rule with malformed request", + token: validToken, + id: validID, + domainID: domainID, + contentType: contentType, + data: fmt.Sprintf(`{"tags":["%s"}`, newTag), + status: http.StatusBadRequest, + err: apiutil.ErrMalformedRequestBody, + }, + { + desc: "update rule with empty id", + token: validToken, + id: "", + domainID: domainID, + contentType: contentType, + data: fmt.Sprintf(`{"tags":["%s"]}`, newTag), + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/rules/%s/tags", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(tc.data), + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateRuleTags", mock.Anything, tc.session, re.Rule{ID: tc.id, Tags: []string{newTag}}).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateRuleScheduleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + updateScheduleReq := pkgSch.Schedule{ + StartDateTime: future, + Time: future.Add(2 * time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + } + + ruleWithSchedule := rule + ruleWithSchedule.Schedule = updateScheduleReq + + cases := []struct { + desc string + token string + id string + domainID string + schedule pkgSch.Schedule + contentType string + session smqauthn.Session + svcResp re.Rule + svcErr error + status int + authnErr error + err error + }{ + { + desc: "update rule schedule successfully", + token: validToken, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + svcResp: ruleWithSchedule, + status: http.StatusOK, + err: nil, + }, + { + desc: "update rule schedule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update rule schedule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update rule schedule with empty domainID", + token: validToken, + id: validID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update rule schedule with invalid content type", + token: validToken, + id: validID, + domainID: domainID, + schedule: updateScheduleReq, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update rule schedule with start_datetime in past", + token: validToken, + id: validID, + domainID: domainID, + schedule: pkgSch.Schedule{ + StartDateTime: past, + Time: future, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "update rule schedule with service error", + token: validToken, + id: validID, + domainID: domainID, + schedule: updateScheduleReq, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "update rule schedule with empty id", + token: validToken, + id: "", + domainID: domainID, + schedule: updateScheduleReq, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(map[string]any{ + "schedule": tc.schedule, + }) + + req := testRequest{ + client: ts.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/rules/%s/schedule", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateRuleSchedule", mock.Anything, mock.Anything, mock.Anything).Return(tc.svcResp, tc.svcErr) + + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestEnableRuleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcResp re.Rule + svcErr error + status int + authnErr error + err error + }{ + { + desc: "enable rule successfully", + token: validToken, + domainID: domainID, + id: validID, + svcResp: rule, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "enable rule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "enable rule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "enable rule with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "enable rule with service error", + token: validToken, + id: validID, + domainID: domainID, + svcResp: re.Rule{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "enable rule with empty id", + token: validToken, + id: "", + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/rules/%s/enable", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("EnableRule", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDisableRuleEndpoint(t *testing.T) { + gs, svc, authn := newRuleEngineServer() + defer gs.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcResp re.Rule + svcErr error + status int + authnErr error + err error + }{ + { + desc: "disable rule successfully", + token: validToken, + domainID: domainID, + id: validID, + svcResp: rule, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "disable rule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "disable rule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "disable rule with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "disable rule with service error", + token: validToken, + id: validID, + domainID: domainID, + svcResp: re.Rule{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "disable rule with empty id", + token: validToken, + id: "", + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: gs.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/rules/%s/disable", gs.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("DisableRule", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteRuleEndpoint(t *testing.T) { + ts, svc, authn := newRuleEngineServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcErr error + status int + authnErr error + err error + }{ + { + desc: "delete rule successfully", + token: validToken, + domainID: domainID, + id: validID, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "delete rule with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "delete rule with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "delete rule with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "delete rule with service error", + token: validToken, + id: validID, + domainID: domainID, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/rules/%s", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RemoveRule", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +type respBody struct { + Err string `json:"error"` + Message string `json:"message"` + Total uint64 `json:"total"` + ID string `json:"id"` + Status re.Status `json:"status"` +} diff --git a/re/api/requests.go b/re/api/requests.go new file mode 100644 index 000000000..d7b5a858d --- /dev/null +++ b/re/api/requests.go @@ -0,0 +1,137 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/re" +) + +const ( + maxLimitSize = 1000 + MaxNameSize = 1024 + MaxTitleSize = 37 +) + +type addRuleReq struct { + re.Rule +} + +func (req addRuleReq) validate() error { + if len(req.Name) > api.MaxNameSize || req.Name == "" { + return apiutil.ErrNameSize + } + if err := req.Rule.Schedule.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + + return nil +} + +type viewRuleReq struct { + id string + withRoles bool +} + +func (req viewRuleReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type listRulesReq struct { + re.PageMeta +} + +func (req listRulesReq) validate() error { + if req.Limit > maxLimitSize { + return apiutil.ErrLimitSize + } + + switch req.Order { + case "", api.NameKey, api.CreatedAtOrder, api.UpdatedAtOrder: + default: + return errors.Wrap(apiutil.ErrInvalidOrder, apiutil.ErrValidation) + } + + if req.Dir != api.AscDir && req.Dir != api.DescDir { + return apiutil.ErrInvalidDirection + } + + return nil +} + +type updateRuleReq struct { + Rule re.Rule +} + +func (req updateRuleReq) validate() error { + if req.Rule.ID == "" { + return apiutil.ErrMissingID + } + if len(req.Rule.Name) > api.MaxNameSize { + return apiutil.ErrNameSize + } + + return nil +} + +type updateRuleTagsReq struct { + id string + Tags []string `json:"tags,omitempty"` +} + +func (req updateRuleTagsReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type updateRuleScheduleReq struct { + id string + Schedule schedule.Schedule `json:"schedule,omitempty"` +} + +func (req updateRuleScheduleReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + if err := req.Schedule.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + + return nil +} + +type updateRuleStatusReq struct { + id string +} + +func (req updateRuleStatusReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type deleteRuleReq struct { + id string +} + +func (req deleteRuleReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/re/api/responses.go b/re/api/responses.go new file mode 100644 index 000000000..78d3063aa --- /dev/null +++ b/re/api/responses.go @@ -0,0 +1,138 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/re" +) + +var ( + _ supermq.Response = (*viewRuleRes)(nil) + _ supermq.Response = (*addRuleRes)(nil) + _ supermq.Response = (*updateRuleStatusRes)(nil) + _ supermq.Response = (*rulesPageRes)(nil) + _ supermq.Response = (*updateRuleRes)(nil) + _ supermq.Response = (*deleteRuleRes)(nil) +) + +type pageRes struct { + Limit uint64 `json:"limit,omitempty"` + Offset uint64 `json:"offset"` + Total uint64 `json:"total"` +} + +type addRuleRes struct { + re.Rule + created bool +} + +func (res addRuleRes) Code() int { + if res.created { + return http.StatusCreated + } + + return http.StatusOK +} + +func (res addRuleRes) Headers() map[string]string { + if res.created { + return map[string]string{ + "Location": fmt.Sprintf("/rules/%s", res.ID), + } + } + + return map[string]string{} +} + +func (res addRuleRes) Empty() bool { + return false +} + +type updateRuleRes struct { + re.Rule `json:",inline"` +} + +func (res updateRuleRes) Code() int { + return http.StatusOK +} + +func (res updateRuleRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updateRuleRes) Empty() bool { + return false +} + +type viewRuleRes struct { + re.Rule `json:",inline"` +} + +func (res viewRuleRes) Code() int { + return http.StatusOK +} + +func (res viewRuleRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewRuleRes) Empty() bool { + return false +} + +type rulesPageRes struct { + re.Page `json:",inline"` +} + +func (res rulesPageRes) Code() int { + return http.StatusOK +} + +func (res rulesPageRes) Headers() map[string]string { + return map[string]string{} +} + +func (res rulesPageRes) Empty() bool { + return false +} + +type updateRuleStatusRes struct { + re.Rule `json:",inline"` +} + +func (res updateRuleStatusRes) Code() int { + return http.StatusOK +} + +func (res updateRuleStatusRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updateRuleStatusRes) Empty() bool { + return false +} + +type deleteRuleRes struct { + deleted bool +} + +func (res deleteRuleRes) Code() int { + if res.deleted { + return http.StatusNoContent + } + + return http.StatusOK +} + +func (res deleteRuleRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteRuleRes) Empty() bool { + return true +} diff --git a/re/api/transport.go b/re/api/transport.go new file mode 100644 index 000000000..64d9bc17b --- /dev/null +++ b/re/api/transport.go @@ -0,0 +1,247 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strings" + + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api" + "github.com/absmach/supermq/re" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + ruleIdKey = "ruleID" + inputChannelKey = "input_channel" +) + +// MakeHandler creates an HTTP handler for the service endpoints. +func MakeHandler(svc re.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + mux.Group(func(r chi.Router) { + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) + r.Route("/{domainID}", func(r chi.Router) { + r.Route("/rules", func(r chi.Router) { + d := roleManagerHttp.NewDecoder("ruleID") + + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + addRuleEndpoint(svc), + decodeAddRuleRequest, + api.EncodeResponse, + opts..., + ), "create_rule").ServeHTTP) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listRulesEndpoint(svc), + decodeListRulesRequest, + api.EncodeResponse, + opts..., + ), "list_rules").ServeHTTP) + + r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts) + + r.Route("/{ruleID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewRuleEndpoint(svc), + decodeViewRuleRequest, + api.EncodeResponse, + opts..., + ), "view_rule").ServeHTTP) + + r.Patch("/", otelhttp.NewHandler(kithttp.NewServer( + updateRuleEndpoint(svc), + decodeUpdateRuleRequest, + api.EncodeResponse, + opts..., + ), "update_rule").ServeHTTP) + + r.Patch("/tags", otelhttp.NewHandler(kithttp.NewServer( + updateRuleTagsEndpoint(svc), + decodeUpdateRuleTags, + api.EncodeResponse, + opts..., + ), "update_rule_tags").ServeHTTP) + + r.Patch("/schedule", otelhttp.NewHandler(kithttp.NewServer( + updateRuleScheduleEndpoint(svc), + decodeUpdateRuleScheduleRequest, + api.EncodeResponse, + opts..., + ), "update_rule_scheduler").ServeHTTP) + + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteRuleEndpoint(svc), + decodeDeleteRuleRequest, + api.EncodeResponse, + opts..., + ), "delete_rule").ServeHTTP) + + r.Post("/enable", otelhttp.NewHandler(kithttp.NewServer( + enableRuleEndpoint(svc), + decodeUpdateRuleStatusRequest, + api.EncodeResponse, + opts..., + ), "enable_rule").ServeHTTP) + + r.Post("/disable", otelhttp.NewHandler(kithttp.NewServer( + disableRuleEndpoint(svc), + decodeUpdateRuleStatusRequest, + api.EncodeResponse, + opts..., + ), "disable_rule").ServeHTTP) + + roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts) + }) + }) + }) + }) + + mux.Get("/health", supermq.Health("rule_engine", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeAddRuleRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + var rule re.Rule + if err := json.NewDecoder(r.Body).Decode(&rule); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + return addRuleReq{Rule: rule}, nil +} + +func decodeViewRuleRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + withRoles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + return viewRuleReq{id: id, withRoles: withRoles}, nil +} + +func decodeUpdateRuleRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + var rule re.Rule + if err := json.NewDecoder(r.Body).Decode(&rule); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + rule.ID = chi.URLParam(r, ruleIdKey) + + return updateRuleReq{Rule: rule}, nil +} + +func decodeUpdateRuleTags(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateRuleTagsReq{ + id: chi.URLParam(r, ruleIdKey), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateRuleScheduleRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateRuleScheduleReq{ + id: chi.URLParam(r, ruleIdKey), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateRuleStatusRequest(_ context.Context, r *http.Request) (any, error) { + req := updateRuleStatusReq{ + id: chi.URLParam(r, ruleIdKey), + } + + return req, nil +} + +func decodeListRulesRequest(_ context.Context, r *http.Request) (any, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + name, err := apiutil.ReadStringQuery(r, api.NameKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + ic, err := apiutil.ReadStringQuery(r, inputChannelKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefStatus) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + st, err := re.ToStatus(s) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + tag, err := apiutil.ReadStringQuery(r, api.TagKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + return listRulesReq{ + PageMeta: re.PageMeta{ + Offset: offset, + Limit: limit, + Name: name, + InputChannel: ic, + Status: st, + Dir: dir, + Order: order, + Tag: tag, + }, + }, nil +} + +func decodeDeleteRuleRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, ruleIdKey) + + return deleteRuleReq{id: id}, nil +} diff --git a/re/builtinroles.go b/re/builtinroles.go new file mode 100644 index 000000000..3b242c116 --- /dev/null +++ b/re/builtinroles.go @@ -0,0 +1,8 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import "github.com/absmach/supermq/pkg/roles" + +const BuiltInRoleAdmin roles.BuiltInRoleName = "admin" diff --git a/re/doc.go b/re/doc.go new file mode 100644 index 000000000..2c28de3bd --- /dev/null +++ b/re/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package re contain the domain concept definitions needed to +// support Magistrala Rule Egine services functionality. +package re diff --git a/mqtt/events/doc.go b/re/events/doc.go similarity index 83% rename from mqtt/events/doc.go rename to re/events/doc.go index 83ccf23cb..720686489 100644 --- a/mqtt/events/doc.go +++ b/re/events/doc.go @@ -2,5 +2,5 @@ // SPDX-License-Identifier: Apache-2.0 // Package events provides the domain concept definitions needed to support -// mqtt events functionality. +// clients events functionality. package events diff --git a/re/events/events.go b/re/events/events.go new file mode 100644 index 000000000..09fda86ec --- /dev/null +++ b/re/events/events.go @@ -0,0 +1,192 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package events + +import ( + "maps" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/re" +) + +const ( + rulePrefix = "rule." + ruleCreate = rulePrefix + "create" + ruleList = rulePrefix + "list" + ruleView = rulePrefix + "view" + ruleUpdate = rulePrefix + "update" + ruleUpdateTags = rulePrefix + "update_tags" + ruleUpdateSchedule = rulePrefix + "update_schedule" + ruleEnable = rulePrefix + "enable" + ruleDisable = rulePrefix + "disable" + ruleRemove = rulePrefix + "remove" +) + +var ( + _ events.Event = (*createRuleEvent)(nil) + _ events.Event = (*listRuleEvent)(nil) + _ events.Event = (*viewRuleEvent)(nil) + _ events.Event = (*updateRuleEvent)(nil) + _ events.Event = (*updateRuleTagsEvent)(nil) + _ events.Event = (*updateRuleScheduleEvent)(nil) + _ events.Event = (*enableRuleEvent)(nil) + _ events.Event = (*disableRuleEvent)(nil) + _ events.Event = (*removeRuleEvent)(nil) +) + +type baseRuleEvent struct { + session authn.Session + requestID string +} + +func newBaseRuleEvent(session authn.Session, requestID string) baseRuleEvent { + return baseRuleEvent{ + session: session, + requestID: requestID, + } +} + +func (bre baseRuleEvent) Encode() map[string]any { + return map[string]any{ + "domain": bre.session.DomainID, + "user_id": bre.session.UserID, + "token_type": bre.session.Type.String(), + "super_admin": bre.session.SuperAdmin, + "request_id": bre.requestID, + } +} + +type createRuleEvent struct { + rule re.Rule + rolesProvisioned []roles.RoleProvision + baseRuleEvent +} + +func (cre createRuleEvent) Encode() (map[string]any, error) { + val, err := cre.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, cre.baseRuleEvent.Encode()) + val["operation"] = ruleCreate + val["roles_provisioned"] = cre.rolesProvisioned + return val, nil +} + +type listRuleEvent struct { + re.PageMeta + baseRuleEvent +} + +// Encode implements the events.Event interface for listRuleEvent. +func (lre listRuleEvent) Encode() (map[string]any, error) { + val := lre.PageMeta.EventEncode() + maps.Copy(val, lre.baseRuleEvent.Encode()) + val["operation"] = ruleList + return val, nil +} + +type updateRuleEvent struct { + rule re.Rule + baseRuleEvent +} + +type viewRuleEvent struct { + rule re.Rule + baseRuleEvent +} + +func (vre viewRuleEvent) Encode() (map[string]any, error) { + val, err := vre.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, vre.baseRuleEvent.Encode()) + val["operation"] = ruleView + return val, nil +} + +func (ure updateRuleEvent) Encode() (map[string]any, error) { + val, err := ure.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, ure.baseRuleEvent.Encode()) + val["operation"] = ruleUpdate + return val, nil +} + +type updateRuleTagsEvent struct { + rule re.Rule + baseRuleEvent +} + +func (urte updateRuleTagsEvent) Encode() (map[string]any, error) { + val, err := urte.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, urte.baseRuleEvent.Encode()) + val["operation"] = ruleUpdateTags + return val, nil +} + +type updateRuleScheduleEvent struct { + rule re.Rule + baseRuleEvent +} + +func (urse updateRuleScheduleEvent) Encode() (map[string]any, error) { + val, err := urse.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, urse.baseRuleEvent.Encode()) + val["operation"] = ruleUpdateSchedule + return val, nil +} + +type disableRuleEvent struct { + rule re.Rule + baseRuleEvent +} + +func (dre disableRuleEvent) Encode() (map[string]any, error) { + val, err := dre.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, dre.baseRuleEvent.Encode()) + val["operation"] = ruleDisable + return val, nil +} + +type enableRuleEvent struct { + rule re.Rule + baseRuleEvent +} + +func (ere enableRuleEvent) Encode() (map[string]any, error) { + val, err := ere.rule.EventEncode() + if err != nil { + return map[string]any{}, err + } + maps.Copy(val, ere.baseRuleEvent.Encode()) + val["operation"] = ruleEnable + return val, nil +} + +type removeRuleEvent struct { + id string + baseRuleEvent +} + +func (rre removeRuleEvent) Encode() (map[string]any, error) { + val := rre.baseRuleEvent.Encode() + val["id"] = rre.id + val["operation"] = ruleRemove + return val, nil +} diff --git a/re/events/streams.go b/re/events/streams.go new file mode 100644 index 000000000..36093a3fd --- /dev/null +++ b/re/events/streams.go @@ -0,0 +1,203 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package events + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events" + "github.com/absmach/supermq/re" + "github.com/go-chi/chi/v5/middleware" +) + +const ( + supermqPrefix = "supermq." + CreateStream = supermqPrefix + ruleCreate + ListStream = supermqPrefix + ruleList + ViewStream = supermqPrefix + ruleView + UpdateStream = supermqPrefix + ruleUpdate + UpdateTagsStream = supermqPrefix + ruleUpdateTags + UpdateScheduleStream = supermqPrefix + ruleUpdateSchedule + EnableStream = supermqPrefix + ruleEnable + DisableStream = supermqPrefix + ruleDisable + RemoveStream = supermqPrefix + ruleRemove +) + +var _ re.Service = (*eventStore)(nil) + +type eventStore struct { + events.Publisher + svc re.Service + rmEvents.RoleManagerEventStore +} + +// NewEventStoreMiddleware returns wrapper around rules service that sends +// events to event store. +func NewEventStoreMiddleware(ctx context.Context, svc re.Service, url string) (re.Service, error) { + publisher, err := store.NewPublisher(ctx, url, "re-es-pub") + if err != nil { + return nil, err + } + + res := rmEvents.NewRoleManagerEventStore("rules", rulePrefix, svc, publisher) + + return &eventStore{ + svc: svc, + Publisher: publisher, + RoleManagerEventStore: res, + }, nil +} + +func (es *eventStore) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + rule, rps, err := es.svc.AddRule(ctx, session, r) + if err != nil { + return rule, rps, err + } + event := createRuleEvent{ + rule: rule, + rolesProvisioned: rps, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, CreateStream, event); err != nil { + return rule, rps, err + } + return rule, rps, nil +} + +func (es *eventStore) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + page, err := es.svc.ListRules(ctx, session, pm) + if err != nil { + return page, err + } + event := listRuleEvent{ + PageMeta: pm, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, ListStream, event); err != nil { + return page, err + } + return page, nil +} + +func (es *eventStore) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + rule, err := es.svc.ViewRule(ctx, session, id, withRoles) + if err != nil { + return rule, err + } + event := viewRuleEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, ViewStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + rule, err := es.svc.UpdateRule(ctx, session, r) + if err != nil { + return rule, err + } + event := updateRuleEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, UpdateStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + rule, err := es.svc.UpdateRuleTags(ctx, session, r) + if err != nil { + return rule, err + } + event := updateRuleTagsEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, UpdateTagsStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + rule, err := es.svc.UpdateRuleSchedule(ctx, session, r) + if err != nil { + return rule, err + } + event := updateRuleScheduleEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, UpdateScheduleStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) RemoveRule(ctx context.Context, session authn.Session, id string) error { + err := es.svc.RemoveRule(ctx, session, id) + if err != nil { + return err + } + event := removeRuleEvent{ + id: id, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, RemoveStream, event); err != nil { + return err + } + return nil +} + +func (es *eventStore) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + rule, err := es.svc.EnableRule(ctx, session, id) + if err != nil { + return rule, err + } + event := enableRuleEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, EnableStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + rule, err := es.svc.DisableRule(ctx, session, id) + if err != nil { + return rule, err + } + event := disableRuleEvent{ + rule: rule, + baseRuleEvent: newBaseRuleEvent(session, middleware.GetReqID(ctx)), + } + if err := es.Publish(ctx, DisableStream, event); err != nil { + return rule, err + } + return rule, nil +} + +func (es *eventStore) StartScheduler(ctx context.Context) error { + return es.svc.StartScheduler(ctx) +} + +func (es *eventStore) Handle(msg *messaging.Message) error { + return es.svc.Handle(msg) +} + +func (es *eventStore) Cancel() error { + return es.svc.Cancel() +} diff --git a/re/golang.go b/re/golang.go new file mode 100644 index 000000000..407de626c --- /dev/null +++ b/re/golang.go @@ -0,0 +1,104 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "reflect" + "regexp" + + "github.com/absmach/supermq/pkg/errors" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + golang "github.com/traefik/yaegi/interp" + "github.com/traefik/yaegi/stdlib" +) + +const logicFunction = "main.logicFunction" + +var ( + goKeywordRegex = regexp.MustCompile(`\bgo\s+func\s*\(|^\s*go\s+\w+\(|[;\s{]go\s+func\s*\(|[;\s{]go\s+\w+\(`) + panicRegex = regexp.MustCompile(`\bpanic\s*\(`) +) + +// Type message is an SMQ message with payload replaces by JSON deserialized payload. +type message struct { + Channel string `json:"channel,omitempty"` + ClientID string `json:"client_id,omitempty"` + Domain string `json:"domain,omitempty"` + Subtopic string `json:"subtopic,omitempty"` + Publisher string `json:"publisher,omitempty"` + Protocol string `json:"protocol,omitempty"` + Created int64 `json:"created,omitempty"` + Payload any `json:"payload,omitempty"` +} + +func (re *re) processGo(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) (ret pkglog.RunInfo) { + defer func() { + if r := recover(); r != nil { + ret = pkglog.RunInfo{ + Level: slog.LevelError, + Details: details, + Message: fmt.Sprintf("panic in Go script: %v", r), + } + } + }() + + i := golang.New(golang.Options{}) + if err := i.Use(stdlib.Symbols); err != nil { + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} + } + m := message{ + Created: msg.Created, + ClientID: msg.ClientIdentity(), + Domain: msg.Domain, + Publisher: msg.Publisher, + Channel: msg.Channel, + Subtopic: msg.Subtopic, + Protocol: msg.Protocol, + } + var pld any + if err := json.Unmarshal(msg.Payload, &pld); err != nil { + pld = msg.Payload + } + m.Payload = pld + + err := i.Use(golang.Exports{ + "messaging/m": { + "message": reflect.ValueOf(m), + }, + }) + if err != nil { + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} + } + if _, err = i.Eval(r.Logic.Value); err != nil { + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} + } + ifc, err := i.Eval(logicFunction) + if err != nil { + return pkglog.RunInfo{Level: slog.LevelError, Details: details, Message: err.Error()} + } + f, ok := ifc.Interface().(func() any) + if !ok { + return pkglog.RunInfo{Level: slog.LevelError, Message: "invalid logic function signature", Details: details} + } + res := f() + if b, ok := res.(bool); ok && !b { + return pkglog.RunInfo{Level: slog.LevelInfo, Message: "logic returned false", Details: details} + } + for _, o := range r.Outputs { + if e := re.handleOutput(ctx, o, r, msg, res); e != nil { + err = errors.Wrap(e, err) + } + } + ret = pkglog.RunInfo{Level: slog.LevelInfo, Details: details, Message: "rule processed successfully"} + if err != nil { + ret.Level = slog.LevelError + ret.Message = fmt.Sprintf("failed to handle rule output: %s", err) + } + return ret +} diff --git a/re/handlers.go b/re/handlers.go new file mode 100644 index 000000000..a019bc13b --- /dev/null +++ b/re/handlers.go @@ -0,0 +1,164 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "fmt" + "log/slog" + "strconv" + "strings" + "time" + + "github.com/absmach/supermq/pkg/errors" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/re/outputs" +) + +var ( + scheduledTrue = true + scheduledFalse = false +) + +const ( + maxPayload = 100 * 1024 + pldExceededFmt = "max payload size of 100kB exceeded: " + protocol = "nats" +) + +func (re *re) Handle(msg *messaging.Message) error { + // Limit payload for RE so we don't get to process large JSON. + if n := len(msg.Payload); n > maxPayload { + return errors.New(pldExceededFmt + strconv.Itoa(n)) + } + // Skip filtering by message topic and fetch all non-scheduled rules instead. + // It's cleaner and more efficient to match wildcards in Go, but we can + // revisit this if it ever becomes a performance bottleneck. + pm := PageMeta{ + Domain: msg.Domain, + InputChannel: msg.Channel, + Status: EnabledStatus, + Scheduled: &scheduledFalse, + } + ctx := context.Background() + page, err := re.repo.ListAllRules(ctx, pm) + if err != nil { + return err + } + for _, r := range page.Rules { + if matchSubject(msg.Subtopic, r.InputTopic) { + go func(ctx context.Context) { + re.runInfo <- re.process(ctx, r, msg) + }(ctx) + } + } + + return nil +} + +// Match NATS subject to support wildcards. +func matchSubject(published, subscribed string) bool { + p := strings.Split(published, ".") + s := strings.Split(subscribed, ".") + n := len(p) + + for i := range s { + if s[i] == ">" { + return true + } + if i >= n { + return false + } + if s[i] != "*" && p[i] != s[i] { + return false + } + } + return len(s) == n +} + +func (re *re) process(ctx context.Context, r Rule, msg *messaging.Message) pkglog.RunInfo { + details := []slog.Attr{ + slog.String("domain_id", r.DomainID), + slog.String("rule_id", r.ID), + slog.String("rule_name", r.Name), + slog.Time("exec_time", time.Now().UTC()), + } + switch r.Logic.Type { + case GoType: + return re.processGo(ctx, details, r, msg) + default: + return re.processLua(ctx, details, r, msg) + } +} + +func (re *re) handleOutput(ctx context.Context, o Runnable, r Rule, msg *messaging.Message, val any) error { + switch o := o.(type) { + case *outputs.Alarm: + o.AlarmsPub = re.alarmsPub + o.RuleID = r.ID + return o.Run(ctx, msg, val) + case *outputs.Email: + o.Emailer = re.email + return o.Run(ctx, msg, val) + case *outputs.ChannelPublisher: + o.RePubSub = re.rePubSub + return o.Run(ctx, msg, val) + case *outputs.SenML: + o.WritersPub = re.writersPub + return o.Run(ctx, msg, val) + case *outputs.Postgres, *outputs.Slack: + return o.Run(ctx, msg, val) + default: + return fmt.Errorf("unknown output type: %T", o) + } +} + +func (re *re) StartScheduler(ctx context.Context) error { + defer re.ticker.Stop() + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-re.ticker.Tick(): + due := time.Now().UTC() + pm := PageMeta{ + Status: EnabledStatus, + Scheduled: &scheduledTrue, + ScheduledBefore: &due, + } + + page, err := re.repo.ListAllRules(ctx, pm) + if err != nil { + re.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to list rules: %s", err), + Details: []slog.Attr{slog.Time("due", due)}, + } + + continue + } + + for _, r := range page.Rules { + go func(rule Rule, dueTime time.Time) { + if _, err := re.repo.UpdateRuleDue(ctx, rule.ID, rule.Schedule.NextDue()); err != nil { + re.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update rule: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}} + return + } + + msg := &messaging.Message{ + Domain: rule.DomainID, + Channel: rule.InputChannel, + Subtopic: rule.InputTopic, + Protocol: protocol, + Created: dueTime.Unix(), + } + re.runInfo <- re.process(ctx, rule, msg) + }(r, due) + } + // Reset due, it will reset in the page meta as well. + due = time.Now().UTC() + } + } +} diff --git a/re/lua.go b/re/lua.go new file mode 100644 index 000000000..8bc9e9085 --- /dev/null +++ b/re/lua.go @@ -0,0 +1,185 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "github.com/absmach/supermq/pkg/errors" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + "github.com/vadv/gopher-lua-libs/argparse" + "github.com/vadv/gopher-lua-libs/base64" + bit "github.com/vadv/gopher-lua-libs/bit" + "github.com/vadv/gopher-lua-libs/crypto" + "github.com/vadv/gopher-lua-libs/db" + "github.com/vadv/gopher-lua-libs/filepath" + client "github.com/vadv/gopher-lua-libs/http/client" + "github.com/vadv/gopher-lua-libs/ioutil" + luajson "github.com/vadv/gopher-lua-libs/json" + "github.com/vadv/gopher-lua-libs/regexp" + "github.com/vadv/gopher-lua-libs/storage" + "github.com/vadv/gopher-lua-libs/strings" + luatime "github.com/vadv/gopher-lua-libs/time" + "github.com/vadv/gopher-lua-libs/yaml" + lua "github.com/yuin/gopher-lua" +) + +const payloadKey = "payload" + +func (re *re) processLua(ctx context.Context, details []slog.Attr, r Rule, msg *messaging.Message) pkglog.RunInfo { + l := lua.NewState() + defer l.Close() + preload(l) + message := prepareMsg(l, msg) + + // Set the message object as a Lua global variable. + l.SetGlobal("message", message) + if err := l.DoString(r.Logic.Value); err != nil { + return pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to run rule logic: %s", err), Details: details} + } + // Get the last result. + result := l.Get(-1) + if result == lua.LNil { + return pkglog.RunInfo{Level: slog.LevelWarn, Message: "rule with nil script result", Details: details} + } + // Converting Lua is an expensive operation, so + // don't do it if there are no outputs. + if len(r.Outputs) == 0 { + return pkglog.RunInfo{Level: slog.LevelWarn, Message: "rule with no outputs", Details: details} + } + var err error + res := convertLua(result) + + for _, o := range r.Outputs { + // If value is false, don't run the follow-up. + if v, ok := res.(bool); ok && !v { + return pkglog.RunInfo{Level: slog.LevelInfo, Message: "logic returned false", Details: details} + } + if e := re.handleOutput(ctx, o, r, msg, res); e != nil { + err = errors.Wrap(e, err) + } + } + ret := pkglog.RunInfo{Level: slog.LevelInfo, Message: "rule processed successfully", Details: details} + if err != nil { + ret.Level = slog.LevelError + ret.Message = fmt.Sprintf("failed to handle rule output: %s", err) + } + return ret +} + +func preload(l *lua.LState) { + db.Preload(l) + ioutil.Preload(l) + luajson.Preload(l) + yaml.Preload(l) + crypto.Preload(l) + regexp.Preload(l) + luatime.Preload(l) + storage.Preload(l) + base64.Preload(l) + argparse.Preload(l) + strings.Preload(l) + filepath.Preload(l) + client.Preload(l) + bit.Preload(l) +} + +func prepareMsg(l *lua.LState, msg *messaging.Message) lua.LValue { + message := l.NewTable() + message.RawSetString("domain", lua.LString(msg.Domain)) + message.RawSetString("channel", lua.LString(msg.Channel)) + message.RawSetString("subtopic", lua.LString(msg.Subtopic)) + message.RawSetString("client_id", lua.LString(msg.ClientIdentity())) + message.RawSetString("publisher", lua.LString(msg.Publisher)) + message.RawSetString("protocol", lua.LString(msg.Protocol)) + message.RawSetString("created", lua.LNumber(msg.Created)) + + var payload any + if err := json.Unmarshal(msg.GetPayload(), &payload); err != nil { + pld := l.NewTable() + // If message is not JSON, set binary payload and exit. + for i, b := range msg.Payload { + // Lua tables are 1-indexed. + pld.Insert(i+1, lua.LNumber(b)) + } + message.RawSetString(payloadKey, pld) + return message + } + + // Payload is JSON, set the correct value. + message.RawSetString(payloadKey, traverseJson(l, payload)) + return message +} + +func traverseJson(l *lua.LState, value any) lua.LValue { + switch val := value.(type) { + case string: + return lua.LString(val) + case float64: + return lua.LNumber(val) + case int: + return lua.LNumber(float64(val)) + case json.Number: + if num, err := val.Float64(); err != nil { + return lua.LNumber(num) + } + return lua.LNil + case bool: + return lua.LBool(val) + case []any: + t := l.NewTable() + for i, j := range val { + t.RawSetInt(i+1, traverseJson(l, j)) + } + return t + case map[string]any: + t := l.NewTable() + for k, v := range val { + t.RawSetString(k, traverseJson(l, v)) + } + return t + default: + return lua.LNil + } +} + +func convertLua(lv lua.LValue) any { + switch v := lv.(type) { + case *lua.LTable: + isArray := true + v.ForEach(func(key, value lua.LValue) { + if key.Type() != lua.LTNumber { + isArray = false + } + }) + + if isArray { + arr := []any{} + v.ForEach(func(key, value lua.LValue) { + arr = append(arr, convertLua(value)) + }) + return arr + } + + obj := map[string]any{} + v.ForEach(func(key, value lua.LValue) { + obj[key.String()] = convertLua(value) + }) + return obj + case lua.LString: + return string(v) + case lua.LNumber: + return float64(v) + case lua.LBool: + return bool(v) + case *lua.LNilType: + return nil + default: + return v.String() + } +} diff --git a/re/middleware/authorization.go b/re/middleware/authorization.go new file mode 100644 index 000000000..3235b25bd --- /dev/null +++ b/re/middleware/authorization.go @@ -0,0 +1,191 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/roles" + rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/operations" +) + +var ( + errDomainCreateRules = errors.New("not authorized to create rules in domain") + errDomainViewRules = errors.New("not authorized to view rules in domain") + errDomainUpdateRules = errors.New("not authorized to update rules in domain") + errDomainDeleteRules = errors.New("not authorized to delete rules in domain") +) + +type authorizationMiddleware struct { + svc re.Service + authz smqauthz.Authorization + entitiesOps permissions.EntitiesOperations[permissions.Operation] + rolemgr.RoleManagerAuthorizationMiddleware +} + +// AuthorizationMiddleware adds authorization to the re service. +func AuthorizationMiddleware(svc re.Service, authz smqauthz.Authorization, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) { + if err := entitiesOps.Validate(); err != nil { + return nil, err + } + ram, err := rolemgr.NewAuthorization(operations.EntityType, svc, authz, roleOps) + if err != nil { + return nil, err + } + return &authorizationMiddleware{ + svc: svc, + authz: authz, + entitiesOps: entitiesOps, + RoleManagerAuthorizationMiddleware: ram, + }, nil +} + +func (am *authorizationMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + if err := am.authorize(ctx, operations.OpAddRule, session, policies.DomainType, session.DomainID); err != nil { + return re.Rule{}, nil, errors.Wrap(errDomainCreateRules, err) + } + + return am.svc.AddRule(ctx, session, r) +} + +func (am *authorizationMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpViewRule, session, operations.EntityType, id); err != nil { + return re.Rule{}, errors.Wrap(errDomainViewRules, err) + } + + return am.svc.ViewRule(ctx, session, id, withRoles) +} + +func (am *authorizationMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpUpdateRule, session, operations.EntityType, r.ID); err != nil { + return re.Rule{}, errors.Wrap(errDomainUpdateRules, err) + } + + return am.svc.UpdateRule(ctx, session, r) +} + +func (am *authorizationMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpUpdateRuleTags, session, operations.EntityType, r.ID); err != nil { + return re.Rule{}, errors.Wrap(errDomainUpdateRules, err) + } + + return am.svc.UpdateRuleTags(ctx, session, r) +} + +func (am *authorizationMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpUpdateRuleSchedule, session, operations.EntityType, r.ID); err != nil { + return re.Rule{}, errors.Wrap(errDomainUpdateRules, err) + } + + return am.svc.UpdateRuleSchedule(ctx, session, r) +} + +func (am *authorizationMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: + session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return re.Page{}, err + } + + return am.svc.ListRules(ctx, session, pm) +} + +func (am *authorizationMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, operations.OpRemoveRule, session, operations.EntityType, id); err != nil { + return errors.Wrap(errDomainDeleteRules, err) + } + + return am.svc.RemoveRule(ctx, session, id) +} + +func (am *authorizationMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpEnableRule, session, operations.EntityType, id); err != nil { + return re.Rule{}, errors.Wrap(errDomainUpdateRules, err) + } + + return am.svc.EnableRule(ctx, session, id) +} + +func (am *authorizationMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + if err := am.authorize(ctx, operations.OpDisableRule, session, operations.EntityType, id); err != nil { + return re.Rule{}, errors.Wrap(errDomainUpdateRules, err) + } + + return am.svc.DisableRule(ctx, session, id) +} + +func (am *authorizationMiddleware) StartScheduler(ctx context.Context) error { + return am.svc.StartScheduler(ctx) +} + +func (am *authorizationMiddleware) Handle(msg *messaging.Message) error { + return am.svc.Handle(msg) +} + +func (am *authorizationMiddleware) Cancel() error { + return am.svc.Cancel() +} + +func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session, objType, obj string) error { + perm, err := am.entitiesOps.GetPermission(operations.EntityType, op) + if err != nil { + return err + } + + pr := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: obj, + ObjectType: objType, + Permission: perm.String(), + } + + var pat *smqauthz.PATReq + if session.PatID != "" { + opName := am.entitiesOps.OperationName(operations.EntityType, op) + pat = &smqauthz.PATReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityID: session.DomainID, + EntityType: operations.EntityType, + Operation: opName, + Domain: session.DomainID, + } + } + + if err := am.authz.Authorize(ctx, pr, pat); err != nil { + return err + } + + return nil +} + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error { + if session.Role != authn.SuperAdminRole { + return svcerr.ErrSuperAdminAction + } + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + SubjectType: policies.UserType, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.MagistralaObject, + }, nil); err != nil { + return err + } + return nil +} diff --git a/re/middleware/callout.go b/re/middleware/callout.go new file mode 100644 index 000000000..e354910ba --- /dev/null +++ b/re/middleware/callout.go @@ -0,0 +1,196 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/callout" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + mgPolicies "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/roles" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/operations" +) + +var _ re.Service = (*calloutMiddleware)(nil) + +type calloutMiddleware struct { + svc re.Service + callout callout.Callout + entitiesOps permissions.EntitiesOperations[permissions.Operation] + rolemw.RoleManagerCalloutMiddleware +} + +const entityType = "rule" + +func NewCallout(svc re.Service, callout callout.Callout, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (re.Service, error) { + call, err := rolemw.NewCallout(mgPolicies.RulesType, svc, callout, roleOps) + if err != nil { + return nil, err + } + + if err := entitiesOps.Validate(); err != nil { + return nil, err + } + + return &calloutMiddleware{ + svc: svc, + callout: callout, + entitiesOps: entitiesOps, + RoleManagerCalloutMiddleware: call, + }, nil +} + +func (cm *calloutMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + params := map[string]any{ + "entities": r, + "count": 1, + } + + if err := cm.callOut(ctx, session, operations.OpAddRule, params); err != nil { + return re.Rule{}, nil, err + } + + return cm.svc.AddRule(ctx, session, r) +} + +func (cm *calloutMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpViewRule, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.ViewRule(ctx, session, id, withRoles) +} + +func (cm *calloutMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + params := map[string]any{ + "entity_id": r.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateRule, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.UpdateRule(ctx, session, r) +} + +func (cm *calloutMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + params := map[string]any{ + "entity_id": r.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateRuleTags, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.UpdateRuleTags(ctx, session, r) +} + +func (cm *calloutMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + params := map[string]any{ + "entity_id": r.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateRuleSchedule, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.UpdateRuleSchedule(ctx, session, r) +} + +func (cm *calloutMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + params := map[string]any{ + "pagemeta": pm, + } + + if err := cm.callOut(ctx, session, operations.OpListRules, params); err != nil { + return re.Page{}, err + } + + return cm.svc.ListRules(ctx, session, pm) +} + +func (cm *calloutMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpRemoveRule, params); err != nil { + return err + } + + return cm.svc.RemoveRule(ctx, session, id) +} + +func (cm *calloutMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpEnableRule, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.EnableRule(ctx, session, id) +} + +func (cm *calloutMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpDisableRule, params); err != nil { + return re.Rule{}, err + } + + return cm.svc.DisableRule(ctx, session, id) +} + +func (cm *calloutMiddleware) StartScheduler(ctx context.Context) error { + return cm.svc.StartScheduler(ctx) +} + +func (cm *calloutMiddleware) Handle(msg *messaging.Message) error { + return cm.svc.Handle(msg) +} + +func (cm *calloutMiddleware) Cancel() error { + return cm.svc.Cancel() +} + +func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op permissions.Operation, pld map[string]any) error { + var entityID string + if id, ok := pld["entity_id"].(string); ok { + entityID = id + } + + req := callout.Request{ + BaseRequest: callout.BaseRequest{ + Operation: cm.entitiesOps.OperationName(entityType, op), + EntityType: entityType, + EntityID: entityID, + CallerID: session.UserID, + CallerType: policies.UserType, + DomainID: session.DomainID, + Time: time.Now().UTC(), + }, + Payload: pld, + } + + if err := cm.callout.Callout(ctx, req); err != nil { + return err + } + + return nil +} diff --git a/re/middleware/logging.go b/re/middleware/logging.go new file mode 100644 index 000000000..5d1fcb2c3 --- /dev/null +++ b/re/middleware/logging.go @@ -0,0 +1,250 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/re" +) + +var _ re.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc re.Service + rolemw.RoleManagerLoggingMiddleware +} + +func LoggingMiddleware(svc re.Service, logger *slog.Logger) re.Service { + return &loggingMiddleware{ + logger: logger, + svc: svc, + RoleManagerLoggingMiddleware: rolemw.NewLogging("re", svc, logger), + } +} + +func (lm *loggingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, rps []roles.RoleProvision, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("rule_name", r.Name), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Add rule failed", args...) + return + } + lm.logger.Info("Add rule completed successfully", args...) + }(time.Now()) + res, rps, err = lm.svc.AddRule(ctx, session, r) + return +} + +func (lm *loggingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("View rule failed", args...) + return + } + lm.logger.Info("View rule completed successfully", args...) + }(time.Now()) + return lm.svc.ViewRule(ctx, session, id, withRoles) +} + +func (lm *loggingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", r.ID), + slog.String("name", r.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update rule failed", args...) + return + } + lm.logger.Info("Update rule completed successfully", args...) + }(time.Now()) + return lm.svc.UpdateRule(ctx, session, r) +} + +func (lm *loggingMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", r.ID), + slog.String("name", r.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update rule failed", args...) + return + } + lm.logger.Info("Update rule tags completed successfully", args...) + }(time.Now()) + return lm.svc.UpdateRuleTags(ctx, session, r) +} + +func (lm *loggingMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", r.ID), + slog.Any("schedule", r.Schedule), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update rule schedule failed", args...) + return + } + lm.logger.Info("Update rule schedule completed successfully", args...) + }(time.Now()) + return lm.svc.UpdateRuleSchedule(ctx, session, r) +} + +func (lm *loggingMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (pg re.Page, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("page", + slog.Uint64("offset", pm.Offset), + slog.Uint64("limit", pm.Limit), + slog.Uint64("total", pg.Total), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("List rules failed", args...) + return + } + lm.logger.Info("List rules completed successfully", args...) + }(time.Now()) + return lm.svc.ListRules(ctx, session, pm) +} + +func (lm *loggingMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("rule_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Remove rule failed", args...) + return + } + lm.logger.Info("Remove rule completed successfully", args...) + }(time.Now()) + return lm.svc.RemoveRule(ctx, session, id) +} + +func (lm *loggingMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Enable rule failed", args...) + return + } + lm.logger.Info("Enable rule completed successfully", args...) + }(time.Now()) + return lm.svc.EnableRule(ctx, session, id) +} + +func (lm *loggingMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (res re.Rule, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("rule", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Disable rule failed", args...) + return + } + lm.logger.Info("Disable rule completed successfully", args...) + }(time.Now()) + return lm.svc.DisableRule(ctx, session, id) +} + +func (lm *loggingMiddleware) StartScheduler(ctx context.Context) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Start scheduler failed", args...) + return + } + lm.logger.Info("Start scheduler completed successfully", args...) + }(time.Now()) + return lm.svc.StartScheduler(ctx) +} + +func (lm *loggingMiddleware) Handle(msg *messaging.Message) (err error) { + defer func(begin time.Time) { + // Log only failure since the handlers are executed async and will always + // return nil error. The rest of the loggin is performed in main.go error loop. + if err != nil { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if msg != nil { + args = append(args, + slog.String("channel", msg.Channel), + slog.String("payload_size", fmt.Sprintf("%d", len(msg.Payload))), + ) + } + lm.logger.Warn("Message consumption completed", args...) + } + }(time.Now()) + + err = lm.svc.Handle(msg) + return +} + +func (lm *loggingMiddleware) Cancel() error { + return lm.svc.Cancel() +} diff --git a/re/middleware/metrics.go b/re/middleware/metrics.go new file mode 100644 index 000000000..9fa90bc6b --- /dev/null +++ b/re/middleware/metrics.go @@ -0,0 +1,137 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/re" + "github.com/go-kit/kit/metrics" +) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + service re.Service + rolemw.RoleManagerMetricsMiddleware +} + +var _ re.Service = (*metricsMiddleware)(nil) + +func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service re.Service) re.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + service: service, + RoleManagerMetricsMiddleware: rolemw.NewMetrics("re", service, counter, latency), + } +} + +func (mm *metricsMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + defer func(begin time.Time) { + mm.counter.With("method", "add_rule").Add(1) + mm.latency.With("method", "add_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.AddRule(ctx, session, r) +} + +func (mm *metricsMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "view_rule").Add(1) + mm.latency.With("method", "view_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewRule(ctx, session, id, withRoles) +} + +func (mm *metricsMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_rule").Add(1) + mm.latency.With("method", "update_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateRule(ctx, session, r) +} + +func (mm *metricsMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_rule_tags").Add(1) + mm.latency.With("method", "update_rule_tags").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateRuleTags(ctx, session, r) +} + +func (mm *metricsMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_rule_schedule").Add(1) + mm.latency.With("method", "update_rule_schedule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateRuleSchedule(ctx, session, r) +} + +func (mm *metricsMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_rules").Add(1) + mm.latency.With("method", "list_rules").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ListRules(ctx, session, pm) +} + +func (mm *metricsMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "remove_rule").Add(1) + mm.latency.With("method", "remove_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.RemoveRule(ctx, session, id) +} + +func (mm *metricsMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "enable_rule").Add(1) + mm.latency.With("method", "enable_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.EnableRule(ctx, session, id) +} + +func (mm *metricsMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + defer func(begin time.Time) { + mm.counter.With("method", "disable_rule").Add(1) + mm.latency.With("method", "disable_rule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DisableRule(ctx, session, id) +} + +func (mm *metricsMiddleware) Handle(msg *messaging.Message) error { + defer func(begin time.Time) { + mm.counter.With("method", "handle").Add(1) + mm.latency.With("method", "handle").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.Handle(msg) +} + +func (mm *metricsMiddleware) StartScheduler(ctx context.Context) error { + defer func(begin time.Time) { + mm.counter.With("method", "start_scheduler").Add(1) + mm.latency.With("method", "start_scheduler").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.StartScheduler(ctx) +} + +func (mm *metricsMiddleware) Cancel() error { + return mm.service.Cancel() +} diff --git a/re/middleware/tracing.go b/re/middleware/tracing.go new file mode 100644 index 000000000..6bc1abe60 --- /dev/null +++ b/re/middleware/tracing.go @@ -0,0 +1,137 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + smqTracing "github.com/absmach/supermq/pkg/tracing" + "github.com/absmach/supermq/re" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +type tracingMiddleware struct { + tracer trace.Tracer + svc re.Service + rolemw.RoleManagerTracing +} + +var _ re.Service = (*tracingMiddleware)(nil) + +func NewTracingMiddleware(tracer trace.Tracer, svc re.Service) re.Service { + return &tracingMiddleware{ + tracer: tracer, + svc: svc, + RoleManagerTracing: rolemw.NewTracing("re", svc, tracer), + } +} + +func (tm *tracingMiddleware) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "add_rule", trace.WithAttributes( + attribute.String("name", r.Name), + attribute.String("domain_id", r.DomainID), + )) + defer span.End() + + return tm.svc.AddRule(ctx, session, r) +} + +func (tm *tracingMiddleware) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "view_rule", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewRule(ctx, session, id, withRoles) +} + +func (tm *tracingMiddleware) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule", trace.WithAttributes( + attribute.String("id", r.ID), + )) + defer span.End() + + return tm.svc.UpdateRule(ctx, session, r) +} + +func (tm *tracingMiddleware) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_tags", trace.WithAttributes( + attribute.String("id", r.ID), + )) + defer span.End() + + return tm.svc.UpdateRuleTags(ctx, session, r) +} + +func (tm *tracingMiddleware) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_rule_schedule", trace.WithAttributes( + attribute.String("id", r.ID), + )) + defer span.End() + + return tm.svc.UpdateRuleSchedule(ctx, session, r) +} + +func (tm *tracingMiddleware) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_rules", trace.WithAttributes( + attribute.Int("offset", int(pm.Offset)), + attribute.Int("limit", int(pm.Limit)), + )) + defer span.End() + + return tm.svc.ListRules(ctx, session, pm) +} + +func (tm *tracingMiddleware) RemoveRule(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "remove_rule", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.RemoveRule(ctx, session, id) +} + +func (tm *tracingMiddleware) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "enable_rule", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.EnableRule(ctx, session, id) +} + +func (tm *tracingMiddleware) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "disable_rule", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DisableRule(ctx, session, id) +} + +func (tm *tracingMiddleware) Handle(msg *messaging.Message) error { + _, span := smqTracing.StartSpan(context.Background(), tm.tracer, "handle", trace.WithAttributes( + attribute.String("channel", msg.Channel), + attribute.String("subtopic", msg.Subtopic), + )) + defer span.End() + + return tm.svc.Handle(msg) +} + +func (tm *tracingMiddleware) StartScheduler(ctx context.Context) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "start_scheduler") + defer span.End() + + return tm.svc.StartScheduler(ctx) +} + +func (tm *tracingMiddleware) Cancel() error { + return tm.svc.Cancel() +} diff --git a/re/mocks/repository.go b/re/mocks/repository.go new file mode 100644 index 000000000..8ec838dd6 --- /dev/null +++ b/re/mocks/repository.go @@ -0,0 +1,2133 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/re" + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// AddRoles provides a mock function for the type Repository +func (_mock *Repository) AddRoles(ctx context.Context, rps []roles.RoleProvision) ([]roles.RoleProvision, error) { + ret := _mock.Called(ctx, rps) + + if len(ret) == 0 { + panic("no return value specified for AddRoles") + } + + var r0 []roles.RoleProvision + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []roles.RoleProvision) ([]roles.RoleProvision, error)); ok { + return returnFunc(ctx, rps) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, []roles.RoleProvision) []roles.RoleProvision); ok { + r0 = returnFunc(ctx, rps) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]roles.RoleProvision) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, []roles.RoleProvision) error); ok { + r1 = returnFunc(ctx, rps) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_AddRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRoles' +type Repository_AddRoles_Call struct { + *mock.Call +} + +// AddRoles is a helper method to define mock.On call +// - ctx context.Context +// - rps []roles.RoleProvision +func (_e *Repository_Expecter) AddRoles(ctx interface{}, rps interface{}) *Repository_AddRoles_Call { + return &Repository_AddRoles_Call{Call: _e.mock.On("AddRoles", ctx, rps)} +} + +func (_c *Repository_AddRoles_Call) Run(run func(ctx context.Context, rps []roles.RoleProvision)) *Repository_AddRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []roles.RoleProvision + if args[1] != nil { + arg1 = args[1].([]roles.RoleProvision) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_AddRoles_Call) Return(roleProvisions []roles.RoleProvision, err error) *Repository_AddRoles_Call { + _c.Call.Return(roleProvisions, err) + return _c +} + +func (_c *Repository_AddRoles_Call) RunAndReturn(run func(ctx context.Context, rps []roles.RoleProvision) ([]roles.RoleProvision, error)) *Repository_AddRoles_Call { + _c.Call.Return(run) + return _c +} + +// AddRule provides a mock function for the type Repository +func (_mock *Repository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for AddRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.Rule) error); ok { + r1 = returnFunc(ctx, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_AddRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRule' +type Repository_AddRule_Call struct { + *mock.Call +} + +// AddRule is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) AddRule(ctx interface{}, r interface{}) *Repository_AddRule_Call { + return &Repository_AddRule_Call{Call: _e.mock.On("AddRule", ctx, r)} +} + +func (_c *Repository_AddRule_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_AddRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_AddRule_Call) Return(rule re.Rule, err error) *Repository_AddRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_AddRule_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) (re.Rule, error)) *Repository_AddRule_Call { + _c.Call.Return(run) + return _c +} + +// ListAllRules provides a mock function for the type Repository +func (_mock *Repository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAllRules") + } + + var r0 re.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) (re.Page, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.PageMeta) re.Page); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(re.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.PageMeta) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAllRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllRules' +type Repository_ListAllRules_Call struct { + *mock.Call +} + +// ListAllRules is a helper method to define mock.On call +// - ctx context.Context +// - pm re.PageMeta +func (_e *Repository_Expecter) ListAllRules(ctx interface{}, pm interface{}) *Repository_ListAllRules_Call { + return &Repository_ListAllRules_Call{Call: _e.mock.On("ListAllRules", ctx, pm)} +} + +func (_c *Repository_ListAllRules_Call) Run(run func(ctx context.Context, pm re.PageMeta)) *Repository_ListAllRules_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.PageMeta + if args[1] != nil { + arg1 = args[1].(re.PageMeta) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListAllRules_Call) Return(page re.Page, err error) *Repository_ListAllRules_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *Repository_ListAllRules_Call) RunAndReturn(run func(ctx context.Context, pm re.PageMeta) (re.Page, error)) *Repository_ListAllRules_Call { + _c.Call.Return(run) + return _c +} + +// ListEntityMembers provides a mock function for the type Repository +func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) { + ret := _mock.Called(ctx, entityID, pageQuery) + + if len(ret) == 0 { + panic("no return value specified for ListEntityMembers") + } + + var r0 roles.MembersRolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, roles.MembersRolePageQuery) (roles.MembersRolePage, error)); ok { + return returnFunc(ctx, entityID, pageQuery) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, roles.MembersRolePageQuery) roles.MembersRolePage); ok { + r0 = returnFunc(ctx, entityID, pageQuery) + } else { + r0 = ret.Get(0).(roles.MembersRolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, roles.MembersRolePageQuery) error); ok { + r1 = returnFunc(ctx, entityID, pageQuery) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListEntityMembers' +type Repository_ListEntityMembers_Call struct { + *mock.Call +} + +// ListEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - pageQuery roles.MembersRolePageQuery +func (_e *Repository_Expecter) ListEntityMembers(ctx interface{}, entityID interface{}, pageQuery interface{}) *Repository_ListEntityMembers_Call { + return &Repository_ListEntityMembers_Call{Call: _e.mock.On("ListEntityMembers", ctx, entityID, pageQuery)} +} + +func (_c *Repository_ListEntityMembers_Call) Run(run func(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery)) *Repository_ListEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 roles.MembersRolePageQuery + if args[2] != nil { + arg2 = args[2].(roles.MembersRolePageQuery) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ListEntityMembers_Call) Return(membersRolePage roles.MembersRolePage, err error) *Repository_ListEntityMembers_Call { + _c.Call.Return(membersRolePage, err) + return _c +} + +func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error)) *Repository_ListEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// ListUserRules provides a mock function for the type Repository +func (_mock *Repository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) { + ret := _mock.Called(ctx, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for ListUserRules") + } + + var r0 re.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) (re.Page, error)); ok { + return returnFunc(ctx, userID, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, re.PageMeta) re.Page); ok { + r0 = returnFunc(ctx, userID, pm) + } else { + r0 = ret.Get(0).(re.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, re.PageMeta) error); ok { + r1 = returnFunc(ctx, userID, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListUserRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRules' +type Repository_ListUserRules_Call struct { + *mock.Call +} + +// ListUserRules is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - pm re.PageMeta +func (_e *Repository_Expecter) ListUserRules(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserRules_Call { + return &Repository_ListUserRules_Call{Call: _e.mock.On("ListUserRules", ctx, userID, pm)} +} + +func (_c *Repository_ListUserRules_Call) Run(run func(ctx context.Context, userID string, pm re.PageMeta)) *Repository_ListUserRules_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 re.PageMeta + if args[2] != nil { + arg2 = args[2].(re.PageMeta) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ListUserRules_Call) Return(page re.Page, err error) *Repository_ListUserRules_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *Repository_ListUserRules_Call) RunAndReturn(run func(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error)) *Repository_ListUserRules_Call { + _c.Call.Return(run) + return _c +} + +// RemoveEntityMembers provides a mock function for the type Repository +func (_mock *Repository) RemoveEntityMembers(ctx context.Context, entityID string, members []string) error { + ret := _mock.Called(ctx, entityID, members) + + if len(ret) == 0 { + panic("no return value specified for RemoveEntityMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) error); ok { + r0 = returnFunc(ctx, entityID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveEntityMembers' +type Repository_RemoveEntityMembers_Call struct { + *mock.Call +} + +// RemoveEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - members []string +func (_e *Repository_Expecter) RemoveEntityMembers(ctx interface{}, entityID interface{}, members interface{}) *Repository_RemoveEntityMembers_Call { + return &Repository_RemoveEntityMembers_Call{Call: _e.mock.On("RemoveEntityMembers", ctx, entityID, members)} +} + +func (_c *Repository_RemoveEntityMembers_Call) Run(run func(ctx context.Context, entityID string, members []string)) *Repository_RemoveEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RemoveEntityMembers_Call) Return(err error) *Repository_RemoveEntityMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveEntityMembers_Call) RunAndReturn(run func(ctx context.Context, entityID string, members []string) error) *Repository_RemoveEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// RemoveMemberFromAllRoles provides a mock function for the type Repository +func (_mock *Repository) RemoveMemberFromAllRoles(ctx context.Context, memberID string) error { + ret := _mock.Called(ctx, memberID) + + if len(ret) == 0 { + panic("no return value specified for RemoveMemberFromAllRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, memberID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveMemberFromAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveMemberFromAllRoles' +type Repository_RemoveMemberFromAllRoles_Call struct { + *mock.Call +} + +// RemoveMemberFromAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - memberID string +func (_e *Repository_Expecter) RemoveMemberFromAllRoles(ctx interface{}, memberID interface{}) *Repository_RemoveMemberFromAllRoles_Call { + return &Repository_RemoveMemberFromAllRoles_Call{Call: _e.mock.On("RemoveMemberFromAllRoles", ctx, memberID)} +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) Run(run func(ctx context.Context, memberID string)) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) Return(err error) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) RunAndReturn(run func(ctx context.Context, memberID string) error) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRoles provides a mock function for the type Repository +func (_mock *Repository) RemoveRoles(ctx context.Context, roleIDs []string) error { + ret := _mock.Called(ctx, roleIDs) + + if len(ret) == 0 { + panic("no return value specified for RemoveRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) error); ok { + r0 = returnFunc(ctx, roleIDs) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRoles' +type Repository_RemoveRoles_Call struct { + *mock.Call +} + +// RemoveRoles is a helper method to define mock.On call +// - ctx context.Context +// - roleIDs []string +func (_e *Repository_Expecter) RemoveRoles(ctx interface{}, roleIDs interface{}) *Repository_RemoveRoles_Call { + return &Repository_RemoveRoles_Call{Call: _e.mock.On("RemoveRoles", ctx, roleIDs)} +} + +func (_c *Repository_RemoveRoles_Call) Run(run func(ctx context.Context, roleIDs []string)) *Repository_RemoveRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveRoles_Call) Return(err error) *Repository_RemoveRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveRoles_Call) RunAndReturn(run func(ctx context.Context, roleIDs []string) error) *Repository_RemoveRoles_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRule provides a mock function for the type Repository +func (_mock *Repository) RemoveRule(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveRule") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRule' +type Repository_RemoveRule_Call struct { + *mock.Call +} + +// RemoveRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) RemoveRule(ctx interface{}, id interface{}) *Repository_RemoveRule_Call { + return &Repository_RemoveRule_Call{Call: _e.mock.On("RemoveRule", ctx, id)} +} + +func (_c *Repository_RemoveRule_Call) Run(run func(ctx context.Context, id string)) *Repository_RemoveRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveRule_Call) Return(err error) *Repository_RemoveRule_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveRule_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Repository_RemoveRule_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAllRoles provides a mock function for the type Repository +func (_mock *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { + ret := _mock.Called(ctx, entityID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAllRoles") + } + + var r0 roles.RolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (roles.RolePage, error)); ok { + return returnFunc(ctx, entityID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) roles.RolePage); ok { + r0 = returnFunc(ctx, entityID, limit, offset) + } else { + r0 = ret.Get(0).(roles.RolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, entityID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAllRoles' +type Repository_RetrieveAllRoles_Call struct { + *mock.Call +} + +// RetrieveAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - limit uint64 +// - offset uint64 +func (_e *Repository_Expecter) RetrieveAllRoles(ctx interface{}, entityID interface{}, limit interface{}, offset interface{}) *Repository_RetrieveAllRoles_Call { + return &Repository_RetrieveAllRoles_Call{Call: _e.mock.On("RetrieveAllRoles", ctx, entityID, limit, offset)} +} + +func (_c *Repository_RetrieveAllRoles_Call) Run(run func(ctx context.Context, entityID string, limit uint64, offset uint64)) *Repository_RetrieveAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint64 + if args[2] != nil { + arg2 = args[2].(uint64) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Repository_RetrieveAllRoles_Call) Return(rolePage roles.RolePage, err error) *Repository_RetrieveAllRoles_Call { + _c.Call.Return(rolePage, err) + return _c +} + +func (_c *Repository_RetrieveAllRoles_Call) RunAndReturn(run func(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error)) *Repository_RetrieveAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveByIDWithRoles provides a mock function for the type Repository +func (_mock *Repository) RetrieveByIDWithRoles(ctx context.Context, id string, memberID string) (re.Rule, error) { + ret := _mock.Called(ctx, id, memberID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByIDWithRoles") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (re.Rule, error)); ok { + return returnFunc(ctx, id, memberID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) re.Rule); ok { + r0 = returnFunc(ctx, id, memberID) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, id, memberID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveByIDWithRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByIDWithRoles' +type Repository_RetrieveByIDWithRoles_Call struct { + *mock.Call +} + +// RetrieveByIDWithRoles is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - memberID string +func (_e *Repository_Expecter) RetrieveByIDWithRoles(ctx interface{}, id interface{}, memberID interface{}) *Repository_RetrieveByIDWithRoles_Call { + return &Repository_RetrieveByIDWithRoles_Call{Call: _e.mock.On("RetrieveByIDWithRoles", ctx, id, memberID)} +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) Run(run func(ctx context.Context, id string, memberID string)) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) Return(rule re.Rule, err error) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) RunAndReturn(run func(ctx context.Context, id string, memberID string) (re.Rule, error)) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveEntitiesRolesActionsMembers provides a mock function for the type Repository +func (_mock *Repository) RetrieveEntitiesRolesActionsMembers(ctx context.Context, entityIDs []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error) { + ret := _mock.Called(ctx, entityIDs) + + if len(ret) == 0 { + panic("no return value specified for RetrieveEntitiesRolesActionsMembers") + } + + var r0 []roles.EntityActionRole + var r1 []roles.EntityMemberRole + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error)); ok { + return returnFunc(ctx, entityIDs) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) []roles.EntityActionRole); ok { + r0 = returnFunc(ctx, entityIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]roles.EntityActionRole) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, []string) []roles.EntityMemberRole); ok { + r1 = returnFunc(ctx, entityIDs) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]roles.EntityMemberRole) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, []string) error); ok { + r2 = returnFunc(ctx, entityIDs) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// Repository_RetrieveEntitiesRolesActionsMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntitiesRolesActionsMembers' +type Repository_RetrieveEntitiesRolesActionsMembers_Call struct { + *mock.Call +} + +// RetrieveEntitiesRolesActionsMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityIDs []string +func (_e *Repository_Expecter) RetrieveEntitiesRolesActionsMembers(ctx interface{}, entityIDs interface{}) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + return &Repository_RetrieveEntitiesRolesActionsMembers_Call{Call: _e.mock.On("RetrieveEntitiesRolesActionsMembers", ctx, entityIDs)} +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) Run(run func(ctx context.Context, entityIDs []string)) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) Return(entityActionRoles []roles.EntityActionRole, entityMemberRoles []roles.EntityMemberRole, err error) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Return(entityActionRoles, entityMemberRoles, err) + return _c +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) RunAndReturn(run func(ctx context.Context, entityIDs []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error)) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveEntityRole provides a mock function for the type Repository +func (_mock *Repository) RetrieveEntityRole(ctx context.Context, entityID string, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveEntityRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) roles.Role); ok { + r0 = returnFunc(ctx, entityID, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveEntityRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntityRole' +type Repository_RetrieveEntityRole_Call struct { + *mock.Call +} + +// RetrieveEntityRole is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - roleID string +func (_e *Repository_Expecter) RetrieveEntityRole(ctx interface{}, entityID interface{}, roleID interface{}) *Repository_RetrieveEntityRole_Call { + return &Repository_RetrieveEntityRole_Call{Call: _e.mock.On("RetrieveEntityRole", ctx, entityID, roleID)} +} + +func (_c *Repository_RetrieveEntityRole_Call) Run(run func(ctx context.Context, entityID string, roleID string)) *Repository_RetrieveEntityRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RetrieveEntityRole_Call) Return(role roles.Role, err error) *Repository_RetrieveEntityRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_RetrieveEntityRole_Call) RunAndReturn(run func(ctx context.Context, entityID string, roleID string) (roles.Role, error)) *Repository_RetrieveEntityRole_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveRole provides a mock function for the type Repository +func (_mock *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (roles.Role, error)); ok { + return returnFunc(ctx, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) roles.Role); ok { + r0 = returnFunc(ctx, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveRole' +type Repository_RetrieveRole_Call struct { + *mock.Call +} + +// RetrieveRole is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +func (_e *Repository_Expecter) RetrieveRole(ctx interface{}, roleID interface{}) *Repository_RetrieveRole_Call { + return &Repository_RetrieveRole_Call{Call: _e.mock.On("RetrieveRole", ctx, roleID)} +} + +func (_c *Repository_RetrieveRole_Call) Run(run func(ctx context.Context, roleID string)) *Repository_RetrieveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RetrieveRole_Call) Return(role roles.Role, err error) *Repository_RetrieveRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_RetrieveRole_Call) RunAndReturn(run func(ctx context.Context, roleID string) (roles.Role, error)) *Repository_RetrieveRole_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddActions provides a mock function for the type Repository +func (_mock *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) { + ret := _mock.Called(ctx, role, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleAddActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) ([]string, error)); ok { + return returnFunc(ctx, role, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) []string); ok { + r0 = returnFunc(ctx, role, actions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role, []string) error); ok { + r1 = returnFunc(ctx, role, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleAddActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddActions' +type Repository_RoleAddActions_Call struct { + *mock.Call +} + +// RoleAddActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - actions []string +func (_e *Repository_Expecter) RoleAddActions(ctx interface{}, role interface{}, actions interface{}) *Repository_RoleAddActions_Call { + return &Repository_RoleAddActions_Call{Call: _e.mock.On("RoleAddActions", ctx, role, actions)} +} + +func (_c *Repository_RoleAddActions_Call) Run(run func(ctx context.Context, role roles.Role, actions []string)) *Repository_RoleAddActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleAddActions_Call) Return(ops []string, err error) *Repository_RoleAddActions_Call { + _c.Call.Return(ops, err) + return _c +} + +func (_c *Repository_RoleAddActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, actions []string) ([]string, error)) *Repository_RoleAddActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddMembers provides a mock function for the type Repository +func (_mock *Repository) RoleAddMembers(ctx context.Context, role roles.Role, members []string) ([]string, error) { + ret := _mock.Called(ctx, role, members) + + if len(ret) == 0 { + panic("no return value specified for RoleAddMembers") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) ([]string, error)); ok { + return returnFunc(ctx, role, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) []string); ok { + r0 = returnFunc(ctx, role, members) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role, []string) error); ok { + r1 = returnFunc(ctx, role, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleAddMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddMembers' +type Repository_RoleAddMembers_Call struct { + *mock.Call +} + +// RoleAddMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - members []string +func (_e *Repository_Expecter) RoleAddMembers(ctx interface{}, role interface{}, members interface{}) *Repository_RoleAddMembers_Call { + return &Repository_RoleAddMembers_Call{Call: _e.mock.On("RoleAddMembers", ctx, role, members)} +} + +func (_c *Repository_RoleAddMembers_Call) Run(run func(ctx context.Context, role roles.Role, members []string)) *Repository_RoleAddMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleAddMembers_Call) Return(strings []string, err error) *Repository_RoleAddMembers_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Repository_RoleAddMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, members []string) ([]string, error)) *Repository_RoleAddMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckActionsExists provides a mock function for the type Repository +func (_mock *Repository) RoleCheckActionsExists(ctx context.Context, roleID string, actions []string) (bool, error) { + ret := _mock.Called(ctx, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckActionsExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) (bool, error)); ok { + return returnFunc(ctx, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) bool); ok { + r0 = returnFunc(ctx, roleID, actions) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { + r1 = returnFunc(ctx, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleCheckActionsExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckActionsExists' +type Repository_RoleCheckActionsExists_Call struct { + *mock.Call +} + +// RoleCheckActionsExists is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - actions []string +func (_e *Repository_Expecter) RoleCheckActionsExists(ctx interface{}, roleID interface{}, actions interface{}) *Repository_RoleCheckActionsExists_Call { + return &Repository_RoleCheckActionsExists_Call{Call: _e.mock.On("RoleCheckActionsExists", ctx, roleID, actions)} +} + +func (_c *Repository_RoleCheckActionsExists_Call) Run(run func(ctx context.Context, roleID string, actions []string)) *Repository_RoleCheckActionsExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleCheckActionsExists_Call) Return(b bool, err error) *Repository_RoleCheckActionsExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Repository_RoleCheckActionsExists_Call) RunAndReturn(run func(ctx context.Context, roleID string, actions []string) (bool, error)) *Repository_RoleCheckActionsExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckMembersExists provides a mock function for the type Repository +func (_mock *Repository) RoleCheckMembersExists(ctx context.Context, roleID string, members []string) (bool, error) { + ret := _mock.Called(ctx, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckMembersExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) (bool, error)); ok { + return returnFunc(ctx, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) bool); ok { + r0 = returnFunc(ctx, roleID, members) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { + r1 = returnFunc(ctx, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleCheckMembersExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckMembersExists' +type Repository_RoleCheckMembersExists_Call struct { + *mock.Call +} + +// RoleCheckMembersExists is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - members []string +func (_e *Repository_Expecter) RoleCheckMembersExists(ctx interface{}, roleID interface{}, members interface{}) *Repository_RoleCheckMembersExists_Call { + return &Repository_RoleCheckMembersExists_Call{Call: _e.mock.On("RoleCheckMembersExists", ctx, roleID, members)} +} + +func (_c *Repository_RoleCheckMembersExists_Call) Run(run func(ctx context.Context, roleID string, members []string)) *Repository_RoleCheckMembersExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleCheckMembersExists_Call) Return(b bool, err error) *Repository_RoleCheckMembersExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Repository_RoleCheckMembersExists_Call) RunAndReturn(run func(ctx context.Context, roleID string, members []string) (bool, error)) *Repository_RoleCheckMembersExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleListActions provides a mock function for the type Repository +func (_mock *Repository) RoleListActions(ctx context.Context, roleID string) ([]string, error) { + ret := _mock.Called(ctx, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleListActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { + return returnFunc(ctx, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []string); ok { + r0 = returnFunc(ctx, roleID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleListActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListActions' +type Repository_RoleListActions_Call struct { + *mock.Call +} + +// RoleListActions is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +func (_e *Repository_Expecter) RoleListActions(ctx interface{}, roleID interface{}) *Repository_RoleListActions_Call { + return &Repository_RoleListActions_Call{Call: _e.mock.On("RoleListActions", ctx, roleID)} +} + +func (_c *Repository_RoleListActions_Call) Run(run func(ctx context.Context, roleID string)) *Repository_RoleListActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleListActions_Call) Return(strings []string, err error) *Repository_RoleListActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Repository_RoleListActions_Call) RunAndReturn(run func(ctx context.Context, roleID string) ([]string, error)) *Repository_RoleListActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleListMembers provides a mock function for the type Repository +func (_mock *Repository) RoleListMembers(ctx context.Context, roleID string, limit uint64, offset uint64) (roles.MembersPage, error) { + ret := _mock.Called(ctx, roleID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RoleListMembers") + } + + var r0 roles.MembersPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (roles.MembersPage, error)); ok { + return returnFunc(ctx, roleID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) roles.MembersPage); ok { + r0 = returnFunc(ctx, roleID, limit, offset) + } else { + r0 = ret.Get(0).(roles.MembersPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, roleID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleListMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListMembers' +type Repository_RoleListMembers_Call struct { + *mock.Call +} + +// RoleListMembers is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - limit uint64 +// - offset uint64 +func (_e *Repository_Expecter) RoleListMembers(ctx interface{}, roleID interface{}, limit interface{}, offset interface{}) *Repository_RoleListMembers_Call { + return &Repository_RoleListMembers_Call{Call: _e.mock.On("RoleListMembers", ctx, roleID, limit, offset)} +} + +func (_c *Repository_RoleListMembers_Call) Run(run func(ctx context.Context, roleID string, limit uint64, offset uint64)) *Repository_RoleListMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint64 + if args[2] != nil { + arg2 = args[2].(uint64) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Repository_RoleListMembers_Call) Return(membersPage roles.MembersPage, err error) *Repository_RoleListMembers_Call { + _c.Call.Return(membersPage, err) + return _c +} + +func (_c *Repository_RoleListMembers_Call) RunAndReturn(run func(ctx context.Context, roleID string, limit uint64, offset uint64) (roles.MembersPage, error)) *Repository_RoleListMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveActions provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveActions(ctx context.Context, role roles.Role, actions []string) error { + ret := _mock.Called(ctx, role, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) error); ok { + r0 = returnFunc(ctx, role, actions) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveActions' +type Repository_RoleRemoveActions_Call struct { + *mock.Call +} + +// RoleRemoveActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - actions []string +func (_e *Repository_Expecter) RoleRemoveActions(ctx interface{}, role interface{}, actions interface{}) *Repository_RoleRemoveActions_Call { + return &Repository_RoleRemoveActions_Call{Call: _e.mock.On("RoleRemoveActions", ctx, role, actions)} +} + +func (_c *Repository_RoleRemoveActions_Call) Run(run func(ctx context.Context, role roles.Role, actions []string)) *Repository_RoleRemoveActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveActions_Call) Return(err error) *Repository_RoleRemoveActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, actions []string) error) *Repository_RoleRemoveActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllActions provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveAllActions(ctx context.Context, role roles.Role) error { + ret := _mock.Called(ctx, role) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) error); ok { + r0 = returnFunc(ctx, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveAllActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllActions' +type Repository_RoleRemoveAllActions_Call struct { + *mock.Call +} + +// RoleRemoveAllActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +func (_e *Repository_Expecter) RoleRemoveAllActions(ctx interface{}, role interface{}) *Repository_RoleRemoveAllActions_Call { + return &Repository_RoleRemoveAllActions_Call{Call: _e.mock.On("RoleRemoveAllActions", ctx, role)} +} + +func (_c *Repository_RoleRemoveAllActions_Call) Run(run func(ctx context.Context, role roles.Role)) *Repository_RoleRemoveAllActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveAllActions_Call) Return(err error) *Repository_RoleRemoveAllActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveAllActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role) error) *Repository_RoleRemoveAllActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllMembers provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveAllMembers(ctx context.Context, role roles.Role) error { + ret := _mock.Called(ctx, role) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) error); ok { + r0 = returnFunc(ctx, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveAllMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllMembers' +type Repository_RoleRemoveAllMembers_Call struct { + *mock.Call +} + +// RoleRemoveAllMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +func (_e *Repository_Expecter) RoleRemoveAllMembers(ctx interface{}, role interface{}) *Repository_RoleRemoveAllMembers_Call { + return &Repository_RoleRemoveAllMembers_Call{Call: _e.mock.On("RoleRemoveAllMembers", ctx, role)} +} + +func (_c *Repository_RoleRemoveAllMembers_Call) Run(run func(ctx context.Context, role roles.Role)) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveAllMembers_Call) Return(err error) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveAllMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role) error) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveMembers provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveMembers(ctx context.Context, role roles.Role, members []string) error { + ret := _mock.Called(ctx, role, members) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) error); ok { + r0 = returnFunc(ctx, role, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveMembers' +type Repository_RoleRemoveMembers_Call struct { + *mock.Call +} + +// RoleRemoveMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - members []string +func (_e *Repository_Expecter) RoleRemoveMembers(ctx interface{}, role interface{}, members interface{}) *Repository_RoleRemoveMembers_Call { + return &Repository_RoleRemoveMembers_Call{Call: _e.mock.On("RoleRemoveMembers", ctx, role, members)} +} + +func (_c *Repository_RoleRemoveMembers_Call) Run(run func(ctx context.Context, role roles.Role, members []string)) *Repository_RoleRemoveMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveMembers_Call) Return(err error) *Repository_RoleRemoveMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, members []string) error) *Repository_RoleRemoveMembers_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRole provides a mock function for the type Repository +func (_mock *Repository) UpdateRole(ctx context.Context, ro roles.Role) (roles.Role, error) { + ret := _mock.Called(ctx, ro) + + if len(ret) == 0 { + panic("no return value specified for UpdateRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) (roles.Role, error)); ok { + return returnFunc(ctx, ro) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) roles.Role); ok { + r0 = returnFunc(ctx, ro) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role) error); ok { + r1 = returnFunc(ctx, ro) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole' +type Repository_UpdateRole_Call struct { + *mock.Call +} + +// UpdateRole is a helper method to define mock.On call +// - ctx context.Context +// - ro roles.Role +func (_e *Repository_Expecter) UpdateRole(ctx interface{}, ro interface{}) *Repository_UpdateRole_Call { + return &Repository_UpdateRole_Call{Call: _e.mock.On("UpdateRole", ctx, ro)} +} + +func (_c *Repository_UpdateRole_Call) Run(run func(ctx context.Context, ro roles.Role)) *Repository_UpdateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRole_Call) Return(role roles.Role, err error) *Repository_UpdateRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context, ro roles.Role) (roles.Role, error)) *Repository_UpdateRole_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRule provides a mock function for the type Repository +func (_mock *Repository) UpdateRule(ctx context.Context, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.Rule) error); ok { + r1 = returnFunc(ctx, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRule' +type Repository_UpdateRule_Call struct { + *mock.Call +} + +// UpdateRule is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRule(ctx interface{}, r interface{}) *Repository_UpdateRule_Call { + return &Repository_UpdateRule_Call{Call: _e.mock.On("UpdateRule", ctx, r)} +} + +func (_c *Repository_UpdateRule_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRule_Call) Return(rule re.Rule, err error) *Repository_UpdateRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_UpdateRule_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) (re.Rule, error)) *Repository_UpdateRule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleDue provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleDue(ctx context.Context, id string, due time.Time) (re.Rule, error) { + ret := _mock.Called(ctx, id, due) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleDue") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, time.Time) (re.Rule, error)); ok { + return returnFunc(ctx, id, due) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, time.Time) re.Rule); ok { + r0 = returnFunc(ctx, id, due) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, time.Time) error); ok { + r1 = returnFunc(ctx, id, due) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRuleDue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleDue' +type Repository_UpdateRuleDue_Call struct { + *mock.Call +} + +// UpdateRuleDue is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - due time.Time +func (_e *Repository_Expecter) UpdateRuleDue(ctx interface{}, id interface{}, due interface{}) *Repository_UpdateRuleDue_Call { + return &Repository_UpdateRuleDue_Call{Call: _e.mock.On("UpdateRuleDue", ctx, id, due)} +} + +func (_c *Repository_UpdateRuleDue_Call) Run(run func(ctx context.Context, id string, due time.Time)) *Repository_UpdateRuleDue_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 time.Time + if args[2] != nil { + arg2 = args[2].(time.Time) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleDue_Call) Return(rule re.Rule, err error) *Repository_UpdateRuleDue_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_UpdateRuleDue_Call) RunAndReturn(run func(ctx context.Context, id string, due time.Time) (re.Rule, error)) *Repository_UpdateRuleDue_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleSchedule provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleSchedule(ctx context.Context, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleSchedule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.Rule) error); ok { + r1 = returnFunc(ctx, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRuleSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleSchedule' +type Repository_UpdateRuleSchedule_Call struct { + *mock.Call +} + +// UpdateRuleSchedule is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRuleSchedule(ctx interface{}, r interface{}) *Repository_UpdateRuleSchedule_Call { + return &Repository_UpdateRuleSchedule_Call{Call: _e.mock.On("UpdateRuleSchedule", ctx, r)} +} + +func (_c *Repository_UpdateRuleSchedule_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRuleSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleSchedule_Call) Return(rule re.Rule, err error) *Repository_UpdateRuleSchedule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_UpdateRuleSchedule_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) (re.Rule, error)) *Repository_UpdateRuleSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleStatus provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleStatus(ctx context.Context, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleStatus") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.Rule) error); ok { + r1 = returnFunc(ctx, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRuleStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleStatus' +type Repository_UpdateRuleStatus_Call struct { + *mock.Call +} + +// UpdateRuleStatus is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRuleStatus(ctx interface{}, r interface{}) *Repository_UpdateRuleStatus_Call { + return &Repository_UpdateRuleStatus_Call{Call: _e.mock.On("UpdateRuleStatus", ctx, r)} +} + +func (_c *Repository_UpdateRuleStatus_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRuleStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleStatus_Call) Return(rule re.Rule, err error) *Repository_UpdateRuleStatus_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_UpdateRuleStatus_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) (re.Rule, error)) *Repository_UpdateRuleStatus_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleTags provides a mock function for the type Repository +func (_mock *Repository) UpdateRuleTags(ctx context.Context, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleTags") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, re.Rule) error); ok { + r1 = returnFunc(ctx, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRuleTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleTags' +type Repository_UpdateRuleTags_Call struct { + *mock.Call +} + +// UpdateRuleTags is a helper method to define mock.On call +// - ctx context.Context +// - r re.Rule +func (_e *Repository_Expecter) UpdateRuleTags(ctx interface{}, r interface{}) *Repository_UpdateRuleTags_Call { + return &Repository_UpdateRuleTags_Call{Call: _e.mock.On("UpdateRuleTags", ctx, r)} +} + +func (_c *Repository_UpdateRuleTags_Call) Run(run func(ctx context.Context, r re.Rule)) *Repository_UpdateRuleTags_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 re.Rule + if args[1] != nil { + arg1 = args[1].(re.Rule) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRuleTags_Call) Return(rule re.Rule, err error) *Repository_UpdateRuleTags_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_UpdateRuleTags_Call) RunAndReturn(run func(ctx context.Context, r re.Rule) (re.Rule, error)) *Repository_UpdateRuleTags_Call { + _c.Call.Return(run) + return _c +} + +// ViewRule provides a mock function for the type Repository +func (_mock *Repository) ViewRule(ctx context.Context, id string) (re.Rule, error) { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for ViewRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (re.Rule, error)); ok { + return returnFunc(ctx, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) re.Rule); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewRule' +type Repository_ViewRule_Call struct { + *mock.Call +} + +// ViewRule is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) ViewRule(ctx interface{}, id interface{}) *Repository_ViewRule_Call { + return &Repository_ViewRule_Call{Call: _e.mock.On("ViewRule", ctx, id)} +} + +func (_c *Repository_ViewRule_Call) Run(run func(ctx context.Context, id string)) *Repository_ViewRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ViewRule_Call) Return(rule re.Rule, err error) *Repository_ViewRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Repository_ViewRule_Call) RunAndReturn(run func(ctx context.Context, id string) (re.Rule, error)) *Repository_ViewRule_Call { + _c.Call.Return(run) + return _c +} diff --git a/re/mocks/service.go b/re/mocks/service.go new file mode 100644 index 000000000..bf7c47ad1 --- /dev/null +++ b/re/mocks/service.go @@ -0,0 +1,2326 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/re" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// AddRole provides a mock function for the type Service +func (_mock *Service) AddRole(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) { + ret := _mock.Called(ctx, session, entityID, roleName, optionalActions, optionalMembers) + + if len(ret) == 0 { + panic("no return value specified for AddRole") + } + + var r0 roles.RoleProvision + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, []string) (roles.RoleProvision, error)); ok { + return returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, []string) roles.RoleProvision); ok { + r0 = returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } else { + r0 = ret.Get(0).(roles.RoleProvision) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_AddRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRole' +type Service_AddRole_Call struct { + *mock.Call +} + +// AddRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleName string +// - optionalActions []string +// - optionalMembers []string +func (_e *Service_Expecter) AddRole(ctx interface{}, session interface{}, entityID interface{}, roleName interface{}, optionalActions interface{}, optionalMembers interface{}) *Service_AddRole_Call { + return &Service_AddRole_Call{Call: _e.mock.On("AddRole", ctx, session, entityID, roleName, optionalActions, optionalMembers)} +} + +func (_c *Service_AddRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string)) *Service_AddRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + var arg5 []string + if args[5] != nil { + arg5 = args[5].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_AddRole_Call) Return(roleProvision roles.RoleProvision, err error) *Service_AddRole_Call { + _c.Call.Return(roleProvision, err) + return _c +} + +func (_c *Service_AddRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error)) *Service_AddRole_Call { + _c.Call.Return(run) + return _c +} + +// AddRule provides a mock function for the type Service +func (_mock *Service) AddRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error) { + ret := _mock.Called(ctx, session, r) + + if len(ret) == 0 { + panic("no return value specified for AddRule") + } + + var r0 re.Rule + var r1 []roles.RoleProvision + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) (re.Rule, []roles.RoleProvision, error)); ok { + return returnFunc(ctx, session, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, session, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, re.Rule) []roles.RoleProvision); ok { + r1 = returnFunc(ctx, session, r) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]roles.RoleProvision) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, authn.Session, re.Rule) error); ok { + r2 = returnFunc(ctx, session, r) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// Service_AddRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRule' +type Service_AddRule_Call struct { + *mock.Call +} + +// AddRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - r re.Rule +func (_e *Service_Expecter) AddRule(ctx interface{}, session interface{}, r interface{}) *Service_AddRule_Call { + return &Service_AddRule_Call{Call: _e.mock.On("AddRule", ctx, session, r)} +} + +func (_c *Service_AddRule_Call) Run(run func(ctx context.Context, session authn.Session, r re.Rule)) *Service_AddRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 re.Rule + if args[2] != nil { + arg2 = args[2].(re.Rule) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_AddRule_Call) Return(rule re.Rule, roleProvisions []roles.RoleProvision, err error) *Service_AddRule_Call { + _c.Call.Return(rule, roleProvisions, err) + return _c +} + +func (_c *Service_AddRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, []roles.RoleProvision, error)) *Service_AddRule_Call { + _c.Call.Return(run) + return _c +} + +// Cancel provides a mock function for the type Service +func (_mock *Service) Cancel() error { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Cancel") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func() error); ok { + r0 = returnFunc() + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Cancel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Cancel' +type Service_Cancel_Call struct { + *mock.Call +} + +// Cancel is a helper method to define mock.On call +func (_e *Service_Expecter) Cancel() *Service_Cancel_Call { + return &Service_Cancel_Call{Call: _e.mock.On("Cancel")} +} + +func (_c *Service_Cancel_Call) Run(run func()) *Service_Cancel_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Service_Cancel_Call) Return(err error) *Service_Cancel_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Cancel_Call) RunAndReturn(run func() error) *Service_Cancel_Call { + _c.Call.Return(run) + return _c +} + +// DisableRule provides a mock function for the type Service +func (_mock *Service) DisableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DisableRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (re.Rule, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) re.Rule); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_DisableRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisableRule' +type Service_DisableRule_Call struct { + *mock.Call +} + +// DisableRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) DisableRule(ctx interface{}, session interface{}, id interface{}) *Service_DisableRule_Call { + return &Service_DisableRule_Call{Call: _e.mock.On("DisableRule", ctx, session, id)} +} + +func (_c *Service_DisableRule_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DisableRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_DisableRule_Call) Return(rule re.Rule, err error) *Service_DisableRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_DisableRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (re.Rule, error)) *Service_DisableRule_Call { + _c.Call.Return(run) + return _c +} + +// EnableRule provides a mock function for the type Service +func (_mock *Service) EnableRule(ctx context.Context, session authn.Session, id string) (re.Rule, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for EnableRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (re.Rule, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) re.Rule); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_EnableRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EnableRule' +type Service_EnableRule_Call struct { + *mock.Call +} + +// EnableRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) EnableRule(ctx interface{}, session interface{}, id interface{}) *Service_EnableRule_Call { + return &Service_EnableRule_Call{Call: _e.mock.On("EnableRule", ctx, session, id)} +} + +func (_c *Service_EnableRule_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_EnableRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_EnableRule_Call) Return(rule re.Rule, err error) *Service_EnableRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_EnableRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (re.Rule, error)) *Service_EnableRule_Call { + _c.Call.Return(run) + return _c +} + +// Handle provides a mock function for the type Service +func (_mock *Service) Handle(msg *messaging.Message) error { + ret := _mock.Called(msg) + + if len(ret) == 0 { + panic("no return value specified for Handle") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(*messaging.Message) error); ok { + r0 = returnFunc(msg) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_Handle_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Handle' +type Service_Handle_Call struct { + *mock.Call +} + +// Handle is a helper method to define mock.On call +// - msg *messaging.Message +func (_e *Service_Expecter) Handle(msg interface{}) *Service_Handle_Call { + return &Service_Handle_Call{Call: _e.mock.On("Handle", msg)} +} + +func (_c *Service_Handle_Call) Run(run func(msg *messaging.Message)) *Service_Handle_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *messaging.Message + if args[0] != nil { + arg0 = args[0].(*messaging.Message) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Service_Handle_Call) Return(err error) *Service_Handle_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_Handle_Call) RunAndReturn(run func(msg *messaging.Message) error) *Service_Handle_Call { + _c.Call.Return(run) + return _c +} + +// ListAvailableActions provides a mock function for the type Service +func (_mock *Service) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) { + ret := _mock.Called(ctx, session) + + if len(ret) == 0 { + panic("no return value specified for ListAvailableActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) ([]string, error)); ok { + return returnFunc(ctx, session) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) []string); ok { + r0 = returnFunc(ctx, session) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session) error); ok { + r1 = returnFunc(ctx, session) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListAvailableActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAvailableActions' +type Service_ListAvailableActions_Call struct { + *mock.Call +} + +// ListAvailableActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +func (_e *Service_Expecter) ListAvailableActions(ctx interface{}, session interface{}) *Service_ListAvailableActions_Call { + return &Service_ListAvailableActions_Call{Call: _e.mock.On("ListAvailableActions", ctx, session)} +} + +func (_c *Service_ListAvailableActions_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_ListAvailableActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_ListAvailableActions_Call) Return(strings []string, err error) *Service_ListAvailableActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_ListAvailableActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) ([]string, error)) *Service_ListAvailableActions_Call { + _c.Call.Return(run) + return _c +} + +// ListEntityMembers provides a mock function for the type Service +func (_mock *Service) ListEntityMembers(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery) (roles.MembersRolePage, error) { + ret := _mock.Called(ctx, session, entityID, pq) + + if len(ret) == 0 { + panic("no return value specified for ListEntityMembers") + } + + var r0 roles.MembersRolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) (roles.MembersRolePage, error)); ok { + return returnFunc(ctx, session, entityID, pq) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) roles.MembersRolePage); ok { + r0 = returnFunc(ctx, session, entityID, pq) + } else { + r0 = ret.Get(0).(roles.MembersRolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) error); ok { + r1 = returnFunc(ctx, session, entityID, pq) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListEntityMembers' +type Service_ListEntityMembers_Call struct { + *mock.Call +} + +// ListEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - pq roles.MembersRolePageQuery +func (_e *Service_Expecter) ListEntityMembers(ctx interface{}, session interface{}, entityID interface{}, pq interface{}) *Service_ListEntityMembers_Call { + return &Service_ListEntityMembers_Call{Call: _e.mock.On("ListEntityMembers", ctx, session, entityID, pq)} +} + +func (_c *Service_ListEntityMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery)) *Service_ListEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 roles.MembersRolePageQuery + if args[3] != nil { + arg3 = args[3].(roles.MembersRolePageQuery) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ListEntityMembers_Call) Return(membersRolePage roles.MembersRolePage, err error) *Service_ListEntityMembers_Call { + _c.Call.Return(membersRolePage, err) + return _c +} + +func (_c *Service_ListEntityMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery) (roles.MembersRolePage, error)) *Service_ListEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// ListRules provides a mock function for the type Service +func (_mock *Service) ListRules(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListRules") + } + + var r0 re.Page + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.PageMeta) (re.Page, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.PageMeta) re.Page); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(re.Page) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, re.PageMeta) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListRules_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListRules' +type Service_ListRules_Call struct { + *mock.Call +} + +// ListRules is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - pm re.PageMeta +func (_e *Service_Expecter) ListRules(ctx interface{}, session interface{}, pm interface{}) *Service_ListRules_Call { + return &Service_ListRules_Call{Call: _e.mock.On("ListRules", ctx, session, pm)} +} + +func (_c *Service_ListRules_Call) Run(run func(ctx context.Context, session authn.Session, pm re.PageMeta)) *Service_ListRules_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 re.PageMeta + if args[2] != nil { + arg2 = args[2].(re.PageMeta) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ListRules_Call) Return(page re.Page, err error) *Service_ListRules_Call { + _c.Call.Return(page, err) + return _c +} + +func (_c *Service_ListRules_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm re.PageMeta) (re.Page, error)) *Service_ListRules_Call { + _c.Call.Return(run) + return _c +} + +// RemoveEntityMembers provides a mock function for the type Service +func (_mock *Service) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error { + ret := _mock.Called(ctx, session, entityID, members) + + if len(ret) == 0 { + panic("no return value specified for RemoveEntityMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveEntityMembers' +type Service_RemoveEntityMembers_Call struct { + *mock.Call +} + +// RemoveEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - members []string +func (_e *Service_Expecter) RemoveEntityMembers(ctx interface{}, session interface{}, entityID interface{}, members interface{}) *Service_RemoveEntityMembers_Call { + return &Service_RemoveEntityMembers_Call{Call: _e.mock.On("RemoveEntityMembers", ctx, session, entityID, members)} +} + +func (_c *Service_RemoveEntityMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, members []string)) *Service_RemoveEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 []string + if args[3] != nil { + arg3 = args[3].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RemoveEntityMembers_Call) Return(err error) *Service_RemoveEntityMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveEntityMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, members []string) error) *Service_RemoveEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// RemoveMemberFromAllRoles provides a mock function for the type Service +func (_mock *Service) RemoveMemberFromAllRoles(ctx context.Context, session authn.Session, memberID string) error { + ret := _mock.Called(ctx, session, memberID) + + if len(ret) == 0 { + panic("no return value specified for RemoveMemberFromAllRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, memberID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveMemberFromAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveMemberFromAllRoles' +type Service_RemoveMemberFromAllRoles_Call struct { + *mock.Call +} + +// RemoveMemberFromAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - memberID string +func (_e *Service_Expecter) RemoveMemberFromAllRoles(ctx interface{}, session interface{}, memberID interface{}) *Service_RemoveMemberFromAllRoles_Call { + return &Service_RemoveMemberFromAllRoles_Call{Call: _e.mock.On("RemoveMemberFromAllRoles", ctx, session, memberID)} +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) Run(run func(ctx context.Context, session authn.Session, memberID string)) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) Return(err error) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, memberID string) error) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRole provides a mock function for the type Service +func (_mock *Service) RemoveRole(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RemoveRole") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRole' +type Service_RemoveRole_Call struct { + *mock.Call +} + +// RemoveRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RemoveRole(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RemoveRole_Call { + return &Service_RemoveRole_Call{Call: _e.mock.On("RemoveRole", ctx, session, entityID, roleID)} +} + +func (_c *Service_RemoveRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RemoveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RemoveRole_Call) Return(err error) *Service_RemoveRole_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RemoveRole_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRule provides a mock function for the type Service +func (_mock *Service) RemoveRule(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveRule") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRule' +type Service_RemoveRule_Call struct { + *mock.Call +} + +// RemoveRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) RemoveRule(ctx interface{}, session interface{}, id interface{}) *Service_RemoveRule_Call { + return &Service_RemoveRule_Call{Call: _e.mock.On("RemoveRule", ctx, session, id)} +} + +func (_c *Service_RemoveRule_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_RemoveRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RemoveRule_Call) Return(err error) *Service_RemoveRule_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_RemoveRule_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAllRoles provides a mock function for the type Service +func (_mock *Service) RetrieveAllRoles(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { + ret := _mock.Called(ctx, session, entityID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAllRoles") + } + + var r0 roles.RolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, uint64, uint64) (roles.RolePage, error)); ok { + return returnFunc(ctx, session, entityID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, uint64, uint64) roles.RolePage); ok { + r0 = returnFunc(ctx, session, entityID, limit, offset) + } else { + r0 = ret.Get(0).(roles.RolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, session, entityID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RetrieveAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAllRoles' +type Service_RetrieveAllRoles_Call struct { + *mock.Call +} + +// RetrieveAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - limit uint64 +// - offset uint64 +func (_e *Service_Expecter) RetrieveAllRoles(ctx interface{}, session interface{}, entityID interface{}, limit interface{}, offset interface{}) *Service_RetrieveAllRoles_Call { + return &Service_RetrieveAllRoles_Call{Call: _e.mock.On("RetrieveAllRoles", ctx, session, entityID, limit, offset)} +} + +func (_c *Service_RetrieveAllRoles_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64)) *Service_RetrieveAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RetrieveAllRoles_Call) Return(rolePage roles.RolePage, err error) *Service_RetrieveAllRoles_Call { + _c.Call.Return(rolePage, err) + return _c +} + +func (_c *Service_RetrieveAllRoles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64) (roles.RolePage, error)) *Service_RetrieveAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveRole provides a mock function for the type Service +func (_mock *Service) RetrieveRole(ctx context.Context, session authn.Session, entityID string, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, session, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) roles.Role); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RetrieveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveRole' +type Service_RetrieveRole_Call struct { + *mock.Call +} + +// RetrieveRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RetrieveRole(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RetrieveRole_Call { + return &Service_RetrieveRole_Call{Call: _e.mock.On("RetrieveRole", ctx, session, entityID, roleID)} +} + +func (_c *Service_RetrieveRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RetrieveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RetrieveRole_Call) Return(role roles.Role, err error) *Service_RetrieveRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Service_RetrieveRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) (roles.Role, error)) *Service_RetrieveRole_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddActions provides a mock function for the type Service +func (_mock *Service) RoleAddActions(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleAddActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleAddActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddActions' +type Service_RoleAddActions_Call struct { + *mock.Call +} + +// RoleAddActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleAddActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleAddActions_Call { + return &Service_RoleAddActions_Call{Call: _e.mock.On("RoleAddActions", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleAddActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleAddActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleAddActions_Call) Return(ops []string, err error) *Service_RoleAddActions_Call { + _c.Call.Return(ops, err) + return _c +} + +func (_c *Service_RoleAddActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) ([]string, error)) *Service_RoleAddActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddMembers provides a mock function for the type Service +func (_mock *Service) RoleAddMembers(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleAddMembers") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleAddMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddMembers' +type Service_RoleAddMembers_Call struct { + *mock.Call +} + +// RoleAddMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleAddMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleAddMembers_Call { + return &Service_RoleAddMembers_Call{Call: _e.mock.On("RoleAddMembers", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleAddMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleAddMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleAddMembers_Call) Return(strings []string, err error) *Service_RoleAddMembers_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_RoleAddMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) ([]string, error)) *Service_RoleAddMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckActionsExists provides a mock function for the type Service +func (_mock *Service) RoleCheckActionsExists(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) (bool, error) { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckActionsExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) (bool, error)); ok { + return returnFunc(ctx, session, entityID, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) bool); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleCheckActionsExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckActionsExists' +type Service_RoleCheckActionsExists_Call struct { + *mock.Call +} + +// RoleCheckActionsExists is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleCheckActionsExists(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleCheckActionsExists_Call { + return &Service_RoleCheckActionsExists_Call{Call: _e.mock.On("RoleCheckActionsExists", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleCheckActionsExists_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleCheckActionsExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleCheckActionsExists_Call) Return(b bool, err error) *Service_RoleCheckActionsExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Service_RoleCheckActionsExists_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) (bool, error)) *Service_RoleCheckActionsExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckMembersExists provides a mock function for the type Service +func (_mock *Service) RoleCheckMembersExists(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) (bool, error) { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckMembersExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) (bool, error)); ok { + return returnFunc(ctx, session, entityID, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) bool); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleCheckMembersExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckMembersExists' +type Service_RoleCheckMembersExists_Call struct { + *mock.Call +} + +// RoleCheckMembersExists is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleCheckMembersExists(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleCheckMembersExists_Call { + return &Service_RoleCheckMembersExists_Call{Call: _e.mock.On("RoleCheckMembersExists", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleCheckMembersExists_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleCheckMembersExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleCheckMembersExists_Call) Return(b bool, err error) *Service_RoleCheckMembersExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Service_RoleCheckMembersExists_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) (bool, error)) *Service_RoleCheckMembersExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleListActions provides a mock function for the type Service +func (_mock *Service) RoleListActions(ctx context.Context, session authn.Session, entityID string, roleID string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleListActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleListActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListActions' +type Service_RoleListActions_Call struct { + *mock.Call +} + +// RoleListActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleListActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleListActions_Call { + return &Service_RoleListActions_Call{Call: _e.mock.On("RoleListActions", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleListActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleListActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleListActions_Call) Return(strings []string, err error) *Service_RoleListActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_RoleListActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) ([]string, error)) *Service_RoleListActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleListMembers provides a mock function for the type Service +func (_mock *Service) RoleListMembers(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64) (roles.MembersPage, error) { + ret := _mock.Called(ctx, session, entityID, roleID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RoleListMembers") + } + + var r0 roles.MembersPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, uint64, uint64) (roles.MembersPage, error)); ok { + return returnFunc(ctx, session, entityID, roleID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, uint64, uint64) roles.MembersPage); ok { + r0 = returnFunc(ctx, session, entityID, roleID, limit, offset) + } else { + r0 = ret.Get(0).(roles.MembersPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleListMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListMembers' +type Service_RoleListMembers_Call struct { + *mock.Call +} + +// RoleListMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - limit uint64 +// - offset uint64 +func (_e *Service_Expecter) RoleListMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, limit interface{}, offset interface{}) *Service_RoleListMembers_Call { + return &Service_RoleListMembers_Call{Call: _e.mock.On("RoleListMembers", ctx, session, entityID, roleID, limit, offset)} +} + +func (_c *Service_RoleListMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64)) *Service_RoleListMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + var arg5 uint64 + if args[5] != nil { + arg5 = args[5].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_RoleListMembers_Call) Return(membersPage roles.MembersPage, err error) *Service_RoleListMembers_Call { + _c.Call.Return(membersPage, err) + return _c +} + +func (_c *Service_RoleListMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64) (roles.MembersPage, error)) *Service_RoleListMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveActions provides a mock function for the type Service +func (_mock *Service) RoleRemoveActions(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) error { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveActions' +type Service_RoleRemoveActions_Call struct { + *mock.Call +} + +// RoleRemoveActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleRemoveActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleRemoveActions_Call { + return &Service_RoleRemoveActions_Call{Call: _e.mock.On("RoleRemoveActions", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleRemoveActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleRemoveActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveActions_Call) Return(err error) *Service_RoleRemoveActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) error) *Service_RoleRemoveActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllActions provides a mock function for the type Service +func (_mock *Service) RoleRemoveAllActions(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveAllActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllActions' +type Service_RoleRemoveAllActions_Call struct { + *mock.Call +} + +// RoleRemoveAllActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleRemoveAllActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleRemoveAllActions_Call { + return &Service_RoleRemoveAllActions_Call{Call: _e.mock.On("RoleRemoveAllActions", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleRemoveAllActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleRemoveAllActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveAllActions_Call) Return(err error) *Service_RoleRemoveAllActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveAllActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RoleRemoveAllActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllMembers provides a mock function for the type Service +func (_mock *Service) RoleRemoveAllMembers(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveAllMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllMembers' +type Service_RoleRemoveAllMembers_Call struct { + *mock.Call +} + +// RoleRemoveAllMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleRemoveAllMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleRemoveAllMembers_Call { + return &Service_RoleRemoveAllMembers_Call{Call: _e.mock.On("RoleRemoveAllMembers", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleRemoveAllMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleRemoveAllMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveAllMembers_Call) Return(err error) *Service_RoleRemoveAllMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveAllMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RoleRemoveAllMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveMembers provides a mock function for the type Service +func (_mock *Service) RoleRemoveMembers(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) error { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveMembers' +type Service_RoleRemoveMembers_Call struct { + *mock.Call +} + +// RoleRemoveMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleRemoveMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleRemoveMembers_Call { + return &Service_RoleRemoveMembers_Call{Call: _e.mock.On("RoleRemoveMembers", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleRemoveMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleRemoveMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveMembers_Call) Return(err error) *Service_RoleRemoveMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) error) *Service_RoleRemoveMembers_Call { + _c.Call.Return(run) + return _c +} + +// StartScheduler provides a mock function for the type Service +func (_mock *Service) StartScheduler(ctx context.Context) error { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for StartScheduler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_StartScheduler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StartScheduler' +type Service_StartScheduler_Call struct { + *mock.Call +} + +// StartScheduler is a helper method to define mock.On call +// - ctx context.Context +func (_e *Service_Expecter) StartScheduler(ctx interface{}) *Service_StartScheduler_Call { + return &Service_StartScheduler_Call{Call: _e.mock.On("StartScheduler", ctx)} +} + +func (_c *Service_StartScheduler_Call) Run(run func(ctx context.Context)) *Service_StartScheduler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Service_StartScheduler_Call) Return(err error) *Service_StartScheduler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_StartScheduler_Call) RunAndReturn(run func(ctx context.Context) error) *Service_StartScheduler_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRoleName provides a mock function for the type Service +func (_mock *Service) UpdateRoleName(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string) (roles.Role, error) { + ret := _mock.Called(ctx, session, entityID, roleID, newRoleName) + + if len(ret) == 0 { + panic("no return value specified for UpdateRoleName") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, session, entityID, roleID, newRoleName) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) roles.Role); ok { + r0 = returnFunc(ctx, session, entityID, roleID, newRoleName) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, newRoleName) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateRoleName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRoleName' +type Service_UpdateRoleName_Call struct { + *mock.Call +} + +// UpdateRoleName is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - newRoleName string +func (_e *Service_Expecter) UpdateRoleName(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, newRoleName interface{}) *Service_UpdateRoleName_Call { + return &Service_UpdateRoleName_Call{Call: _e.mock.On("UpdateRoleName", ctx, session, entityID, roleID, newRoleName)} +} + +func (_c *Service_UpdateRoleName_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string)) *Service_UpdateRoleName_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_UpdateRoleName_Call) Return(role roles.Role, err error) *Service_UpdateRoleName_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Service_UpdateRoleName_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string) (roles.Role, error)) *Service_UpdateRoleName_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRule provides a mock function for the type Service +func (_mock *Service) UpdateRule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, session, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, session, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, session, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, re.Rule) error); ok { + r1 = returnFunc(ctx, session, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRule' +type Service_UpdateRule_Call struct { + *mock.Call +} + +// UpdateRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - r re.Rule +func (_e *Service_Expecter) UpdateRule(ctx interface{}, session interface{}, r interface{}) *Service_UpdateRule_Call { + return &Service_UpdateRule_Call{Call: _e.mock.On("UpdateRule", ctx, session, r)} +} + +func (_c *Service_UpdateRule_Call) Run(run func(ctx context.Context, session authn.Session, r re.Rule)) *Service_UpdateRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 re.Rule + if args[2] != nil { + arg2 = args[2].(re.Rule) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateRule_Call) Return(rule re.Rule, err error) *Service_UpdateRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_UpdateRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error)) *Service_UpdateRule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleSchedule provides a mock function for the type Service +func (_mock *Service) UpdateRuleSchedule(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, session, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleSchedule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, session, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, session, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, re.Rule) error); ok { + r1 = returnFunc(ctx, session, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateRuleSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleSchedule' +type Service_UpdateRuleSchedule_Call struct { + *mock.Call +} + +// UpdateRuleSchedule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - r re.Rule +func (_e *Service_Expecter) UpdateRuleSchedule(ctx interface{}, session interface{}, r interface{}) *Service_UpdateRuleSchedule_Call { + return &Service_UpdateRuleSchedule_Call{Call: _e.mock.On("UpdateRuleSchedule", ctx, session, r)} +} + +func (_c *Service_UpdateRuleSchedule_Call) Run(run func(ctx context.Context, session authn.Session, r re.Rule)) *Service_UpdateRuleSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 re.Rule + if args[2] != nil { + arg2 = args[2].(re.Rule) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateRuleSchedule_Call) Return(rule re.Rule, err error) *Service_UpdateRuleSchedule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_UpdateRuleSchedule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error)) *Service_UpdateRuleSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRuleTags provides a mock function for the type Service +func (_mock *Service) UpdateRuleTags(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error) { + ret := _mock.Called(ctx, session, r) + + if len(ret) == 0 { + panic("no return value specified for UpdateRuleTags") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) (re.Rule, error)); ok { + return returnFunc(ctx, session, r) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, re.Rule) re.Rule); ok { + r0 = returnFunc(ctx, session, r) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, re.Rule) error); ok { + r1 = returnFunc(ctx, session, r) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateRuleTags_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRuleTags' +type Service_UpdateRuleTags_Call struct { + *mock.Call +} + +// UpdateRuleTags is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - r re.Rule +func (_e *Service_Expecter) UpdateRuleTags(ctx interface{}, session interface{}, r interface{}) *Service_UpdateRuleTags_Call { + return &Service_UpdateRuleTags_Call{Call: _e.mock.On("UpdateRuleTags", ctx, session, r)} +} + +func (_c *Service_UpdateRuleTags_Call) Run(run func(ctx context.Context, session authn.Session, r re.Rule)) *Service_UpdateRuleTags_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 re.Rule + if args[2] != nil { + arg2 = args[2].(re.Rule) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateRuleTags_Call) Return(rule re.Rule, err error) *Service_UpdateRuleTags_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_UpdateRuleTags_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, r re.Rule) (re.Rule, error)) *Service_UpdateRuleTags_Call { + _c.Call.Return(run) + return _c +} + +// ViewRule provides a mock function for the type Service +func (_mock *Service) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error) { + ret := _mock.Called(ctx, session, id, withRoles) + + if len(ret) == 0 { + panic("no return value specified for ViewRule") + } + + var r0 re.Rule + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bool) (re.Rule, error)); ok { + return returnFunc(ctx, session, id, withRoles) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bool) re.Rule); ok { + r0 = returnFunc(ctx, session, id, withRoles) + } else { + r0 = ret.Get(0).(re.Rule) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, bool) error); ok { + r1 = returnFunc(ctx, session, id, withRoles) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewRule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewRule' +type Service_ViewRule_Call struct { + *mock.Call +} + +// ViewRule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +// - withRoles bool +func (_e *Service_Expecter) ViewRule(ctx interface{}, session interface{}, id interface{}, withRoles interface{}) *Service_ViewRule_Call { + return &Service_ViewRule_Call{Call: _e.mock.On("ViewRule", ctx, session, id, withRoles)} +} + +func (_c *Service_ViewRule_Call) Run(run func(ctx context.Context, session authn.Session, id string, withRoles bool)) *Service_ViewRule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ViewRule_Call) Return(rule re.Rule, err error) *Service_ViewRule_Call { + _c.Call.Return(rule, err) + return _c +} + +func (_c *Service_ViewRule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string, withRoles bool) (re.Rule, error)) *Service_ViewRule_Call { + _c.Call.Return(run) + return _c +} diff --git a/re/operations/operations.go b/re/operations/operations.go new file mode 100644 index 000000000..9fe6e6031 --- /dev/null +++ b/re/operations/operations.go @@ -0,0 +1,62 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package operations + +import "github.com/absmach/supermq/pkg/permissions" + +const EntityType = "rule" + +// Rule Operations. +const ( + OpAddRule permissions.Operation = iota + OpViewRule + OpUpdateRule + OpUpdateRuleTags + OpUpdateRuleSchedule + OpRemoveRule + OpListRules + OpEnableRule + OpDisableRule +) + +func OperationDetails() map[permissions.Operation]permissions.OperationDetails { + return map[permissions.Operation]permissions.OperationDetails{ + OpAddRule: { + Name: "add", + PermissionRequired: true, + }, + OpViewRule: { + Name: "view", + PermissionRequired: true, + }, + OpUpdateRule: { + Name: "update", + PermissionRequired: true, + }, + OpUpdateRuleTags: { + Name: "update_tags", + PermissionRequired: true, + }, + OpUpdateRuleSchedule: { + Name: "update_schedule", + PermissionRequired: true, + }, + OpRemoveRule: { + Name: "delete", + PermissionRequired: true, + }, + OpListRules: { + Name: "list", + PermissionRequired: true, + }, + OpEnableRule: { + Name: "enable", + PermissionRequired: true, + }, + OpDisableRule: { + Name: "disable", + PermissionRequired: true, + }, + } +} diff --git a/re/outputs/alarm.go b/re/outputs/alarm.go new file mode 100644 index 000000000..5bb55aaf1 --- /dev/null +++ b/re/outputs/alarm.go @@ -0,0 +1,79 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "bytes" + "context" + "encoding/gob" + "encoding/json" + + "github.com/absmach/supermq/alarms" + "github.com/absmach/supermq/pkg/messaging" +) + +type Alarm struct { + AlarmsPub messaging.Publisher `json:"-"` + RuleID string `json:"rule_id"` +} + +func (a *Alarm) Run(ctx context.Context, msg *messaging.Message, val any) error { + data, err := json.Marshal(val) + if err != nil { + return err + } + + var alarmsList []alarms.Alarm + if err := json.Unmarshal(data, &alarmsList); err != nil { + var single alarms.Alarm + if err := json.Unmarshal(data, &single); err != nil { + return err + } + alarmsList = []alarms.Alarm{single} + } + + for _, alarm := range alarmsList { + if err := a.processAlarm(ctx, msg, alarm); err != nil { + return err + } + } + + return nil +} + +func (a *Alarm) processAlarm(ctx context.Context, msg *messaging.Message, alarm alarms.Alarm) error { + alarm.RuleID = a.RuleID + alarm.DomainID = msg.Domain + alarm.ClientID = msg.ClientIdentity() + alarm.ChannelID = msg.Channel + alarm.Subtopic = msg.Subtopic + + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(alarm); err != nil { + return err + } + + m := &messaging.Message{ + Domain: msg.Domain, + Publisher: msg.Publisher, + ClientId: msg.ClientIdentity(), + Created: msg.Created, + Channel: msg.Channel, + Subtopic: msg.Subtopic, + Protocol: msg.Protocol, + Payload: buf.Bytes(), + } + + topic := messaging.EncodeMessageTopic(msg) + if err := a.AlarmsPub.Publish(ctx, topic, m); err != nil { + return err + } + return nil +} + +func (a *Alarm) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]any{ + "type": AlarmsType.String(), + }) +} diff --git a/re/outputs/channel.go b/re/outputs/channel.go new file mode 100644 index 000000000..c48f1543d --- /dev/null +++ b/re/outputs/channel.go @@ -0,0 +1,52 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "context" + "encoding/json" + + "github.com/absmach/supermq/pkg/messaging" +) + +const protocol = "nats" + +type ChannelPublisher struct { + RePubSub messaging.PubSub `json:"-"` + Channel string `json:"channel"` + Topic string `json:"topic"` +} + +func (p *ChannelPublisher) Run(ctx context.Context, msg *messaging.Message, val any) error { + data, err := json.Marshal(val) + if err != nil { + return err + } + + m := &messaging.Message{ + Domain: msg.Domain, + Publisher: msg.Publisher, + ClientId: msg.ClientIdentity(), + Created: msg.Created, + Channel: p.Channel, + Subtopic: p.Topic, + Protocol: protocol, + Payload: data, + } + + topic := messaging.EncodeTopicSuffix(msg.Domain, p.Channel, p.Topic) + if err := p.RePubSub.Publish(ctx, topic, m); err != nil { + return err + } + + return nil +} + +func (cp *ChannelPublisher) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]string{ + "type": ChannelsType.String(), + "channel": cp.Channel, + "topic": cp.Topic, + }) +} diff --git a/re/outputs/doc.go b/re/outputs/doc.go new file mode 100644 index 000000000..47f1ac5ff --- /dev/null +++ b/re/outputs/doc.go @@ -0,0 +1,4 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs diff --git a/re/outputs/email.go b/re/outputs/email.go new file mode 100644 index 000000000..45d1f6c05 --- /dev/null +++ b/re/outputs/email.go @@ -0,0 +1,54 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "bytes" + "context" + "encoding/json" + "text/template" + + "github.com/absmach/supermq/pkg/emailer" + "github.com/absmach/supermq/pkg/messaging" +) + +type Email struct { + To []string `json:"to"` + Subject string `json:"subject"` + Content string `json:"content"` + Emailer emailer.Emailer `json:"-"` +} + +func (e *Email) Run(ctx context.Context, msg *messaging.Message, val any) error { + templData := templateVal{ + Message: msg, + Result: val, + } + + tmpl, err := template.New("email").Parse(e.Content) + if err != nil { + return err + } + + var output bytes.Buffer + if err := tmpl.Execute(&output, templData); err != nil { + return err + } + + content := output.String() + + if err := e.Emailer.SendEmailNotification(e.To, "", e.Subject, "", "", content, "", make(map[string][]byte)); err != nil { + return err + } + return nil +} + +func (e *Email) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]any{ + "type": EmailType.String(), + "to": e.To, + "subject": e.Subject, + "content": e.Content, + }) +} diff --git a/re/outputs/outputs.go b/re/outputs/outputs.go new file mode 100644 index 000000000..b8bbfcdf6 --- /dev/null +++ b/re/outputs/outputs.go @@ -0,0 +1,68 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "encoding/json" + "strings" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" +) + +type templateVal struct { + Message *messaging.Message + Result any +} + +// OutputType is the indicator for type of the output +// so we can move it to the Go instead calling Go from Lua. +type OutputType uint + +const ( + ChannelsType OutputType = iota + AlarmsType + SaveSenMLType + EmailType + SaveRemotePgType + SlackType +) + +var ( + scriptKindToString = [...]string{"channels", "alarms", "save_senml", "email", "save_remote_pg", "slack"} + stringToScriptKind = map[string]OutputType{ + "channels": ChannelsType, + "alarms": AlarmsType, + "save_senml": SaveSenMLType, + "email": EmailType, + "save_remote_pg": SaveRemotePgType, + "slack": SlackType, + } +) + +func (s OutputType) String() string { + if int(s) < 0 || int(s) >= len(scriptKindToString) { + return "unknown" + } + return scriptKindToString[s] +} + +// MarshalJSON converts OutputType to JSON. +func (s *OutputType) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +// UnmarshalJSON parses JSON string into OutputType. +func (s *OutputType) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + lower := strings.ToLower(str) + if val, ok := stringToScriptKind[lower]; ok { + *s = val + return nil + } + return errors.New("invalid OutputType: " + str) +} diff --git a/re/outputs/postgres.go b/re/outputs/postgres.go new file mode 100644 index 000000000..2633163e0 --- /dev/null +++ b/re/outputs/postgres.go @@ -0,0 +1,107 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + "text/template" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + "github.com/jmoiron/sqlx" +) + +type Postgres struct { + Host string `json:"host"` + Port int `json:"port"` + User string `json:"user"` + Password string `json:"password"` + Database string `json:"database"` + Table string `json:"table"` + Mapping string `json:"mapping"` +} + +func (p *Postgres) Run(ctx context.Context, msg *messaging.Message, val any) error { + templData := templateVal{ + Message: msg, + Result: val, + } + + tmpl, err := template.New("postgres").Parse(p.Mapping) + if err != nil { + return err + } + + var output bytes.Buffer + if err := tmpl.Execute(&output, templData); err != nil { + return err + } + + mapping := output.String() + var columns map[string]any + if err = json.Unmarshal([]byte(mapping), &columns); err != nil { + return err + } + + connStr := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + p.Host, p.Port, p.User, p.Password, p.Database, + ) + + db, err := sqlx.Open("pgx", connStr) + if err != nil { + return err + } + defer db.Close() + + if err := db.Ping(); err != nil { + return errors.Wrap(errors.New("failed to connect to DB"), err) + } + + var ( + cols []string + values []any + placeholders []string + ) + + i := 1 + for k, v := range columns { + cols = append(cols, k) + values = append(values, v) + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + i++ + } + + q := fmt.Sprintf( + `INSERT INTO %s (%s) VALUES (%s)`, + p.Table, + strings.Join(cols, ", "), + strings.Join(placeholders, ", "), + ) + + _, err = db.Exec(q, values...) + if err != nil { + return errors.Wrap(errors.New("failed to insert data"), err) + } + + return nil +} + +func (p *Postgres) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]any{ + "type": SaveRemotePgType.String(), + "host": p.Host, + "port": p.Port, + "user": p.User, + "password": p.Password, + "database": p.Database, + "table": p.Table, + "mapping": p.Mapping, + }) +} diff --git a/re/outputs/senml.go b/re/outputs/senml.go new file mode 100644 index 000000000..31f176223 --- /dev/null +++ b/re/outputs/senml.go @@ -0,0 +1,53 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "context" + "encoding/json" + + "github.com/absmach/senml" + "github.com/absmach/supermq/pkg/messaging" +) + +type SenML struct { + WritersPub messaging.Publisher `json:"-"` +} + +func (s *SenML) Run(ctx context.Context, msg *messaging.Message, val any) error { + // In case there is a single SenML value, convert to slice so we can decode. + if _, ok := val.([]any); !ok { + val = []any{val} + } + data, err := json.Marshal(val) + if err != nil { + return err + } + if _, err := senml.Decode(data, senml.JSON); err != nil { + return err + } + + m := &messaging.Message{ + Domain: msg.Domain, + Publisher: msg.Publisher, + ClientId: msg.ClientIdentity(), + Created: msg.Created, + Channel: msg.Channel, + Subtopic: msg.Subtopic, + Protocol: msg.Protocol, + Payload: data, + } + topic := messaging.EncodeMessageTopic(msg) + if err := s.WritersPub.Publish(ctx, topic, m); err != nil { + return err + } + + return nil +} + +func (senml *SenML) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]string{ + "type": SaveSenMLType.String(), + }) +} diff --git a/re/outputs/slack.go b/re/outputs/slack.go new file mode 100644 index 000000000..a43dd950d --- /dev/null +++ b/re/outputs/slack.go @@ -0,0 +1,72 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package outputs + +import ( + "bytes" + "context" + "encoding/json" + "text/template" + + "github.com/absmach/supermq/pkg/messaging" + "github.com/slack-go/slack" +) + +type Slack struct { + Token string `json:"token"` + ChannelID string `json:"channel_id"` + Message string `json:"message"` +} + +func (s *Slack) Run(ctx context.Context, msg *messaging.Message, val any) error { + templData := templateVal{ + Message: msg, + Result: val, + } + + tmpl, err := template.New("slack").Parse(s.Message) + if err != nil { + return err + } + + var output bytes.Buffer + if err := tmpl.Execute(&output, templData); err != nil { + return err + } + + mapping := output.String() + + var message slack.Msg + if err := json.Unmarshal([]byte(mapping), &message); err != nil { + return err + } + + slackClient := slack.New(s.Token) + + var opts []slack.MsgOption + + if message.Text != "" { + opts = append(opts, slack.MsgOptionText(message.Text, false)) + } + if len(message.Attachments) > 0 { + opts = append(opts, slack.MsgOptionAttachments(message.Attachments...)) + } + if len(message.Blocks.BlockSet) > 0 { + opts = append(opts, slack.MsgOptionBlocks(message.Blocks.BlockSet...)) + } + _, _, err = slackClient.PostMessage(s.ChannelID, opts...) + if err != nil { + return err + } + return nil +} + +func (s *Slack) MarshalJSON() ([]byte, error) { + return json.Marshal(map[string]any{ + "type": SlackType.String(), + "token": s.Token, + "channel_id": s.ChannelID, + "message": s.Message, + }) +} diff --git a/re/postgres/init.go b/re/postgres/init.go new file mode 100644 index 000000000..f7ad8191d --- /dev/null +++ b/re/postgres/init.go @@ -0,0 +1,86 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + dpostgres "github.com/absmach/supermq/domains/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + migrate "github.com/rubenv/sql-migrate" +) + +func Migration() (*migrate.MemoryMigrationSource, error) { + rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName) + if err != nil { + return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err) + } + rulesMigration := &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "rules_01", + // VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters + // STATUS 0 to imply enabled and 1 to imply disabled + Up: []string{ + `CREATE TABLE IF NOT EXISTS rules ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(1024), + domain_id VARCHAR(36) NOT NULL, + metadata JSONB, + created_by VARCHAR(254), + created_at TIMESTAMP, + updated_at TIMESTAMP, + updated_by VARCHAR(254), + input_channel VARCHAR(36), + input_topic TEXT, + outputs JSONB, + status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0), + logic_type SMALLINT NOT NULL DEFAULT 0 CHECK (logic_type >= 0), + logic_value BYTEA, + time TIMESTAMP, + recurring SMALLINT, + recurring_period SMALLINT, + start_datetime TIMESTAMP + )`, + }, + Down: []string{ + `DROP TABLE IF EXISTS rules`, + }, + }, + { + Id: "rules_02", + Up: []string{ + `ALTER TABLE rules ADD COLUMN tags TEXT[];`, + }, + Down: []string{ + `ALTER TABLE rules DROP COLUMN tags;`, + }, + }, + { + Id: "rules_03", + Up: []string{ + `UPDATE rules + SET metadata = (COALESCE(metadata, '{}'::jsonb) - 'ui') || jsonb_build_object('flow', metadata->'ui') + WHERE metadata ? 'ui' AND jsonb_typeof(metadata->'ui') = 'string'`, + }, + Down: []string{ + `UPDATE rules + SET metadata = (COALESCE(metadata, '{}'::jsonb) - 'flow') || jsonb_build_object('ui', metadata->'flow') + WHERE metadata ? 'flow' AND jsonb_typeof(metadata->'flow') = 'string'`, + }, + }, + }, + } + + rulesMigration.Migrations = append(rulesMigration.Migrations, rolesMigration.Migrations...) + + domainsMigration, err := dpostgres.Migration() + if err != nil { + return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err) + } + rulesMigration.Migrations = append(rulesMigration.Migrations, domainsMigration.Migrations...) + + return rulesMigration, nil +} diff --git a/re/postgres/repository.go b/re/postgres/repository.go new file mode 100644 index 000000000..dc319f2a3 --- /dev/null +++ b/re/postgres/repository.go @@ -0,0 +1,556 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + api "github.com/absmach/supermq/api/http" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + mgPolicies "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/postgres" + rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" + "github.com/absmach/supermq/re" +) + +const ( + rolesTableNamePrefix = "rules" + entityTableName = "rules" + entityIDColumnName = "id" +) + +type PostgresRepository struct { + DB postgres.Database + rolesPostgres.Repository +} + +func NewRepository(db postgres.Database) re.Repository { + rolesRepo := rolesPostgres.NewRepository(db, mgPolicies.RulesType, rolesTableNamePrefix, entityTableName, entityIDColumnName) + return &PostgresRepository{ + DB: db, + Repository: rolesRepo, + } +} + +func (repo *PostgresRepository) AddRule(ctx context.Context, r re.Rule) (re.Rule, error) { + q := ` + INSERT INTO rules (id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status) + VALUES (:id, :name, :domain_id, :tags, :metadata, :input_channel, :input_topic, :logic_type, :logic_value, + :outputs, :start_datetime, :time, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status) + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; +` + dbr, err := ruleToDb(r) + if err != nil { + return re.Rule{}, err + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return re.Rule{}, postgres.HandleError(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + var dbRule dbRule + if row.Next() { + if err := row.StructScan(&dbRule); err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + + rule, err := dbToRule(dbRule) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return rule, nil +} + +func (repo *PostgresRepository) ViewRule(ctx context.Context, id string) (re.Rule, error) { + q := ` + SELECT id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, outputs, + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + FROM rules + WHERE id = $1; + ` + row := repo.DB.QueryRowxContext(ctx, q, id) + if err := row.Err(); err != nil { + return re.Rule{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + var dbr dbRule + if err := row.StructScan(&dbr); err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + ret, err := dbToRule(dbr) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return ret, nil +} + +func (repo *PostgresRepository) RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (re.Rule, error) { + query := ` + WITH selected_rule AS ( + SELECT + r.id, + r.domain_id + FROM + rules r + WHERE + r.id = :id + LIMIT 1 + ), + selected_rule_roles AS ( + SELECT + rr.entity_id AS rule_id, + rrm.member_id AS member_id, + rr.id AS role_id, + rr."name" AS role_name, + jsonb_agg(DISTINCT rra."action") AS actions, + 'direct' AS access_type, + '' AS access_provider_id + FROM + rules_roles rr + JOIN + rules_role_members rrm ON rr.id = rrm.role_id + JOIN + rules_role_actions rra ON rr.id = rra.role_id + JOIN + selected_rule sr ON sr.id = rr.entity_id + AND rrm.member_id = :member_id + GROUP BY + rr.entity_id, rr.id, rr.name, rrm.member_id + ), + selected_domain_roles AS ( + SELECT + sr.id AS rule_id, + drm.member_id AS member_id, + dr.id AS role_id, + dr."name" AS role_name, + jsonb_agg(DISTINCT all_actions."action") AS actions, + 'domain' AS access_type, + dr.entity_id AS access_provider_id + FROM + domains d + JOIN + selected_rule sr ON sr.domain_id = d.id + JOIN + domains_roles dr ON dr.entity_id = d.id + JOIN + domains_role_members drm ON dr.id = drm.role_id + JOIN + domains_role_actions dra ON dr.id = dra.role_id + JOIN + domains_role_actions all_actions ON dr.id = all_actions.role_id + WHERE + drm.member_id = :member_id + AND dra."action" LIKE 'rule%' + GROUP BY + sr.id, dr.entity_id, dr.id, dr."name", drm.member_id + ), + all_roles AS ( + SELECT + srr.rule_id, + srr.member_id, + srr.role_id, + srr.role_name, + srr.actions, + srr.access_type, + srr.access_provider_id + FROM + selected_rule_roles srr + UNION + SELECT + sdr.rule_id, + sdr.member_id, + sdr.role_id, + sdr.role_name, + sdr.actions, + sdr.access_type, + sdr.access_provider_id + FROM + selected_domain_roles sdr + ), + final_roles AS ( + SELECT + ar.rule_id, + ar.member_id, + jsonb_agg( + jsonb_build_object( + 'role_id', ar.role_id, + 'role_name', ar.role_name, + 'actions', ar.actions, + 'access_type', ar.access_type, + 'access_provider_id', ar.access_provider_id + ) + ) AS roles + FROM all_roles ar + GROUP BY + ar.rule_id, ar.member_id + ) + SELECT + r2.id, + r2."name", + r2.domain_id, + r2.tags, + r2.metadata, + r2.input_channel, + r2.input_topic, + r2.outputs, + r2.status, + r2.logic_type, + r2.logic_value, + r2.time, + r2.recurring, + r2.recurring_period, + r2.start_datetime, + r2.created_at, + r2.created_by, + r2.updated_at, + r2.updated_by, + fr.member_id, + fr.roles + FROM rules r2 + JOIN final_roles fr ON fr.rule_id = r2.id + ` + parameters := map[string]any{ + "id": id, + "member_id": memberID, + } + row, err := repo.DB.NamedQueryContext(ctx, query, parameters) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer row.Close() + + dbrule := dbRule{} + if !row.Next() { + return re.Rule{}, repoerr.ErrNotFound + } + + if err := row.StructScan(&dbrule); err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + r, err := dbToRule(dbrule) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + return r, nil +} + +func (repo *PostgresRepository) UpdateRuleStatus(ctx context.Context, r re.Rule) (re.Rule, error) { + q := `UPDATE rules + SET status = :status, updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + + return repo.update(ctx, r, q) +} + +func (repo *PostgresRepository) UpdateRule(ctx context.Context, r re.Rule) (re.Rule, error) { + var query []string + var upq string + if r.Name != "" { + query = append(query, "name = :name,") + } + if r.Metadata != nil { + query = append(query, "metadata = :metadata,") + } + query = append(query, "input_channel = :input_channel,") + query = append(query, "input_topic = :input_topic,") + if r.Outputs != nil { + query = append(query, "outputs = :outputs, ") + } + if r.Logic.Value != "" { + query = append(query, "logic_type = :logic_type,") + query = append(query, "logic_value = :logic_value,") + } + + if len(query) > 0 { + upq = strings.Join(query, " ") + } + + q := fmt.Sprintf(` + UPDATE rules + SET %s updated_at = :updated_at, updated_by = :updated_by WHERE id = :id + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + `, upq) + + return repo.update(ctx, r, q) +} + +func (repo *PostgresRepository) UpdateRuleTags(ctx context.Context, r re.Rule) (re.Rule, error) { + q := `UPDATE rules SET tags = :tags, updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id AND status = :status + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + r.Status = re.EnabledStatus + + return repo.update(ctx, r, q) +} + +func (repo *PostgresRepository) UpdateRuleSchedule(ctx context.Context, r re.Rule) (re.Rule, error) { + q := ` + UPDATE rules + SET start_datetime = :start_datetime, time = :time, recurring = :recurring, + recurring_period = :recurring_period, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + ` + return repo.update(ctx, r, q) +} + +func (repo *PostgresRepository) update(ctx context.Context, r re.Rule, query string) (re.Rule, error) { + dbr, err := ruleToDb(r) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + row, err := repo.DB.NamedQueryContext(ctx, query, dbr) + if err != nil { + return re.Rule{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + if !row.Next() { + return re.Rule{}, repoerr.ErrNotFound + } + var dbRule dbRule + if err := row.StructScan(&dbRule); err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + rule, err := dbToRule(dbRule) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + return rule, nil +} + +func (repo *PostgresRepository) RemoveRule(ctx context.Context, id string) error { + q := ` + DELETE FROM rules + WHERE id = $1; +` + result, err := repo.DB.ExecContext(ctx, q, id) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + if rowsAffected == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo *PostgresRepository) ListAllRules(ctx context.Context, pm re.PageMeta) (re.Page, error) { + pq := pageRulesQuery(pm) + orderClause := rulesOrderClause(pm) + pgData := rulesPageData(pm) + + q := fmt.Sprintf(` + SELECT id, name, domain_id, tags, input_channel, input_topic, logic_type, logic_value, outputs, + start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + FROM rules r %s %s %s; + `, pq, orderClause, pgData) + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return re.Page{}, err + } + defer rows.Close() + + var rules []re.Rule + var r dbRule + for rows.Next() { + if err := rows.StructScan(&r); err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + ret, err := dbToRule(r) + if err != nil { + return re.Page{}, err + } + rules = append(rules, ret) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM rules r %s;`, pq) + + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + ret := re.Page{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Rules: rules, + } + + return ret, nil +} + +func (repo *PostgresRepository) ListUserRules(ctx context.Context, userID string, pm re.PageMeta) (re.Page, error) { + pm.UserID = userID + pq := pageRulesQuery(pm) + orderClause := rulesOrderClause(pm) + pgData := rulesPageData(pm) + + userJoin := ` + INNER JOIN rules_roles rr ON rr.entity_id = r.id + INNER JOIN rules_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id + ` + + innerQ := fmt.Sprintf(` + SELECT DISTINCT r.id, r.name, r.domain_id, r.tags, r.input_channel, r.input_topic, r.logic_type, r.logic_value, r.outputs, + r.start_datetime, r.time, r.recurring, r.recurring_period, r.created_at, r.created_by, r.updated_at, r.updated_by, r.status + FROM rules r + %s + %s + `, userJoin, pq) + + q := fmt.Sprintf(` + SELECT * FROM (%s) AS sub %s %s; + `, innerQ, orderClause, pgData) + + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return re.Page{}, err + } + defer rows.Close() + + var rules []re.Rule + for rows.Next() { + var r dbRule + if err := rows.StructScan(&r); err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + ret, err := dbToRule(r) + if err != nil { + return re.Page{}, err + } + rules = append(rules, ret) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ) + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return re.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return re.Page{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Rules: rules, + }, nil +} + +func (repo *PostgresRepository) UpdateRuleDue(ctx context.Context, id string, due time.Time) (re.Rule, error) { + q := ` + UPDATE rules + SET time = :time, updated_at = :updated_at WHERE id = :id + RETURNING id, name, domain_id, tags, metadata, input_channel, input_topic, logic_type, logic_value, + outputs, start_datetime, time, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + ` + dbr := dbRule{ + ID: id, + UpdatedAt: time.Now().UTC(), + Time: sql.NullTime{Time: due}, + } + if !due.IsZero() { + dbr.Time.Valid = true + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return re.Rule{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + var dbRule dbRule + if row.Next() { + if err := row.StructScan(&dbRule); err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + } + rule, err := dbToRule(dbRule) + if err != nil { + return re.Rule{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return rule, nil +} + +func rulesOrderClause(pm re.PageMeta) string { + dir := api.DescDir + if pm.Dir == api.AscDir { + dir = api.AscDir + } + + switch pm.Order { + case api.NameKey: + return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) + case api.CreatedAtOrder: + return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) + default: + return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) + } +} + +func rulesPageData(pm re.PageMeta) string { + pgData := "" + if pm.Limit != 0 { + pgData = "LIMIT :limit" + } + if pm.Offset != 0 { + pgData += " OFFSET :offset" + } + return pgData +} + +func pageRulesQuery(pm re.PageMeta) string { + var query []string + if pm.InputChannel != "" { + query = append(query, "r.input_channel = :input_channel") + } + if pm.Status != re.AllStatus { + query = append(query, "r.status = :status") + } + if pm.Domain != "" { + query = append(query, "r.domain_id = :domain_id") + } + if pm.Tag != "" { + query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')") + } + if pm.ScheduledBefore != nil { + query = append(query, "r.time < :scheduled_before") + } + if pm.ScheduledAfter != nil { + query = append(query, "r.time > :scheduled_after") + } + if pm.Name != "" { + query = append(query, "r.name ILIKE '%' || :name || '%'") + } + if pm.Scheduled != nil && !*pm.Scheduled { + query = append(query, "r.time IS NULL") + } + + var q string + if len(query) > 0 { + q = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return q +} diff --git a/re/postgres/repository_test.go b/re/postgres/repository_test.go new file mode 100644 index 000000000..61ed53cb9 --- /dev/null +++ b/re/postgres/repository_test.go @@ -0,0 +1,1184 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/outputs" + "github.com/absmach/supermq/re/postgres" + "github.com/stretchr/testify/assert" +) + +const ( + ascDir = "asc" + descDir = "desc" + nameOrder = "name" + createdAtOrder = "created_at" + updatedAtOrder = "updated_at" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func TestAddRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Tags: []string{"test", "rule"}, + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Outputs: re.Outputs{ + &outputs.Alarm{}, + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + + scheduleName := namegen.Generate() + scheduleDomain := generateUUID(t) + scheduleChannel := generateUUID(t) + scheduleCreatedBy := generateUUID(t) + scheduleCreatedAt := time.Now().UTC().Truncate(time.Microsecond) + scheduleUpdatedBy := generateUUID(t) + scheduleUpdatedAt := time.Now().UTC().Truncate(time.Microsecond) + scheduleStartTime := time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond) + scheduleTime := time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond) + + scheduleRule := re.Rule{ + ID: generateUUID(t), + Name: scheduleName, + DomainID: scheduleDomain, + InputChannel: scheduleChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 50", + }, + Schedule: schedule.Schedule{ + StartDateTime: scheduleStartTime, + Time: scheduleTime, + Recurring: schedule.Daily, + RecurringPeriod: 1, + }, + Status: re.EnabledStatus, + CreatedAt: scheduleCreatedAt, + CreatedBy: scheduleCreatedBy, + UpdatedAt: scheduleUpdatedAt, + UpdatedBy: scheduleUpdatedBy, + Metadata: re.Metadata{}, + } + + outputsName := namegen.Generate() + outputsDomain := generateUUID(t) + outputsChannel := generateUUID(t) + outputsCreatedBy := generateUUID(t) + outputsCreatedAt := time.Now().UTC().Truncate(time.Microsecond) + outputsUpdatedBy := generateUUID(t) + outputsUpdatedAt := time.Now().UTC().Truncate(time.Microsecond) + outputsRuleID := generateUUID(t) + + outputsRule := re.Rule{ + ID: outputsRuleID, + Name: outputsName, + DomainID: outputsDomain, + InputChannel: outputsChannel, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: generateUUID(t), + Topic: "alerts", + }, + &outputs.SenML{}, + }, + Status: re.EnabledStatus, + CreatedAt: outputsCreatedAt, + CreatedBy: outputsCreatedBy, + UpdatedAt: outputsUpdatedAt, + UpdatedBy: outputsUpdatedBy, + Metadata: re.Metadata{}, + } + + cases := []struct { + desc string + rule re.Rule + resp re.Rule + err error + }{ + { + desc: "valid rule", + rule: rule, + resp: rule, + err: nil, + }, + { + desc: "duplicate rule", + rule: rule, + resp: re.Rule{}, + err: repoerr.ErrConflict, + }, + + { + desc: "rule with schedule", + rule: scheduleRule, + resp: scheduleRule, + err: nil, + }, + { + desc: "rule with outputs", + rule: outputsRule, + resp: outputsRule, + err: nil, + }, + { + desc: "invalid metadata", + rule: re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Metadata: map[string]any{ + "key": make(chan int), + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + addedRule, err := repo.AddRule(context.Background(), tc.rule) + if err == nil { + tc.resp.ID = addedRule.ID + assert.Equal(t, tc.resp, addedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, addedRule)) + } + }) + } +} + +func TestViewRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + resp re.Rule + err error + }{ + { + desc: "valid rule", + id: rule.ID, + resp: rule, + err: nil, + }, + { + desc: "non existing rule", + id: generateUUID(t), + resp: re.Rule{}, + err: repoerr.ErrViewEntity, + }, + { + desc: "empty id", + id: "", + resp: re.Rule{}, + err: repoerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + retrievedRule, err := repo.ViewRule(context.Background(), tc.id) + assert.Equal(t, tc.resp, retrievedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, retrievedRule)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestUpdateRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + InputTopic: "temperature", + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + Metadata: map[string]any{ + "key": "value", + }, + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newInputChannel := generateUUID(t) + newUpdatedBy := generateUUID(t) + + cases := []struct { + desc string + rule re.Rule + resp re.Rule + err error + }{ + { + desc: "valid rule update", + rule: re.Rule{ + ID: rule.ID, + Name: "updated-name", + InputChannel: newInputChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 30", + }, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: newUpdatedBy, + Metadata: map[string]any{ + "updated": "metadata", + }, + }, + resp: re.Rule{ + ID: rule.ID, + Name: "updated-name", + DomainID: rule.DomainID, + InputChannel: newInputChannel, + InputTopic: "humidity", + Logic: re.Script{ + Type: re.LuaType, + Value: "return value > 30", + }, + Status: rule.Status, + CreatedAt: rule.CreatedAt, + CreatedBy: rule.CreatedBy, + UpdatedAt: time.Time{}, + UpdatedBy: newUpdatedBy, + Metadata: map[string]any{ + "updated": "metadata", + }, + }, + err: nil, + }, + { + desc: "update non-existing rule", + rule: re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + InputChannel: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrNotFound, + }, + { + desc: "update with invalid metadata", + rule: re.Rule{ + ID: rule.ID, + InputChannel: generateUUID(t), + Metadata: map[string]any{ + "key": make(chan int), + }, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + resp: re.Rule{}, + err: repoerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRule(context.Background(), tc.rule) + if tc.err == nil { + tc.resp.UpdatedAt = updatedRule.UpdatedAt + } + assert.Equal(t, tc.resp, updatedRule, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, updatedRule)) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestUpdateRuleStatus(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + rule re.Rule + status re.Status + err error + }{ + { + desc: "disable rule", + rule: re.Rule{ + ID: rule.ID, + Status: re.DisabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + status: re.DisabledStatus, + err: nil, + }, + { + desc: "enable rule", + rule: re.Rule{ + ID: rule.ID, + Status: re.EnabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + status: re.EnabledStatus, + err: nil, + }, + { + desc: "update non-existing rule status", + rule: re.Rule{ + ID: generateUUID(t), + Status: re.DisabledStatus, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleStatus(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.status, updatedRule.Status, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.status, updatedRule.Status)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleTags(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Tags: []string{"tag1", "tag2"}, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + rule re.Rule + tags []string + err error + }{ + { + desc: "update tags", + rule: re.Rule{ + ID: rule.ID, + Tags: []string{"newtag1", "newtag2", "newtag3"}, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + tags: []string{"newtag1", "newtag2", "newtag3"}, + err: nil, + }, + { + desc: "update non-existing rule tags", + rule: re.Rule{ + ID: generateUUID(t), + Tags: []string{"tag"}, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleTags(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.tags, updatedRule.Tags, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.tags, updatedRule.Tags)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleSchedule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newSchedule := schedule.Schedule{ + StartDateTime: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond), + Time: time.Now().UTC().Add(2 * time.Hour).Truncate(time.Microsecond), + Recurring: schedule.Weekly, + RecurringPeriod: 2, + } + + cases := []struct { + desc string + rule re.Rule + schedule schedule.Schedule + err error + }{ + { + desc: "update schedule", + rule: re.Rule{ + ID: rule.ID, + Schedule: newSchedule, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + schedule: newSchedule, + err: nil, + }, + { + desc: "update non-existing rule schedule", + rule: re.Rule{ + ID: generateUUID(t), + Schedule: newSchedule, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleSchedule(context.Background(), tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.rule.ID, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.ID, updatedRule.ID)) + assert.Equal(t, tc.schedule.Recurring, updatedRule.Schedule.Recurring, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.Recurring, updatedRule.Schedule.Recurring)) + assert.Equal(t, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.schedule.RecurringPeriod, updatedRule.Schedule.RecurringPeriod)) + assert.Equal(t, tc.rule.UpdatedBy, updatedRule.UpdatedBy, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.rule.UpdatedBy, updatedRule.UpdatedBy)) + } + }) + } +} + +func TestUpdateRuleDue(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Schedule: schedule.Schedule{ + Time: time.Now().UTC().Add(time.Hour).Truncate(time.Microsecond), + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newDue := time.Now().UTC().Add(3 * time.Hour).Truncate(time.Microsecond) + + cases := []struct { + desc string + id string + due time.Time + err error + }{ + { + desc: "update due time", + id: rule.ID, + due: newDue, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + updatedRule, err := repo.UpdateRuleDue(context.Background(), tc.id, tc.due) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.id, updatedRule.ID, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.id, updatedRule.ID)) + assert.True(t, updatedRule.Schedule.Time.Sub(tc.due) < time.Second, fmt.Sprintf("%s: expected due time close to %v got %v\n", tc.desc, tc.due, updatedRule.Schedule.Time)) + } + }) + } +} + +func TestListRules(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + channelID := generateUUID(t) + items := make([]re.Rule, 100) + + for i := range 100 { + items[i] = re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: channelID, + Tags: []string{fmt.Sprintf("tag%d", i%10)}, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + if i%2 == 0 { + items[i].Status = re.DisabledStatus + } + if i%3 == 0 { + items[i].Schedule = schedule.Schedule{ + Time: time.Now().UTC().Add(time.Duration(i) * time.Hour), + Recurring: schedule.Daily, + } + } + rule, err := repo.AddRule(context.Background(), items[i]) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + items[i].ID = rule.ID + } + + cases := []struct { + desc string + pm re.PageMeta + count int + err error + }{ + { + desc: "list first page", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + { + desc: "list with offset", + pm: re.PageMeta{ + Offset: 10, + Limit: 20, + Status: re.AllStatus, + }, + count: 20, + err: nil, + }, + { + desc: "list by domain", + pm: re.PageMeta{ + Domain: domainID, + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list by channel", + pm: re.PageMeta{ + InputChannel: channelID, + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list enabled rules", + pm: re.PageMeta{ + Status: re.EnabledStatus, + Offset: 0, + Limit: 200, + }, + count: 50, + err: nil, + }, + { + desc: "list disabled rules", + pm: re.PageMeta{ + Status: re.DisabledStatus, + Offset: 0, + Limit: 200, + }, + count: 50, + err: nil, + }, + { + desc: "list by tag", + pm: re.PageMeta{ + Tag: "tag1", + Offset: 0, + Limit: 200, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + { + desc: "list with zero limit returns all", + pm: re.PageMeta{ + Status: re.AllStatus, + }, + count: 100, + err: nil, + }, + { + desc: "list non-existing domain", + pm: re.PageMeta{ + Domain: generateUUID(t), + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list ordered by name ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: nameOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by name descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: nameOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by created_at ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by created_at descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by updated_at ascending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: updatedAtOrder, + Dir: ascDir, + }, + count: 10, + err: nil, + }, + { + desc: "list ordered by updated_at descending", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + Order: updatedAtOrder, + Dir: descDir, + }, + count: 10, + err: nil, + }, + { + desc: "list with default order (updated_at desc)", + pm: re.PageMeta{ + Offset: 0, + Limit: 10, + Status: re.AllStatus, + }, + count: 10, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListAllRules(context.Background(), tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules))) + if len(page.Rules) > 1 { + switch tc.pm.Order { + case nameOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name <= page.Rules[j].Name + }), "Expected names to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name >= page.Rules[j].Name + }), "Expected names to be sorted descending") + } + case createdAtOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.Before(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted descending") + } + case updatedAtOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].UpdatedAt.Before(page.Rules[j].UpdatedAt) + }), "Expected updated_at to be sorted ascending") + } else { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].UpdatedAt.After(page.Rules[j].UpdatedAt) + }), "Expected updated_at to be sorted descending") + } + } + } + }) + } +} + +func TestListUserRules(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + userID := generateUUID(t) + otherUserID := generateUUID(t) + channelID := generateUUID(t) + + // Create 10 rules; assign the first 4 to userID via a role. + var allRules []re.Rule + for i := range 10 { + r := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: channelID, + Logic: re.Script{Type: re.LuaType, Value: "return true"}, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute).Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), r) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + allRules = append(allRules, rule) + } + + // Assign userID to the first 4 rules via direct role INSERT. + for i := range 4 { + roleID := generateUUID(t) + _, err := db.Exec(`INSERT INTO rules_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allRules[i].ID) + assert.Nil(t, err, fmt.Sprintf("insert rules_roles unexpected error: %s", err)) + _, err = db.Exec(`INSERT INTO rules_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allRules[i].ID) + assert.Nil(t, err, fmt.Sprintf("insert rules_role_members unexpected error: %s", err)) + } + + cases := []struct { + desc string + userID string + pm re.PageMeta + count int + err error + }{ + { + desc: "list user rules returns only accessible rules", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with offset", + userID: userID, + pm: re.PageMeta{ + Offset: 2, + Limit: 100, + Status: re.AllStatus, + }, + count: 2, + err: nil, + }, + { + desc: "list user rules with limit", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 2, + Status: re.AllStatus, + }, + count: 2, + err: nil, + }, + { + desc: "list user rules with domain filter", + userID: userID, + pm: re.PageMeta{ + Domain: domainID, + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with channel filter", + userID: userID, + pm: re.PageMeta{ + InputChannel: channelID, + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules with non-existing domain returns 0", + userID: userID, + pm: re.PageMeta{ + Domain: generateUUID(t), + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list rules for user with no role assignments returns 0", + userID: otherUserID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + }, + count: 0, + err: nil, + }, + { + desc: "list user rules ordered by name ascending", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + Order: nameOrder, + Dir: ascDir, + }, + count: 4, + err: nil, + }, + { + desc: "list user rules ordered by created_at descending", + userID: userID, + pm: re.PageMeta{ + Offset: 0, + Limit: 100, + Status: re.AllStatus, + Order: createdAtOrder, + Dir: descDir, + }, + count: 4, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListUserRules(context.Background(), tc.userID, tc.pm) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + assert.Equal(t, tc.count, len(page.Rules), fmt.Sprintf("%s: expected %d rules, got %d", tc.desc, tc.count, len(page.Rules))) + if len(page.Rules) > 1 { + switch tc.pm.Order { + case nameOrder: + if tc.pm.Dir == ascDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].Name <= page.Rules[j].Name + }), "Expected names to be sorted ascending") + } + case createdAtOrder: + if tc.pm.Dir == descDir { + assert.True(t, sort.SliceIsSorted(page.Rules, func(i, j int) bool { + return page.Rules[i].CreatedAt.After(page.Rules[j].CreatedAt) + }), "Expected created_at to be sorted descending") + } + } + } + }) + } +} + +func TestRemoveRule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM rules") + assert.Nil(t, err, fmt.Sprintf("clean rules unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + rule := re.Rule{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + InputChannel: generateUUID(t), + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + Status: re.EnabledStatus, + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + UpdatedBy: generateUUID(t), + } + rule, err := repo.AddRule(context.Background(), rule) + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove existing rule", + id: rule.ID, + err: nil, + }, + { + desc: "remove non-existing rule", + id: generateUUID(t), + err: repoerr.ErrNotFound, + }, + { + desc: "remove already removed rule", + id: rule.ID, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.RemoveRule(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func generateUUID(t *testing.T) string { + ulid, err := idProvider.ID() + assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + return ulid +} diff --git a/re/postgres/rule.go b/re/postgres/rule.go new file mode 100644 index 000000000..927bb7a6a --- /dev/null +++ b/re/postgres/rule.go @@ -0,0 +1,161 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "database/sql" + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/re" + "github.com/jackc/pgtype" +) + +// dbRule represents the database structure for a Rule. +type dbRule struct { + ID string `db:"id"` + Name string `db:"name"` + DomainID string `db:"domain_id"` + Tags pgtype.TextArray `db:"tags,omitempty"` + Metadata []byte `db:"metadata,omitempty"` + InputChannel string `db:"input_channel"` + InputTopic sql.NullString `db:"input_topic"` + LogicType re.ScriptType `db:"logic_type"` + LogicValue string `db:"logic_value"` + Outputs []byte `db:"outputs"` + StartDateTime sql.NullTime `db:"start_datetime"` + Time sql.NullTime `db:"time"` + Recurring schedule.Recurring `db:"recurring"` + RecurringPeriod uint `db:"recurring_period"` + Status re.Status `db:"status"` + CreatedAt time.Time `db:"created_at"` + CreatedBy string `db:"created_by"` + UpdatedAt time.Time `db:"updated_at"` + UpdatedBy string `db:"updated_by"` + MemberID string `db:"member_id,omitempty"` + Roles json.RawMessage `db:"roles,omitempty"` +} + +func ruleToDb(r re.Rule) (dbRule, error) { + metadata := []byte("{}") + if len(r.Metadata) > 0 { + b, err := json.Marshal(r.Metadata) + if err != nil { + return dbRule{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + metadata = b + } + + start := sql.NullTime{Time: r.Schedule.StartDateTime} + if !r.Schedule.StartDateTime.IsZero() { + start.Valid = true + } + t := sql.NullTime{Time: r.Schedule.Time} + if !r.Schedule.Time.IsZero() { + t.Valid = true + } + var tags pgtype.TextArray + if err := tags.Set(r.Tags); err != nil { + return dbRule{}, err + } + + outputs, err := json.Marshal(r.Outputs) + if err != nil { + return dbRule{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + + return dbRule{ + ID: r.ID, + Name: r.Name, + DomainID: r.DomainID, + Tags: tags, + Metadata: metadata, + InputChannel: r.InputChannel, + InputTopic: toNullString(r.InputTopic), + LogicType: r.Logic.Type, + LogicValue: r.Logic.Value, + Outputs: outputs, + StartDateTime: start, + Time: t, + Recurring: r.Schedule.Recurring, + RecurringPeriod: r.Schedule.RecurringPeriod, + Status: r.Status, + CreatedAt: r.CreatedAt, + CreatedBy: r.CreatedBy, + UpdatedAt: r.UpdatedAt, + UpdatedBy: r.UpdatedBy, + }, nil +} + +func dbToRule(dto dbRule) (re.Rule, error) { + var metadata re.Metadata + if dto.Metadata != nil { + if err := json.Unmarshal(dto.Metadata, &metadata); err != nil { + return re.Rule{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + var tags []string + for _, e := range dto.Tags.Elements { + tags = append(tags, e.String) + } + + var outputs re.Outputs + if dto.Outputs != nil { + if err := json.Unmarshal(dto.Outputs, &outputs); err != nil { + return re.Rule{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + var roles []roles.MemberRoleActions + if dto.Roles != nil { + if err := json.Unmarshal(dto.Roles, &roles); err != nil { + return re.Rule{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + return re.Rule{ + ID: dto.ID, + Name: dto.Name, + DomainID: dto.DomainID, + Tags: tags, + Metadata: metadata, + InputChannel: dto.InputChannel, + InputTopic: fromNullString(dto.InputTopic), + Logic: re.Script{ + Type: dto.LogicType, + Value: dto.LogicValue, + }, + Outputs: outputs, + Schedule: schedule.Schedule{ + StartDateTime: dto.StartDateTime.Time, + Time: dto.Time.Time, + Recurring: dto.Recurring, + RecurringPeriod: dto.RecurringPeriod, + }, + Status: dto.Status, + CreatedAt: dto.CreatedAt, + CreatedBy: dto.CreatedBy, + UpdatedAt: dto.UpdatedAt, + UpdatedBy: dto.UpdatedBy, + Roles: roles, + }, nil +} + +func toNullString(value string) sql.NullString { + if value == "" { + return sql.NullString{Valid: false} + } + return sql.NullString{String: value, Valid: true} +} + +func fromNullString(nullString sql.NullString) string { + if !nullString.Valid { + return "" + } + return nullString.String +} diff --git a/re/postgres/setup_test.go b/re/postgres/setup_test.go new file mode 100644 index 000000000..10d2dc962 --- /dev/null +++ b/re/postgres/setup_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/absmach/supermq/pkg/postgres" + repostgres "github.com/absmach/supermq/re/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + migration, err := repostgres.Migration() + if err != nil { + log.Fatalf("Could not get migration: %s", err) + } + if db, err = postgres.Setup(dbConfig, *migration); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/re/rule.go b/re/rule.go new file mode 100644 index 000000000..a92b61804 --- /dev/null +++ b/re/rule.go @@ -0,0 +1,258 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/re/outputs" +) + +const ( + LuaType ScriptType = iota + GoType +) + +const TimeLayout = "2006-01-02T15:04:05.999999Z" + +type ( + // ScriptType indicates Runtime type for the future versions + // that will support JS or Go runtimes alongside Lua. + ScriptType uint + + Metadata map[string]any + Script struct { + Type ScriptType `json:"type"` + Value string `json:"value"` + } +) + +var outputRegistry = map[outputs.OutputType]func() Runnable{ + outputs.AlarmsType: func() Runnable { return &outputs.Alarm{} }, + outputs.EmailType: func() Runnable { return &outputs.Email{} }, + outputs.SaveRemotePgType: func() Runnable { return &outputs.Postgres{} }, + outputs.ChannelsType: func() Runnable { return &outputs.ChannelPublisher{} }, + outputs.SaveSenMLType: func() Runnable { return &outputs.SenML{} }, + outputs.SlackType: func() Runnable { return &outputs.Slack{} }, +} + +type Rule struct { + ID string `json:"id"` + Name string `json:"name"` + DomainID string `json:"domain"` + Metadata Metadata `json:"metadata,omitempty"` + Tags []string `json:"tags,omitempty"` + InputChannel string `json:"input_channel"` + InputTopic string `json:"input_topic"` + Logic Script `json:"logic"` + Outputs Outputs `json:"outputs,omitempty"` + Schedule schedule.Schedule `json:"schedule,omitempty"` + Status Status `json:"status"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by"` + Roles []roles.MemberRoleActions `json:"roles,omitempty"` +} + +// EventEncode converts a Rule struct to map[string]any at event producer. +func (r Rule) EventEncode() (map[string]any, error) { + m := map[string]any{ + "id": r.ID, + "name": r.Name, + "created_at": r.CreatedAt.Format(TimeLayout), + "created_by": r.CreatedBy, + "schedule": r.Schedule.EventEncode(), + "status": r.Status.String(), + } + + if r.Name != "" { + m["name"] = r.Name + } + + if r.DomainID != "" { + m["domain"] = r.DomainID + } + + if !r.UpdatedAt.IsZero() { + m["updated_at"] = r.UpdatedAt.Format(TimeLayout) + } + + if r.UpdatedBy != "" { + m["updated_by"] = r.UpdatedBy + } + + if len(r.Metadata) > 0 { + m["metadata"] = r.Metadata + } + + if len(r.Tags) > 0 { + m["tags"] = r.Tags + } + + if r.InputChannel != "" { + m["input_channel"] = r.InputChannel + } + + if r.InputTopic != "" { + m["input_topic"] = r.InputTopic + } + + if r.Logic.Value != "" { + m["logic"] = map[string]any{ + "type": r.Logic.Type, + "value": r.Logic.Value, + } + } + + return m, nil +} + +type Outputs []Runnable + +func (o *Outputs) UnmarshalJSON(data []byte) error { + var rawList []json.RawMessage + if err := json.Unmarshal(data, &rawList); err != nil { + return err + } + + var runnables []Runnable + for _, raw := range rawList { + var meta struct { + Type outputs.OutputType `json:"type"` + } + if err := json.Unmarshal(raw, &meta); err != nil { + return err + } + + factory, ok := outputRegistry[meta.Type] + if !ok { + return errors.New("unknown output type: " + meta.Type.String()) + } + + instance := factory() + if err := json.Unmarshal(raw, instance); err != nil { + return err + } + + runnables = append(runnables, instance) + } + v := Outputs(runnables) + *o = v + return nil +} + +type Runnable interface { + Run(ctx context.Context, msg *messaging.Message, val any) error +} + +// PageMeta contains page metadata that helps navigation. +type PageMeta struct { + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + Dir string `json:"dir" db:"dir"` + Order string `json:"order" db:"order"` + Name string `json:"name" db:"name"` + InputChannel string `json:"input_channel,omitempty" db:"input_channel"` + InputTopic *string `json:"input_topic,omitempty" db:"input_topic"` + Scheduled *bool `json:"scheduled,omitempty"` + OutputChannel string `json:"output_channel,omitempty" db:"output_channel"` + Status Status `json:"status,omitempty" db:"status"` + Domain string `json:"domain_id,omitempty" db:"domain_id"` + Tag string `json:"tag,omitempty"` + ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time + ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time + Recurring *schedule.Recurring `json:"recurring,omitempty" db:"recurring"` // Filter by recurring type + UserID string `json:"user_id,omitempty" db:"user_id"` +} + +// EventEncode converts a PageMeta struct to map[string]any. +func (pm PageMeta) EventEncode() map[string]any { + m := map[string]any{ + "total": pm.Total, + "offset": pm.Offset, + "limit": pm.Limit, + "status": pm.Status.String(), + "domain_id": pm.Domain, + } + + if pm.Dir != "" { + m["dir"] = pm.Dir + } + if pm.Name != "" { + m["name"] = pm.Name + } + if pm.InputChannel != "" { + m["input_channel"] = pm.InputChannel + } + if pm.InputTopic != nil { + m["input_topic"] = *pm.InputTopic + } + if pm.Scheduled != nil { + m["scheduled"] = *pm.Scheduled + } + if pm.OutputChannel != "" { + m["output_channel"] = pm.OutputChannel + } + if pm.Tag != "" { + m["tag"] = pm.Tag + } + if pm.ScheduledBefore != nil { + m["scheduled_before"] = pm.ScheduledBefore.Format(time.RFC3339Nano) + } + if pm.ScheduledAfter != nil { + m["scheduled_after"] = pm.ScheduledAfter.Format(time.RFC3339Nano) + } + if pm.Recurring != nil { + m["recurring"] = pm.Recurring.String() + } + + return m +} + +type Page struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Rules []Rule `json:"rules"` +} + +type Service interface { + messaging.MessageHandler + AddRule(ctx context.Context, session authn.Session, r Rule) (Rule, []roles.RoleProvision, error) + ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error) + UpdateRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) + UpdateRuleTags(ctx context.Context, session authn.Session, r Rule) (Rule, error) + UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error) + ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) + RemoveRule(ctx context.Context, session authn.Session, id string) error + EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) + DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) + + StartScheduler(ctx context.Context) error + roles.RoleManager +} + +type Repository interface { + AddRule(ctx context.Context, r Rule) (Rule, error) + ViewRule(ctx context.Context, id string) (Rule, error) + RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (Rule, error) + UpdateRule(ctx context.Context, r Rule) (Rule, error) + UpdateRuleTags(ctx context.Context, r Rule) (Rule, error) + UpdateRuleSchedule(ctx context.Context, r Rule) (Rule, error) + RemoveRule(ctx context.Context, id string) error + UpdateRuleStatus(ctx context.Context, r Rule) (Rule, error) + ListAllRules(ctx context.Context, pm PageMeta) (Page, error) + ListUserRules(ctx context.Context, userID string, pm PageMeta) (Page, error) + UpdateRuleDue(ctx context.Context, id string, due time.Time) (Rule, error) + roles.Repository +} diff --git a/re/service.go b/re/service.go new file mode 100644 index 000000000..5b0d8c4da --- /dev/null +++ b/re/service.go @@ -0,0 +1,238 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "context" + "time" + + "github.com/absmach/supermq" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/emailer" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/ticker" + "github.com/absmach/supermq/re/operations" +) + +var ( + ErrGoroutinesNotAllowed = errors.New("goroutines are not allowed in Go scripts") + ErrPanicNotAllowed = errors.New("panic is not allowed in Go scripts") +) + +type re struct { + repo Repository + runInfo chan pkglog.RunInfo + idp supermq.IDProvider + rePubSub messaging.PubSub + writersPub messaging.Publisher + alarmsPub messaging.Publisher + ticker ticker.Ticker + email emailer.Emailer + readers grpcReadersV1.ReadersServiceClient + roles.ProvisionManageService +} + +func NewService(repo Repository, runInfo chan pkglog.RunInfo, policy policies.Service, idp supermq.IDProvider, rePubSub messaging.PubSub, writersPub, alarmsPub messaging.Publisher, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient, availableActions []roles.Action, builtInRoles map[roles.BuiltInRoleName][]roles.Action) (Service, error) { + rpms, err := roles.NewProvisionManageService(operations.EntityType, repo, policy, idp, availableActions, builtInRoles) + if err != nil { + return nil, err + } + return &re{ + repo: repo, + idp: idp, + runInfo: runInfo, + rePubSub: rePubSub, + writersPub: writersPub, + alarmsPub: alarmsPub, + ticker: tck, + email: emailer, + readers: readers, + ProvisionManageService: rpms, + }, nil +} + +func (re *re) AddRule(ctx context.Context, session authn.Session, r Rule) (retRule Rule, retRps []roles.RoleProvision, retErr error) { + if r.Logic.Type == GoType && goKeywordRegex.MatchString(r.Logic.Value) { + return Rule{}, nil, errors.Wrap(svcerr.ErrMalformedEntity, ErrGoroutinesNotAllowed) + } + if r.Logic.Type == GoType && panicRegex.MatchString(r.Logic.Value) { + return Rule{}, nil, errors.Wrap(svcerr.ErrMalformedEntity, ErrPanicNotAllowed) + } + + id, err := re.idp.ID() + if err != nil { + return Rule{}, nil, err + } + now := time.Now().UTC() + r.CreatedAt = now + r.ID = id + r.CreatedBy = session.UserID + r.DomainID = session.DomainID + r.Status = EnabledStatus + + if !r.Schedule.StartDateTime.IsZero() { + r.Schedule.StartDateTime = now + } + r.Schedule.Time = r.Schedule.StartDateTime + + rule, err := re.repo.AddRule(ctx, r) + if err != nil { + return Rule{}, nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + defer func() { + if retErr != nil { + if errRollBack := re.repo.RemoveRule(ctx, rule.ID); errRollBack != nil { + retErr = errors.Wrap(retErr, errors.Wrap(svcerr.ErrRollbackRepo, errRollBack)) + } + } + }() + + newBuiltInRoleMembers := map[roles.BuiltInRoleName][]roles.Member{ + BuiltInRoleAdmin: {roles.Member(session.UserID)}, + } + + optionalPolicies := []policies.Policy{ + { + SubjectType: policies.DomainType, + Subject: session.DomainID, + Relation: policies.DomainRelation, + ObjectType: operations.EntityType, + Object: rule.ID, + }, + } + + rps, err := re.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{rule.ID}, optionalPolicies, newBuiltInRoleMembers) + if err != nil { + return Rule{}, nil, errors.Wrap(svcerr.ErrAddPolicies, err) + } + + return rule, rps, nil +} + +func (re *re) ViewRule(ctx context.Context, session authn.Session, id string, withRoles bool) (Rule, error) { + var rule Rule + var err error + switch withRoles { + case true: + rule, err = re.repo.RetrieveByIDWithRoles(ctx, id, session.UserID) + default: + rule, err = re.repo.ViewRule(ctx, id) + } + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + return rule, nil +} + +func (re *re) UpdateRule(ctx context.Context, session authn.Session, r Rule) (Rule, error) { + if r.Logic.Type == GoType && goKeywordRegex.MatchString(r.Logic.Value) { + return Rule{}, errors.Wrap(svcerr.ErrMalformedEntity, ErrGoroutinesNotAllowed) + } + if r.Logic.Type == GoType && panicRegex.MatchString(r.Logic.Value) { + return Rule{}, errors.Wrap(svcerr.ErrMalformedEntity, ErrPanicNotAllowed) + } + + r.UpdatedAt = time.Now().UTC() + r.UpdatedBy = session.UserID + rule, err := re.repo.UpdateRule(ctx, r) + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return rule, nil +} + +func (re *re) UpdateRuleTags(ctx context.Context, session authn.Session, r Rule) (Rule, error) { + r.UpdatedAt = time.Now().UTC() + r.UpdatedBy = session.UserID + rule, err := re.repo.UpdateRuleTags(ctx, r) + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return rule, nil +} + +func (re *re) UpdateRuleSchedule(ctx context.Context, session authn.Session, r Rule) (Rule, error) { + r.UpdatedAt = time.Now().UTC() + r.UpdatedBy = session.UserID + rule, err := re.repo.UpdateRuleSchedule(ctx, r) + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return rule, nil +} + +func (re *re) ListRules(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) { + pm.Domain = session.DomainID + if session.SuperAdmin { + page, err := re.repo.ListAllRules(ctx, pm) + if err != nil { + return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil + } + page, err := re.repo.ListUserRules(ctx, session.UserID, pm) + if err != nil { + return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil +} + +func (re *re) RemoveRule(ctx context.Context, session authn.Session, id string) error { + if err := re.repo.RemoveRule(ctx, id); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + return nil +} + +func (re *re) EnableRule(ctx context.Context, session authn.Session, id string) (Rule, error) { + status, err := ToStatus(Enabled) + if err != nil { + return Rule{}, err + } + r := Rule{ + ID: id, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, + Status: status, + } + rule, err := re.repo.UpdateRuleStatus(ctx, r) + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + return rule, nil +} + +func (re *re) DisableRule(ctx context.Context, session authn.Session, id string) (Rule, error) { + status, err := ToStatus(Disabled) + if err != nil { + return Rule{}, err + } + r := Rule{ + ID: id, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, + Status: status, + } + rule, err := re.repo.UpdateRuleStatus(ctx, r) + if err != nil { + return Rule{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + return rule, nil +} + +func (re *re) Cancel() error { + return nil +} diff --git a/re/service_test.go b/re/service_test.go new file mode 100644 index 000000000..a26e827de --- /dev/null +++ b/re/service_test.go @@ -0,0 +1,2070 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re_test + +import ( + "context" + "fmt" + "log/slog" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/authn" + emocks "github.com/absmach/supermq/pkg/emailer/mocks" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/messaging" + pubsubmocks "github.com/absmach/supermq/pkg/messaging/mocks" + policymocks "github.com/absmach/supermq/pkg/policies/mocks" + "github.com/absmach/supermq/pkg/roles" + pkgSch "github.com/absmach/supermq/pkg/schedule" + tmocks "github.com/absmach/supermq/pkg/ticker/mocks" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/re" + "github.com/absmach/supermq/re/mocks" + "github.com/absmach/supermq/re/outputs" + readmocks "github.com/absmach/supermq/readers/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// unknownOutput is a mock output type that doesn't match any known output type. +type unknownOutput struct{} + +func (u *unknownOutput) Run(ctx context.Context, msg *messaging.Message, val any) error { + return nil +} + +func (u *unknownOutput) MarshalJSON() ([]byte, error) { + return []byte(`{"type": "unknown"}`), nil +} + +var ( + namegen = namegenerator.NewGenerator() + userID = testsutil.GenerateUUID(&testing.T{}) + domainID = testsutil.GenerateUUID(&testing.T{}) + ruleName = namegen.Generate() + ruleID = testsutil.GenerateUUID(&testing.T{}) + Tags = []string{"tag1", "tag2"} + inputChannel = "test.channel" + StartDateTime = time.Now().Add(-time.Hour) + schedule = pkgSch.Schedule{ + StartDateTime: StartDateTime, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: time.Now().Add(-time.Hour), + } +) + +func newService(t *testing.T, runInfo chan pkglog.RunInfo) (re.Service, *mocks.Repository, *pubsubmocks.PubSub, *tmocks.Ticker, *emocks.Emailer, *policymocks.Service) { + repo := new(mocks.Repository) + mockTicker := new(tmocks.Ticker) + idProvider := uuid.NewMock() + pubsub := pubsubmocks.NewPubSub(t) + readersSvc := new(readmocks.ReadersServiceClient) + e := new(emocks.Emailer) + policy := new(policymocks.Service) + availableActions := []roles.Action{} + builtInRoles := map[roles.BuiltInRoleName][]roles.Action{ + "admin": availableActions, + } + svc, err := re.NewService(repo, runInfo, policy, idProvider, pubsub, pubsub, pubsub, mockTicker, e, readersSvc, availableActions, builtInRoles) + if err != nil { + t.Fatalf("Failed to create service: %v", err) + } + return svc, repo, pubsub, mockTicker, e, policy +} + +func TestAddRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo)) + ruleName := namegen.Generate() + now := time.Now().Add(time.Hour) + cases := []struct { + desc string + session authn.Session + rule re.Rule + res re.Rule + err error + addPoliciesErr error + deletePolicies error + addRoleErr error + deleteErr error + }{ + { + desc: "Add rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: nil, + addPoliciesErr: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + err: repoerr.ErrCreateEntity, + addPoliciesErr: nil, + deletePolicies: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add rule with non-zero StartDateTime", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + Time: now.Add(2 * time.Hour), + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + Time: now.Add(2 * time.Hour), + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: nil, + addPoliciesErr: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add rule with failed to add roles and failed to delete policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addRoleErr: svcerr.ErrCreateEntity, + deletePolicies: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "Add rule with failed to add policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrAddPolicies, + }, + { + desc: "Add rule with failed to add policies and failed rollback", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + deleteErr: svcerr.ErrRemoveEntity, + err: svcerr.ErrRollbackRepo, + }, + { + desc: "Add rule with failed to add roles", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addRoleErr: svcerr.ErrCreateEntity, + err: svcerr.ErrAddPolicies, + }, + { + desc: "Add rule with Go script containing goroutines", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Logic: re.Script{ + Type: re.GoType, + Value: `func logicFunction() any { go func() {}(); return true }`, + }, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + err: re.ErrGoroutinesNotAllowed, + addPoliciesErr: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add rule with Go script containing panic", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Logic: re.Script{ + Type: re.GoType, + Value: `func logicFunction() any { panic("error"); return true }`, + }, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + err: re.ErrPanicNotAllowed, + addPoliciesErr: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add rule with failed to add roles and failed to delete policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addRoleErr: svcerr.ErrCreateEntity, + deletePolicies: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "Add rule with failed to add policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrAddPolicies, + }, + { + desc: "Add rule with failed to add policies and failed rollback", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addPoliciesErr: svcerr.ErrAuthorization, + deleteErr: svcerr.ErrRemoveEntity, + err: svcerr.ErrRollbackRepo, + }, + { + desc: "Add rule with failed to add roles", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + }, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + addRoleErr: svcerr.ErrCreateEntity, + err: svcerr.ErrAddPolicies, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("AddRule", mock.Anything, mock.Anything).Return(tc.res, tc.err) + policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr) + policyCall2 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePolicies).Maybe() + repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr) + repoCall2 := repo.On("Remove", context.Background(), mock.Anything).Return(tc.deleteErr).Maybe() + res, _, err := svc.AddRule(context.Background(), tc.session, tc.rule) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.NotEmpty(t, res.ID, "expected non-empty result in ID") + assert.Equal(t, tc.rule.Name, res.Name) + assert.Equal(t, tc.rule.Schedule, res.Schedule) + } + policyCall.Unset() + policyCall2.Unset() + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestViewRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + now := time.Now().Add(time.Hour) + cases := []struct { + desc string + session authn.Session + id string + res re.Rule + err error + }{ + { + desc: "view rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + res: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: nil, + }, + { + desc: "view rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("ViewRule", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.ViewRule(context.Background(), tc.session, tc.id, false) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestUpdateRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + newName := namegen.Generate() + now := time.Now().Add(time.Hour) + cases := []struct { + desc string + session authn.Session + rule re.Rule + res re.Rule + err error + }{ + { + desc: "update rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: newName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + res: re.Rule{ + Name: newName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + UpdatedAt: now, + UpdatedBy: userID, + }, + err: nil, + }, + { + desc: "update rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: svcerr.ErrUpdateEntity, + }, + { + desc: "update rule with Go script containing goroutines", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Logic: re.Script{ + Type: re.GoType, + Value: `func logicFunction() any { go processData(); return true }`, + }, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: re.ErrGoroutinesNotAllowed, + }, + { + desc: "Update rule with Go script containing panic", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + rule: re.Rule{ + Name: ruleName, + ID: ruleID, + InputChannel: inputChannel, + Logic: re.Script{ + Type: re.GoType, + Value: `func logicFunction() any { panic("test panic"); return true }`, + }, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + }, + Status: re.EnabledStatus, + CreatedBy: userID, + DomainID: domainID, + }, + err: re.ErrPanicNotAllowed, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRule", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.UpdateRule(context.Background(), tc.session, tc.rule) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestUpdateRuleTags(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + updateReq re.Rule + repoResp re.Rule + repoErr error + err error + }{ + { + desc: "update rule tags successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Tags: []string{"tag1", "tag2"}, + }, + repoResp: re.Rule{ + ID: testsutil.GenerateUUID(t), + Tags: []string{"tag1", "tag2"}, + }, + }, + { + desc: "update rule tags with repo error", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Tags: []string{"tag1", "tag2"}, + }, + repoErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRuleTags", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr) + got, err := svc.UpdateRuleTags(context.Background(), tc.session, tc.updateReq) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.repoResp, got) + ok := repo.AssertCalled(t, "UpdateRuleTags", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("UpdateTags was not called on %s", tc.desc)) + } + repoCall.Unset() + }) + } +} + +func TestUpdateRuleSchedule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + now := time.Now().UTC() + future := now.Add(2 * time.Hour) + newSchedule := pkgSch.Schedule{ + StartDateTime: future, + Time: future.Add(time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 2, + } + + cases := []struct { + desc string + session authn.Session + updateReq re.Rule + repoResp re.Rule + repoErr error + err error + }{ + { + desc: "update rule schedule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + }, + repoResp: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + UpdatedAt: now, + UpdatedBy: userID, + }, + }, + { + desc: "update rule schedule with repo error", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + updateReq: re.Rule{ + ID: testsutil.GenerateUUID(t), + Schedule: newSchedule, + }, + repoErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRuleSchedule", context.Background(), mock.Anything).Return(tc.repoResp, tc.repoErr) + got, err := svc.UpdateRuleSchedule(context.Background(), tc.session, tc.updateReq) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) + if err == nil { + assert.Equal(t, tc.repoResp, got) + ok := repo.AssertCalled(t, "UpdateRuleSchedule", context.Background(), mock.Anything) + assert.True(t, ok, fmt.Sprintf("UpdateRuleSchedule was not called on %s", tc.desc)) + } + repoCall.Unset() + }) + } +} + +func TestListRules(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + numRules := 50 + now := time.Now().Add(time.Hour) + var rules []re.Rule + for i := 0; i < numRules; i++ { + r := re.Rule{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: re.EnabledStatus, + CreatedAt: now, + CreatedBy: userID, + Schedule: pkgSch.Schedule{ + Recurring: pkgSch.Daily, + Time: now.Add(1 * time.Hour), + RecurringPeriod: 1, + StartDateTime: now, + }, + } + rules = append(rules, r) + } + + goRule := re.Rule{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: re.EnabledStatus, + CreatedAt: now, + CreatedBy: userID, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + } + + cases := []struct { + desc string + session authn.Session + pageMeta re.PageMeta + res re.Page + err error + superAdmin bool + }{ + { + desc: "list rules successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{}, + res: re.Page{ + Total: uint64(numRules), + Offset: 0, + Limit: 10, + Rules: rules[0:10], + }, + err: nil, + }, + { + desc: "list rules with go type", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{}, + res: re.Page{ + Total: 1, + Offset: 0, + Limit: 10, + Rules: []re.Rule{goRule}, + }, + err: nil, + }, + { + desc: "list rules successfully with limit", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{ + Limit: 100, + }, + res: re.Page{ + Total: uint64(numRules), + Offset: 0, + Limit: 100, + Rules: rules[0:numRules], + }, + err: nil, + }, + { + desc: "list rules successfully with offset", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{ + Offset: 20, + Limit: 10, + }, + res: re.Page{ + Total: uint64(numRules), + Offset: 20, + Limit: 10, + Rules: rules[20:30], + }, + err: nil, + }, + { + desc: "list rules with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: re.PageMeta{}, + err: svcerr.ErrViewEntity, + }, + { + desc: "list rules as super admin successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: re.PageMeta{}, + res: re.Page{ + Total: uint64(numRules), + Offset: 0, + Limit: 10, + Rules: rules[0:10], + }, + superAdmin: true, + err: nil, + }, + { + desc: "list rules as super admin with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: re.PageMeta{}, + superAdmin: true, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var repoCall *mock.Call + if tc.superAdmin { + repoCall = repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.res, tc.err) + } else { + repoCall = repo.On("ListUserRules", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err) + } + res, err := svc.ListRules(context.Background(), tc.session, tc.pageMeta) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestRemoveRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, policies := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + id string + err error + deletePoliciesErr error + }{ + { + desc: "remove rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + err: nil, + deletePoliciesErr: nil, + }, + { + desc: "remove rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + err: svcerr.ErrRemoveEntity, + deletePoliciesErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RemoveRule", mock.Anything, mock.Anything).Return(tc.err) + policyCall := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr) + err := svc.RemoveRule(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + policyCall.Unset() + repoCall.Unset() + }) + } +} + +func TestEnableRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + now := time.Now() + + cases := []struct { + desc string + session authn.Session + id string + status re.Status + res re.Rule + err error + }{ + { + desc: "enable rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + status: re.EnabledStatus, + res: re.Rule{ + ID: ruleID, + Name: ruleName, + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: schedule, + UpdatedBy: userID, + UpdatedAt: now, + }, + err: nil, + }, + { + desc: "enable rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + status: re.EnabledStatus, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRuleStatus", context.Background(), mock.Anything).Return(tc.res, tc.err) + res, err := svc.EnableRule(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestDisableRule(t *testing.T) { + // nolint:dogsled + svc, repo, _, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + now := time.Now() + + cases := []struct { + desc string + session authn.Session + id string + status re.Status + res re.Rule + err error + }{ + { + desc: "disable rule successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + status: re.DisabledStatus, + res: re.Rule{ + ID: ruleID, + Name: ruleName, + DomainID: domainID, + InputChannel: inputChannel, + Status: re.DisabledStatus, + Schedule: schedule, + UpdatedBy: userID, + UpdatedAt: now, + }, + err: nil, + }, + { + desc: "disable rule with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: ruleID, + status: re.DisabledStatus, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateRuleStatus", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.DisableRule(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestHandle(t *testing.T) { + svc, repo, pubmocks, _, emailer, _ := newService(t, make(chan pkglog.RunInfo)) + now := time.Now() + scheduled := false + + cases := []struct { + desc string + message *messaging.Message + page re.Page + listErr error + publishErr error + expectErr bool + }{ + { + desc: "consume message with empty rules", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{}, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script returning true", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return message.payload", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script returning false", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return false", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script with no outputs", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return message.payload", + }, + Outputs: re.Outputs{}, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script returning nil", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "return nil", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script with invalid syntax", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: "invalid lua syntax {{{", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Alarm output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 30.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return {severity = 2, description = "High temperature"}`, + }, + Outputs: re.Outputs{ + &outputs.Alarm{ + RuleID: testsutil.GenerateUUID(t), + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and SenML output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return {bn = "sensor1", n = "temperature", v = 25.5}`, + }, + Outputs: re.Outputs{ + &outputs.SenML{}, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Email output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Email{ + To: []string{"test@example.com"}, + Subject: "Temperature Alert", + Content: "Temperature: {{.Result}}", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with rules using GoType", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType logic returning false", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return false }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType invalid logic value", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "invalid go code {{{", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType missing logicFunction", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func someOtherFunc() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType invalid function signature", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "var logicFunction = 42", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType function logicFunction properly named", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func logicFunction() any { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType returning non-bool", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() any { return \"not a bool\" }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType and JSON payload", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25, "humidity": 60}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType and invalid JSON payload", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`invalid json {{{`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with GoType script that panics", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"value": 42}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.GoType, + Value: `func logicFunction() any { panic("test") }`, + }, + Outputs: re.Outputs{ + &outputs.ChannelPublisher{ + Channel: "output.channel", + Topic: "output.topic", + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Postgres output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5, "humidity": 60}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Postgres{ + Host: "localhost", + Port: 5432, + User: "test", + Password: "test", + Database: "testdb", + Table: "sensor_data", + Mapping: `{"temperature": {{.Result.temperature}}, "humidity": {{.Result.humidity}}}`, + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and Slack output", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &outputs.Slack{ + Token: "xoxb-test-token", + ChannelID: "C12345678", + Message: `{"text": "Temperature: {{.Result.temperature}}"}`, + }, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + { + desc: "consume message with Lua script and unknown output type", + message: &messaging.Message{ + Channel: inputChannel, + Created: now.Unix(), + Payload: []byte(`{"temperature": 25.5}`), + }, + page: re.Page{ + Rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + InputChannel: inputChannel, + Status: re.EnabledStatus, + Logic: re.Script{ + Type: re.LuaType, + Value: `return message.payload`, + }, + Outputs: re.Outputs{ + &unknownOutput{}, + }, + Schedule: schedule, + }, + }, + }, + listErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var err error + + repoCall := repo.On("ListAllRules", mock.Anything, re.PageMeta{Domain: tc.message.Domain, InputChannel: tc.message.Channel, Scheduled: &scheduled}).Return(tc.page, tc.listErr).Run(func(args mock.Arguments) { + if tc.listErr != nil { + err = tc.listErr + } + }) + repoCall1 := pubmocks.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(tc.publishErr).Maybe() + repoCall2 := emailer.On("SendEmailNotification", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + + err = svc.Handle(tc.message) + assert.Nil(t, err) + + time.Sleep(100 * time.Millisecond) + + assert.True(t, errors.Contains(err, tc.listErr), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.listErr, err)) + + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestStartScheduler(t *testing.T) { + now := time.Now().Truncate(time.Minute) + ri := make(chan pkglog.RunInfo) + // nolint:dogsled + svc, repo, _, ticker, _, _ := newService(t, ri) + + ctxCases := []struct { + desc string + err error + pageMeta re.PageMeta + page re.Page + listErr error + setupCtx func() (context.Context, context.CancelFunc) + }{ + { + desc: "start scheduler with canceled context", + err: context.Canceled, + pageMeta: re.PageMeta{ + Status: re.EnabledStatus, + ScheduledBefore: &now, + }, + setupCtx: func() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return ctx, cancel + }, + }, + { + desc: "start scheduler with timeout", + err: context.DeadlineExceeded, + pageMeta: re.PageMeta{ + Status: re.EnabledStatus, + ScheduledBefore: &now, + }, + setupCtx: func() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), time.Millisecond) + }, + }, + { + desc: "start scheduler with deadline exceeded", + err: context.DeadlineExceeded, + pageMeta: re.PageMeta{ + Status: re.EnabledStatus, + ScheduledBefore: &now, + }, + page: re.Page{}, + setupCtx: func() (context.Context, context.CancelFunc) { + return context.WithDeadline(context.Background(), time.Now().Add(time.Millisecond)) + }, + }, + } + + for _, tc := range ctxCases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(tc.page, tc.listErr) + tickChan := make(chan time.Time) + tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan)) + tickCall1 := ticker.On("Stop").Return() + ctx, cancel := tc.setupCtx() + defer cancel() + errc := make(chan error) + + go func() { + errc <- svc.StartScheduler(ctx) + }() + + err := <-errc + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v but got %v", tc.err, err)) + repoCall.Unset() + tickCall.Unset() + tickCall1.Unset() + }) + } + + schedulerCases := []struct { + desc string + rules []re.Rule + listErr error + updateDueErr error + expectedRunInfo int + }{ + { + desc: "start scheduler with successful rule processing", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + }, + listErr: nil, + updateDueErr: nil, + expectedRunInfo: 1, + }, + { + desc: "start scheduler with multiple rules", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Weekly, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.GoType, + Value: "func() bool { return true }", + }, + }, + }, + listErr: nil, + updateDueErr: nil, + expectedRunInfo: 2, + }, + { + desc: "start scheduler with list rules error", + rules: []re.Rule{}, + listErr: repoerr.ErrViewEntity, + updateDueErr: nil, + expectedRunInfo: 1, + }, + { + desc: "start scheduler with update due error", + rules: []re.Rule{ + { + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + InputChannel: inputChannel, + Status: re.EnabledStatus, + Schedule: pkgSch.Schedule{ + StartDateTime: now.Add(-time.Hour), + Time: now.Add(time.Hour), + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + }, + Logic: re.Script{ + Type: re.LuaType, + Value: "return true", + }, + }, + }, + listErr: nil, + updateDueErr: repoerr.ErrUpdateEntity, + expectedRunInfo: 1, + }, + } + + for _, tc := range schedulerCases { + t.Run(tc.desc, func(t *testing.T) { + page := re.Page{ + Rules: tc.rules, + Total: uint64(len(tc.rules)), + } + + repoCall := repo.On("ListAllRules", mock.Anything, mock.Anything).Return(page, tc.listErr) + repoCall2 := repo.On("UpdateRuleDue", mock.Anything, mock.Anything, mock.Anything).Return(re.Rule{}, tc.updateDueErr) + tickChan := make(chan time.Time, 1) + tickCall := ticker.On("Tick").Return((<-chan time.Time)(tickChan)) + tickCall1 := ticker.On("Stop").Return() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + _ = svc.StartScheduler(ctx) + }() + + tickChan <- now + + collected := 0 + timeout := time.After(500 * time.Millisecond) + for collected < tc.expectedRunInfo { + select { + case info := <-ri: + collected++ + if tc.listErr != nil { + assert.Equal(t, slog.LevelError, info.Level) + assert.Contains(t, info.Message, "failed to list rules") + } else if tc.updateDueErr != nil { + assert.Equal(t, slog.LevelError, info.Level) + assert.Contains(t, info.Message, "failed to update rule") + } else { + assert.True(t, info.Level == slog.LevelInfo || info.Level == slog.LevelWarn || info.Level == slog.LevelError) + } + case <-timeout: + t.Fatalf("timeout waiting for runInfo messages, expected %d got %d", tc.expectedRunInfo, collected) + } + } + + cancel() + time.Sleep(50 * time.Millisecond) + + repoCall.Unset() + repoCall2.Unset() + tickCall.Unset() + tickCall1.Unset() + }) + } +} diff --git a/re/status.go b/re/status.go new file mode 100644 index 000000000..404c6927c --- /dev/null +++ b/re/status.go @@ -0,0 +1,80 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +// Status represents Rule status. +type Status uint8 + +// Possible User status values. +const ( + // EnabledStatus represents enabled Rule. + EnabledStatus Status = iota + // DisabledStatus represents disabled Rule. + DisabledStatus + // DeletedStatus represents a rule that will be deleted. + DeletedStatus + + // AllStatus is used for querying purposes to list rules irrespective + // of their status - both enabled and disabled. It is never stored in the + // database as the actual User status and should always be the largest + // value in this enumeration. + AllStatus +) + +// String representation of the possible status values. +const ( + Disabled = "disabled" + Enabled = "enabled" + Deleted = "deleted" + All = "all" + Unknown = "unknown" +) + +func (s Status) String() string { + switch s { + case DisabledStatus: + return Disabled + case EnabledStatus: + return Enabled + case DeletedStatus: + return Deleted + case AllStatus: + return All + default: + return Unknown + } +} + +// ToStatus converts string value to a valid status. +func ToStatus(status string) (Status, error) { + switch status { + case "", Enabled: + return EnabledStatus, nil + case Disabled: + return DisabledStatus, nil + case Deleted: + return DeletedStatus, nil + case All: + return AllStatus, nil + } + return Status(0), svcerr.ErrInvalidStatus +} + +func (s Status) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +func (s *Status) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToStatus(str) + *s = val + return err +} diff --git a/re/status_test.go b/re/status_test.go new file mode 100644 index 000000000..02de47531 --- /dev/null +++ b/re/status_test.go @@ -0,0 +1,205 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package re_test + +import ( + "encoding/json" + "testing" + + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/re" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestToStatus(t *testing.T) { + cases := []struct { + desc string + status string + res re.Status + err error + }{ + { + desc: "convert enabled status", + status: re.Enabled, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "convert empty string to enabled status", + status: "", + res: re.EnabledStatus, + err: nil, + }, + { + desc: "convert disabled status", + status: re.Disabled, + res: re.DisabledStatus, + err: nil, + }, + { + desc: "convert deleted status", + status: re.Deleted, + res: re.DeletedStatus, + err: nil, + }, + { + desc: "convert all status", + status: re.All, + res: re.AllStatus, + err: nil, + }, + { + desc: "convert invalid status", + status: "invalid", + res: re.Status(0), + err: svcerr.ErrInvalidStatus, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + status, err := re.ToStatus(tc.status) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.res, status) + }) + } +} + +func TestStatusString(t *testing.T) { + cases := []struct { + desc string + status re.Status + res string + }{ + { + desc: "enabled status to string", + status: re.EnabledStatus, + res: re.Enabled, + }, + { + desc: "disabled status to string", + status: re.DisabledStatus, + res: re.Disabled, + }, + { + desc: "deleted status to string", + status: re.DeletedStatus, + res: re.Deleted, + }, + { + desc: "all status to string", + status: re.AllStatus, + res: re.All, + }, + { + desc: "unknown status to string", + status: re.Status(99), + res: re.Unknown, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + assert.Equal(t, tc.res, tc.status.String()) + }) + } +} + +func TestStatusMarshalJSON(t *testing.T) { + cases := []struct { + desc string + status re.Status + res string + }{ + { + desc: "marshal enabled status", + status: re.EnabledStatus, + res: `"enabled"`, + }, + { + desc: "marshal disabled status", + status: re.DisabledStatus, + res: `"disabled"`, + }, + { + desc: "marshal deleted status", + status: re.DeletedStatus, + res: `"deleted"`, + }, + { + desc: "marshal all status", + status: re.AllStatus, + res: `"all"`, + }, + { + desc: "marshal unknown status", + status: re.Status(99), + res: `"unknown"`, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data, err := json.Marshal(tc.status) + require.NoError(t, err) + assert.Equal(t, tc.res, string(data)) + }) + } +} + +func TestStatusUnmarshalJSON(t *testing.T) { + cases := []struct { + desc string + data string + res re.Status + err error + }{ + { + desc: "unmarshal enabled status", + data: `"enabled"`, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "unmarshal disabled status", + data: `"disabled"`, + res: re.DisabledStatus, + err: nil, + }, + { + desc: "unmarshal deleted status", + data: `"deleted"`, + res: re.DeletedStatus, + err: nil, + }, + { + desc: "unmarshal all status", + data: `"all"`, + res: re.AllStatus, + err: nil, + }, + { + desc: "unmarshal empty string to enabled status", + data: `""`, + res: re.EnabledStatus, + err: nil, + }, + { + desc: "unmarshal invalid status", + data: `"invalid"`, + res: re.Status(0), + err: svcerr.ErrInvalidStatus, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var status re.Status + err := json.Unmarshal([]byte(tc.data), &status) + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.res, status) + }) + } +} diff --git a/readers/api/doc.go b/readers/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/readers/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/readers/api/grpc/client.go b/readers/api/grpc/client.go new file mode 100644 index 000000000..ede4f7876 --- /dev/null +++ b/readers/api/grpc/client.go @@ -0,0 +1,238 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/transformers/senml" + readers "github.com/absmach/supermq/readers" + "github.com/go-kit/kit/endpoint" + kitgrpc "github.com/go-kit/kit/transport/grpc" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const readersSvcName = "readers.v1.ReadersService" + +var _ grpcReadersV1.ReadersServiceClient = (*readersGrpcClient)(nil) + +type readersGrpcClient struct { + readMessages endpoint.Endpoint + timeout time.Duration +} + +// NewReadersClient returns new readers gRPC client instance. +func NewReadersClient(conn *grpc.ClientConn, timeout time.Duration) grpcReadersV1.ReadersServiceClient { + return &readersGrpcClient{ + readMessages: kitgrpc.NewClient( + conn, + readersSvcName, + "ReadMessages", + encodeReadMessagesRequest, + decodeReadMessagesResponse, + grpcReadersV1.ReadMessagesRes{}, + ).Endpoint(), + timeout: timeout, + } +} + +func (client readersGrpcClient) ReadMessages(ctx context.Context, in *grpcReadersV1.ReadMessagesReq, opts ...grpc.CallOption) (*grpcReadersV1.ReadMessagesRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.readMessages(ctx, readMessagesReq{ + chanID: in.GetChannelId(), + domain: in.GetDomainId(), + pageMeta: readers.PageMetadata{ + Offset: in.GetPageMetadata().GetOffset(), + Limit: in.GetPageMetadata().GetLimit(), + Comparator: in.GetPageMetadata().GetComparator(), + Aggregation: in.GetPageMetadata().GetAggregation().String(), + From: in.GetPageMetadata().GetFrom(), + To: in.GetPageMetadata().GetTo(), + Interval: in.GetPageMetadata().GetInterval(), + Subtopic: in.GetPageMetadata().GetSubtopic(), + Publisher: in.GetPageMetadata().GetPublisher(), + Protocol: in.GetPageMetadata().GetProtocol(), + Name: in.GetPageMetadata().GetName(), + Value: in.GetPageMetadata().GetValue(), + BoolValue: in.GetPageMetadata().GetBoolValue(), + StringValue: in.GetPageMetadata().GetStringValue(), + DataValue: in.GetPageMetadata().GetDataValue(), + Format: in.GetPageMetadata().GetFormat(), + }, + }) + if err != nil { + return &grpcReadersV1.ReadMessagesRes{}, decodeError(err) + } + + dpr := res.(readMessagesRes) + return &grpcReadersV1.ReadMessagesRes{ + Total: dpr.Total, + Messages: toResponseMessages(dpr.Messages), + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: dpr.PageMetadata.Offset, + Limit: dpr.PageMetadata.Limit, + }, + }, nil +} + +func decodeReadMessagesResponse(_ context.Context, grpcRes any) (any, error) { + res := grpcRes.(*grpcReadersV1.ReadMessagesRes) + return readMessagesRes{ + Total: res.Total, + Messages: fromResponseMessages(res.Messages), + PageMetadata: readers.PageMetadata{ + Offset: res.GetPageMetadata().GetOffset(), + Limit: res.GetPageMetadata().GetLimit(), + Order: res.GetPageMetadata().GetOrder(), + Dir: res.GetPageMetadata().GetDir(), + }, + }, nil +} + +func encodeReadMessagesRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(readMessagesReq) + return &grpcReadersV1.ReadMessagesReq{ + ChannelId: req.chanID, + DomainId: req.domain, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: req.pageMeta.Offset, + Limit: req.pageMeta.Limit, + Comparator: req.pageMeta.Comparator, + Aggregation: parseAggregation(req.pageMeta.Aggregation), + From: req.pageMeta.From, + To: req.pageMeta.To, + Interval: req.pageMeta.Interval, + Subtopic: req.pageMeta.Subtopic, + Publisher: req.pageMeta.Publisher, + Protocol: req.pageMeta.Protocol, + Name: req.pageMeta.Name, + Value: req.pageMeta.Value, + BoolValue: req.pageMeta.BoolValue, + StringValue: req.pageMeta.StringValue, + DataValue: req.pageMeta.DataValue, + Format: req.pageMeta.Format, + Order: req.pageMeta.Order, + Dir: req.pageMeta.Dir, + }, + }, nil +} + +func fromResponseMessages(protoMessages []*grpcReadersV1.Message) []readers.Message { + var messages []readers.Message + for _, m := range protoMessages { + switch msg := m.Payload.(type) { + case *grpcReadersV1.Message_Senml: + s := msg.Senml + base := s.GetBase() + typed := senml.Message{ + Channel: base.GetChannel(), + Subtopic: base.GetSubtopic(), + Publisher: base.GetPublisher(), + Protocol: base.GetProtocol(), + Name: s.GetName(), + Unit: s.GetUnit(), + Time: s.GetTime(), + UpdateTime: s.GetUpdateTime(), + Value: optionalFloat64(s.GetValue()), + StringValue: optionalString(s.GetStringValue()), + DataValue: optionalString(s.GetDataValue()), + BoolValue: optionalBool(s.GetBoolValue()), + Sum: optionalFloat64(s.GetSum()), + } + messages = append(messages, typed) + case *grpcReadersV1.Message_Json: + j := msg.Json + base := j.GetBase() + var p map[string]any + if err := json.Unmarshal(j.GetPayload(), &p); err != nil { + continue + } + messages = append(messages, map[string]any{ + "channel": base.GetChannel(), + "created": j.GetCreated(), + "subtopic": base.GetSubtopic(), + "publisher": base.GetPublisher(), + "protocol": base.GetProtocol(), + "payload": p, + }) + } + } + return messages +} + +func parseAggregation(agg string) grpcReadersV1.Aggregation { + switch strings.ToUpper(agg) { + case "MAX": + return grpcReadersV1.Aggregation_AGGREGATION_MAX + case "MIN": + return grpcReadersV1.Aggregation_AGGREGATION_MIN + case "SUM": + return grpcReadersV1.Aggregation_AGGREGATION_SUM + case "COUNT": + return grpcReadersV1.Aggregation_AGGREGATION_COUNT + case "AVG": + return grpcReadersV1.Aggregation_AGGREGATION_AVG + default: + return grpcReadersV1.Aggregation_AGGREGATION_UNSPECIFIED + } +} + +func decodeError(err error) error { + if st, ok := status.FromError(err); ok { + switch st.Code() { + case codes.Unauthenticated: + return errors.Wrap(svcerr.ErrAuthentication, errors.New(st.Message())) + case codes.PermissionDenied: + return errors.Wrap(svcerr.ErrAuthorization, errors.New(st.Message())) + case codes.InvalidArgument: + return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message())) + case codes.FailedPrecondition: + return errors.Wrap(errors.ErrMalformedEntity, errors.New(st.Message())) + case codes.NotFound: + return errors.Wrap(svcerr.ErrNotFound, errors.New(st.Message())) + case codes.AlreadyExists: + return errors.Wrap(svcerr.ErrConflict, errors.New(st.Message())) + case codes.OK: + if msg := st.Message(); msg != "" { + return errors.Wrap(errors.ErrUnidentified, errors.New(msg)) + } + return nil + default: + return errors.Wrap(fmt.Errorf("unexpected gRPC status: %s (status code:%v)", st.Code().String(), st.Code()), errors.New(st.Message())) + } + } + return err +} + +func optionalString(v string) *string { + if v == "" { + return nil + } + return &v +} + +func optionalFloat64(v float64) *float64 { + if v == 0 { + return nil + } + return &v +} + +func optionalBool(v bool) *bool { + if !v { + return nil + } + return &v +} diff --git a/readers/api/grpc/doc.go b/readers/api/grpc/doc.go new file mode 100644 index 000000000..67672c6f0 --- /dev/null +++ b/readers/api/grpc/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package grpc contains implementation of Readers service gRPC API. +package grpc diff --git a/readers/api/grpc/endpoint.go b/readers/api/grpc/endpoint.go new file mode 100644 index 000000000..fe28b04a3 --- /dev/null +++ b/readers/api/grpc/endpoint.go @@ -0,0 +1,31 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + + readers "github.com/absmach/supermq/readers" + "github.com/go-kit/kit/endpoint" +) + +func readMessagesEndpoint(svc readers.MessageRepository) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(readMessagesReq) + if err := req.validate(); err != nil { + return readMessagesRes{}, err + } + + page, err := svc.ReadAll(req.chanID, req.pageMeta) + if err != nil { + return readMessagesRes{}, err + } + + return readMessagesRes{ + PageMetadata: page.PageMetadata, + Total: page.Total, + Messages: page.Messages, + }, nil + } +} diff --git a/readers/api/grpc/endpoint_test.go b/readers/api/grpc/endpoint_test.go new file mode 100644 index 000000000..20383fa71 --- /dev/null +++ b/readers/api/grpc/endpoint_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "testing" + "time" + + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + grpcapi "github.com/absmach/supermq/readers/api/grpc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +const ( + port = 7071 + channelID = "testChannelID" + domain = "testDomain" + validID = "validID" + validToken = "valid" + inValidToken = "invalid" + testOffset = 0 + testLimit = 10 +) + +var authAddr = fmt.Sprintf("localhost:%d", port) + +func startGRPCServer(svc readers.MessageRepository, port int) *grpc.Server { + listener, _ := net.Listen("tcp", fmt.Sprintf(":%d", port)) + server := grpc.NewServer() + grpcReadersV1.RegisterReadersServiceServer(server, grpcapi.NewReadersServer(svc)) + go func() { + err := server.Serve(listener) + assert.Nil(&testing.T{}, err, fmt.Sprintf(`"Unexpected error creating reader server %s"`, err)) + }() + + return server +} + +func TestReadMessages(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + grpcClient := grpcapi.NewReadersClient(conn, time.Second) + + tmp := readers.MessagesPage{ + Total: 1, + PageMetadata: readers.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Messages: []readers.Message{ + map[string]any{ + "channel": "testChannel", + "created": int64(123456789), + "subtopic": "testSubtopic", + "publisher": "testPublisher", + "protocol": "testProtocol", + "payload": map[string]any{ + "temp": 23.5, + }, + }, + }, + } + + expectedPayload, err := json.Marshal(tmp.Messages[0].(map[string]any)["payload"]) + require.NoError(t, err) + + expectedRes := &grpcReadersV1.ReadMessagesRes{ + Total: 1, + Messages: []*grpcReadersV1.Message{ + { + Payload: &grpcReadersV1.Message_Json{ + Json: &grpcReadersV1.JsonMessage{ + Base: &grpcReadersV1.BaseMessage{ + Channel: "testChannel", + Subtopic: "testSubtopic", + Publisher: "testPublisher", + Protocol: "testProtocol", + }, + Created: 123456789, + Payload: expectedPayload, + }, + }, + }, + }, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: 0, + Limit: 10, + }, + } + + cases := []struct { + desc string + token string + svcRes readers.MessagesPage + ReadMessagesReq *grpcReadersV1.ReadMessagesReq + ReadMessagesRes *grpcReadersV1.ReadMessagesRes + err error + }{ + { + desc: "read valid req", + token: validToken, + ReadMessagesReq: &grpcReadersV1.ReadMessagesReq{ + ChannelId: channelID, + DomainId: domain, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: testOffset, + Limit: testLimit, + }, + }, + svcRes: tmp, + + ReadMessagesRes: expectedRes, + err: nil, + }, + { + desc: " read missing channel id", + token: validToken, + ReadMessagesReq: &grpcReadersV1.ReadMessagesReq{ + ChannelId: "", + DomainId: domain, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: testOffset, + Limit: testLimit, + }, + }, + ReadMessagesRes: &grpcReadersV1.ReadMessagesRes{}, + err: apiutil.ErrMissingID, + }, + { + desc: "read valid SenML message", + token: validToken, + ReadMessagesReq: &grpcReadersV1.ReadMessagesReq{ + ChannelId: channelID, + DomainId: domain, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: testOffset, + Limit: testLimit, + }, + }, + svcRes: readers.MessagesPage{ + Total: 1, + PageMetadata: readers.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Messages: []readers.Message{ + senml.Message{ + Channel: "senmlChannel", + Subtopic: "senmlSub", + Publisher: "senmlPublisher", + Protocol: "mqtt", + Name: "temperature", + Unit: "C", + Time: 1672531200, + UpdateTime: 1672531300, + Value: float64Ptr(22.5), + StringValue: stringPtr("ok"), + DataValue: stringPtr("binary"), + BoolValue: boolPtr(true), + Sum: float64Ptr(123.4), + }, + }, + }, + ReadMessagesRes: &grpcReadersV1.ReadMessagesRes{ + Total: 1, + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: 0, + Limit: 10, + }, + Messages: []*grpcReadersV1.Message{ + { + Payload: &grpcReadersV1.Message_Senml{ + Senml: &grpcReadersV1.SenMLMessage{ + Base: &grpcReadersV1.BaseMessage{ + Channel: "senmlChannel", + Subtopic: "senmlSub", + Publisher: "senmlPublisher", + Protocol: "mqtt", + }, + Name: "temperature", + Unit: "C", + Time: 1672531200, + UpdateTime: 1672531300, + Value: float64Ptr(22.5), + StringValue: stringPtr("ok"), + DataValue: stringPtr("binary"), + BoolValue: boolPtr(true), + Sum: float64Ptr(123.4), + }, + }, + }, + }, + }, + }, + } + + for _, tc := range cases { + repoCall := svc.On("ReadAll", mock.Anything, mock.Anything).Return(tc.svcRes, tc.err) + dpr, err := grpcClient.ReadMessages(context.Background(), tc.ReadMessagesReq) + assert.Equal(t, tc.ReadMessagesRes.Messages, dpr.Messages, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.ReadMessagesRes.Messages, dpr.Messages)) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + } +} + +func float64Ptr(v float64) *float64 { + return &v +} + +func stringPtr(v string) *string { + return &v +} + +func boolPtr(v bool) *bool { + return &v +} diff --git a/readers/api/grpc/request.go b/readers/api/grpc/request.go new file mode 100644 index 000000000..c53ca5cf2 --- /dev/null +++ b/readers/api/grpc/request.go @@ -0,0 +1,69 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "slices" + "strings" + "time" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/readers" +) + +const maxLimitSize = 1000 + +var validAggregations = []string{"MAX", "MIN", "AVG", "SUM", "COUNT"} + +type readMessagesReq struct { + chanID string + domain string + pageMeta readers.PageMetadata +} + +func (req readMessagesReq) validate() error { + if req.chanID == "" { + return apiutil.ErrMissingID + } + if req.domain == "" { + return apiutil.ErrMissingID + } + + if req.pageMeta.Limit < 1 || req.pageMeta.Limit > maxLimitSize { + return apiutil.ErrLimitSize + } + + if req.pageMeta.Comparator != "" && + req.pageMeta.Comparator != readers.EqualKey && + req.pageMeta.Comparator != readers.LowerThanKey && + req.pageMeta.Comparator != readers.LowerThanEqualKey && + req.pageMeta.Comparator != readers.GreaterThanKey && + req.pageMeta.Comparator != readers.GreaterThanEqualKey { + return apiutil.ErrInvalidComparator + } + + if req.pageMeta.Aggregation == "AGGREGATION_UNSPECIFIED" { + req.pageMeta.Aggregation = "" + } + + if agg := strings.ToUpper(req.pageMeta.Aggregation); agg != "" && agg != "AGGREGATION_UNSPECIFIED" { + if req.pageMeta.From == 0 { + return apiutil.ErrMissingFrom + } + + if req.pageMeta.To == 0 { + return apiutil.ErrMissingTo + } + + if !slices.Contains(validAggregations, strings.ToUpper(req.pageMeta.Aggregation)) { + return apiutil.ErrInvalidAggregation + } + + if _, err := time.ParseDuration(req.pageMeta.Interval); err != nil { + return apiutil.ErrInvalidInterval + } + } + + return nil +} diff --git a/readers/api/grpc/responses.go b/readers/api/grpc/responses.go new file mode 100644 index 000000000..8181e7597 --- /dev/null +++ b/readers/api/grpc/responses.go @@ -0,0 +1,16 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "github.com/absmach/supermq/readers" +) + +type readMessagesRes struct { + Total uint64 + Messages []readers.Message + readers.PageMetadata +} + +type Message any diff --git a/readers/api/grpc/server.go b/readers/api/grpc/server.go new file mode 100644 index 000000000..72c22d858 --- /dev/null +++ b/readers/api/grpc/server.go @@ -0,0 +1,172 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc + +import ( + "context" + "encoding/json" + + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + grpcapi "github.com/absmach/supermq/auth/api/grpc" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + kitgrpc "github.com/go-kit/kit/transport/grpc" +) + +var _ grpcReadersV1.ReadersServiceServer = (*readersGrpcServer)(nil) + +type readersGrpcServer struct { + grpcReadersV1.UnimplementedReadersServiceServer + readMessages kitgrpc.Handler +} + +func NewReadersServer(svc readers.MessageRepository) grpcReadersV1.ReadersServiceServer { + return &readersGrpcServer{ + readMessages: kitgrpc.NewServer( + (readMessagesEndpoint(svc)), + decodeReadMessagesRequest, + encodeReadMessagesResponse, + ), + } +} + +func decodeReadMessagesRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(*grpcReadersV1.ReadMessagesReq) + return readMessagesReq{ + chanID: req.GetChannelId(), + domain: req.GetDomainId(), + pageMeta: readers.PageMetadata{ + Offset: req.GetPageMetadata().GetOffset(), + Limit: req.GetPageMetadata().GetLimit(), + Comparator: req.GetPageMetadata().GetComparator(), + Aggregation: stringifyAggregation(req.GetPageMetadata().GetAggregation()), + From: req.GetPageMetadata().GetFrom(), + To: req.GetPageMetadata().GetTo(), + Interval: req.GetPageMetadata().GetInterval(), + Subtopic: req.GetPageMetadata().GetSubtopic(), + Publisher: req.GetPageMetadata().GetPublisher(), + Protocol: req.GetPageMetadata().GetProtocol(), + Name: req.GetPageMetadata().GetName(), + Value: req.GetPageMetadata().GetValue(), + BoolValue: req.GetPageMetadata().GetBoolValue(), + StringValue: req.GetPageMetadata().GetStringValue(), + DataValue: req.GetPageMetadata().GetDataValue(), + Format: req.GetPageMetadata().GetFormat(), + Order: req.GetPageMetadata().GetOrder(), + Dir: req.GetPageMetadata().GetDir(), + }, + }, nil +} + +func encodeReadMessagesResponse(_ context.Context, grpcRes any) (any, error) { + res := grpcRes.(readMessagesRes) + + resp := &grpcReadersV1.ReadMessagesRes{ + Total: res.Total, + Messages: toResponseMessages(res.Messages), + PageMetadata: &grpcReadersV1.PageMetadata{ + Offset: res.PageMetadata.Offset, + Limit: res.PageMetadata.Limit, + Order: res.PageMetadata.Order, + Dir: res.PageMetadata.Dir, + }, + } + return resp, nil +} + +func (s *readersGrpcServer) ReadMessages(ctx context.Context, req *grpcReadersV1.ReadMessagesReq) (*grpcReadersV1.ReadMessagesRes, error) { + _, res, err := s.readMessages.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcReadersV1.ReadMessagesRes), nil +} + +func toResponseMessages(messages []readers.Message) []*grpcReadersV1.Message { + var res []*grpcReadersV1.Message + for _, m := range messages { + switch typed := m.(type) { + case senml.Message: + res = append(res, &grpcReadersV1.Message{ + Payload: &grpcReadersV1.Message_Senml{ + Senml: &grpcReadersV1.SenMLMessage{ + Base: &grpcReadersV1.BaseMessage{ + Channel: typed.Channel, + Subtopic: typed.Subtopic, + Publisher: typed.Publisher, + Protocol: typed.Protocol, + }, + Name: typed.Name, + Unit: typed.Unit, + Time: typed.Time, + UpdateTime: typed.UpdateTime, + Value: typed.Value, + StringValue: typed.StringValue, + DataValue: typed.DataValue, + BoolValue: typed.BoolValue, + Sum: typed.Sum, + }, + }, + }) + case map[string]any: + payload := typed["payload"] + data, err := json.Marshal(payload) + if err != nil { + continue + } + res = append(res, &grpcReadersV1.Message{ + Payload: &grpcReadersV1.Message_Json{ + Json: &grpcReadersV1.JsonMessage{ + Base: &grpcReadersV1.BaseMessage{ + Channel: safeString(typed["channel"]), + Subtopic: safeString(typed["subtopic"]), + Publisher: safeString(typed["publisher"]), + Protocol: safeString(typed["protocol"]), + }, + Created: safeInt64(typed["created"]), + Payload: data, + }, + }, + }) + } + } + return res +} + +func stringifyAggregation(agg grpcReadersV1.Aggregation) string { + switch agg { + case grpcReadersV1.Aggregation_AGGREGATION_UNSPECIFIED: + return "" + case grpcReadersV1.Aggregation_AGGREGATION_MAX: + return "MAX" + case grpcReadersV1.Aggregation_AGGREGATION_MIN: + return "MIN" + case grpcReadersV1.Aggregation_AGGREGATION_AVG: + return "AVG" + case grpcReadersV1.Aggregation_AGGREGATION_SUM: + return "SUM" + case grpcReadersV1.Aggregation_AGGREGATION_COUNT: + return "COUNT" + default: + return "" + } +} + +func safeString(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func safeInt64(v any) int64 { + switch v := v.(type) { + case float64: + return int64(v) + case int64: + return v + default: + return 0 + } +} diff --git a/readers/api/grpc/setup_test.go b/readers/api/grpc/setup_test.go new file mode 100644 index 000000000..001e6219f --- /dev/null +++ b/readers/api/grpc/setup_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package grpc_test + +import ( + "os" + "testing" + + "github.com/absmach/supermq/readers/mocks" +) + +var svc *mocks.MessageRepository + +func TestMain(m *testing.M) { + svc = new(mocks.MessageRepository) + server := startGRPCServer(svc, port) + + code := m.Run() + + server.GracefulStop() + + os.Exit(code) +} diff --git a/readers/api/http/endpoint.go b/readers/api/http/endpoint.go new file mode 100644 index 000000000..f9b22e2cf --- /dev/null +++ b/readers/api/http/endpoint.go @@ -0,0 +1,41 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "context" + + grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" + grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/readers" + "github.com/go-kit/kit/endpoint" +) + +func listMessagesEndpoint(svc readers.MessageRepository, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(listMessagesReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + if err := authnAuthz(ctx, req, authn, clients, channels); err != nil { + return nil, errors.Wrap(svcerr.ErrAuthorization, err) + } + + page, err := svc.ReadAll(req.chanID, req.pageMeta) + if err != nil { + return nil, err + } + + return pageRes{ + PageMetadata: page.PageMetadata, + Total: page.Total, + Messages: page.Messages, + }, nil + } +} diff --git a/readers/api/http/endpoint_test.go b/readers/api/http/endpoint_test.go new file mode 100644 index 000000000..92416da82 --- /dev/null +++ b/readers/api/http/endpoint_test.go @@ -0,0 +1,967 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" + grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + apiutil "github.com/absmach/supermq/api/http/util" + chmocks "github.com/absmach/supermq/channels/mocks" + climocks "github.com/absmach/supermq/clients/mocks" + "github.com/absmach/supermq/internal/testsutil" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + customhttp "github.com/absmach/supermq/readers/api/http" + "github.com/absmach/supermq/readers/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const ( + svcName = "test-service" + clientToken = "1" + userToken = "token" + invalidToken = "invalid" + numOfMessages = 100 + valueFields = 5 + subtopic = "topic" + mqttProt = "mqtt" + httpProt = "http" + msgName = "temperature" + instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" + domainID = "b4d7d79e-fd99-4c2b-ac09-524e43df6888" +) + +var ( + v float64 = 5 + vs = "value" + vb = true + vd = "dataValue" + sum float64 = 42 + validSession = smqauthn.Session{UserID: testsutil.GenerateUUID(&testing.T{})} +) + +func newServer(repo *mocks.MessageRepository, authn *authnmocks.Authentication, clients *climocks.ClientsServiceClient, channels *chmocks.ChannelsServiceClient) *httptest.Server { + mux := customhttp.MakeHandler(repo, authn, clients, channels, svcName, instanceID) + return httptest.NewServer(mux) +} + +type testRequest struct { + client *http.Client + method string + url string + token string + key string +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, http.NoBody) + if err != nil { + return nil, err + } + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + if tr.key != "" { + req.Header.Set("Authorization", apiutil.ClientPrefix+tr.key) + } + + return tr.client.Do(req) +} + +func TestReadAll(t *testing.T) { + chanID := testsutil.GenerateUUID(t) + pubID := testsutil.GenerateUUID(t) + pubID2 := testsutil.GenerateUUID(t) + + now := time.Now().Unix() + + var messages []senml.Message + var queryMsgs []senml.Message + var valueMsgs []senml.Message + var boolMsgs []senml.Message + var stringMsgs []senml.Message + var dataMsgs []senml.Message + + for i := 0; i < numOfMessages; i++ { + // Mix possible values as well as value sum. + msg := senml.Message{ + Channel: chanID, + Publisher: pubID, + Protocol: mqttProt, + Time: float64(now - int64(i)), + Name: "name", + } + + count := i % valueFields + switch count { + case 0: + msg.Value = &v + valueMsgs = append(valueMsgs, msg) + case 1: + msg.BoolValue = &vb + boolMsgs = append(boolMsgs, msg) + case 2: + msg.StringValue = &vs + stringMsgs = append(stringMsgs, msg) + case 3: + msg.DataValue = &vd + dataMsgs = append(dataMsgs, msg) + case 4: + msg.Sum = &sum + msg.Subtopic = subtopic + msg.Protocol = httpProt + msg.Publisher = pubID2 + msg.Name = msgName + queryMsgs = append(queryMsgs, msg) + } + + messages = append(messages, msg) + } + + repo := new(mocks.MessageRepository) + authn := new(authnmocks.Authentication) + clients := new(climocks.ClientsServiceClient) + channels := new(chmocks.ChannelsServiceClient) + ts := newServer(repo, authn, clients, channels) + defer ts.Close() + + cases := []struct { + desc string + req string + url string + token string + key string + status int + res pageRes + authnErr error + authzRes *grpcChannelsV1.AuthzRes + authzErr error + err error + }{ + { + desc: "read page with valid offset and limit", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with valid offset and limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with negative offset as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=-1&limit=10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with negative limit as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=-10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with zero limit as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=0", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-integer offset as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=abc&limit=10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-integer limit as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=abc", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with invalid channel id as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, ""), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with multiple offset as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&offset=1&limit=10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with multiple limit as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=20&limit=10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with empty token as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: "", + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "read page with default offset as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?limit=10", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with default limit as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with senml format as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?format=messages", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with subtopic as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?subtopic=%s&protocol=%s", ts.URL, domainID, chanID, subtopic, httpProt), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Subtopic: subtopic, Format: "messages", Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with subtopic and protocol as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?subtopic=%s&protocol=%s", ts.URL, domainID, chanID, subtopic, httpProt), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Subtopic: subtopic, Format: "messages", Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with publisher as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?publisher=%s", ts.URL, domainID, chanID, pubID2), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Publisher: pubID2, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with protocol as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?protocol=http", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with name as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?name=%s", ts.URL, domainID, chanID, msgName), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Name: msgName, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f", ts.URL, domainID, chanID, v), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and equal comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v, readers.EqualKey), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v, Comparator: readers.EqualKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and lower-than comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v+1, readers.LowerThanKey), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v + 1, Comparator: readers.LowerThanKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and lower-than-or-equal comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v+1, readers.LowerThanEqualKey), + key: clientToken, + + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v + 1, Comparator: readers.LowerThanEqualKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and greater-than comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v-1, readers.GreaterThanKey), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v - 1, Comparator: readers.GreaterThanKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and greater-than-or-equal comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v-1, readers.GreaterThanEqualKey), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v - 1, Comparator: readers.GreaterThanEqualKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with non-float value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=ab01", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with value and wrong comparator as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=wrong", ts.URL, domainID, chanID, v-1), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with boolean value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vb=true", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", BoolValue: true, Order: "time", Dir: "desc"}, + Total: uint64(len(boolMsgs)), + Messages: boolMsgs[0:10], + }, + }, + { + desc: "read page with non-boolean value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vb=yes", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with string value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vs=%s", ts.URL, domainID, chanID, vs), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", StringValue: vs, Order: "time", Dir: "desc"}, + Total: uint64(len(stringMsgs)), + Messages: stringMsgs[0:10], + }, + }, + { + desc: "read page with data value as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vd=%s", ts.URL, domainID, chanID, vd), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", DataValue: vd, Order: "time", Dir: "desc"}, + Total: uint64(len(dataMsgs)), + Messages: dataMsgs[0:10], + }, + }, + { + desc: "read page with non-float from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?from=ABCD", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-float to as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?to=ABCD", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with from/to as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", From: messages[19].Time, To: messages[4].Time, Order: "time", Dir: "desc"}, + Total: uint64(len(messages[5:20])), + Messages: messages[5:15], + }, + }, + { + desc: "read page with aggregation as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with interval as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?interval=10h", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with aggregation and interval as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h", ts.URL, domainID, chanID), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval, to and from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: clientToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Aggregation: "MAX", Interval: "10h", From: messages[19].Time, To: messages[4].Time, Order: "time", Dir: "desc"}, + Total: uint64(len(messages[5:20])), + Messages: messages[5:15], + }, + }, + { + desc: "read page with invalid aggregation and valid interval, to and from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=invalid&interval=10h&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with invalid interval and valid aggregation, to and from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10hrs&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with missing from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&to=%f", ts.URL, domainID, chanID, messages[4].Time), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with invalid from as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&to=ABCD&from=%f", ts.URL, domainID, chanID, messages[4].Time), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with invalid to as client", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&from=%f&to=ABCD", ts.URL, domainID, chanID, messages[4].Time), + key: clientToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with valid offset and limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with invalid client key", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + key: "invalid", + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "read page with unauthorized client key", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + key: clientToken, + authnErr: nil, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "read page with negative offset as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=-1&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with negative limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=-10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with zero limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=0", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-integer offset as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=abc&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-integer limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=abc", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with invalid channel id as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, ""), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with invalid token as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: invalidToken, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthorization, + }, + { + desc: "read page with unauthorized as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: userToken, + authzRes: &grpcChannelsV1.AuthzRes{Authorized: false}, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "read page with multiple offset as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&offset=1&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with multiple limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=20&limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with empty token as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0&limit=10", ts.URL, domainID, chanID), + token: "", + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthorization, + }, + { + desc: "read page with default offset as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?limit=10", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with default limit as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?offset=0", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with senml format as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?format=messages", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with subtopic as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?subtopic=%s&protocol=%s", ts.URL, domainID, chanID, subtopic, httpProt), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Subtopic: subtopic, Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with subtopic and protocol as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?subtopic=%s&protocol=%s", ts.URL, domainID, chanID, subtopic, httpProt), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Subtopic: subtopic, Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with publisher as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?publisher=%s", ts.URL, domainID, chanID, pubID2), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Publisher: pubID2, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with protocol as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?protocol=http", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Protocol: httpProt, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with name as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?name=%s", ts.URL, domainID, chanID, msgName), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Name: msgName, Order: "time", Dir: "desc"}, + Total: uint64(len(queryMsgs)), + Messages: queryMsgs[0:10], + }, + }, + { + desc: "read page with value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f", ts.URL, domainID, chanID, v), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and equal comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v, readers.EqualKey), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v, Comparator: readers.EqualKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and lower-than comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v+1, readers.LowerThanKey), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v + 1, Comparator: readers.LowerThanKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and lower-than-or-equal comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v+1, readers.LowerThanEqualKey), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Value: v + 1, Comparator: readers.LowerThanEqualKey, Order: "time", Dir: "desc"}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and greater-than comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v-1, readers.GreaterThanKey), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Order: "time", Dir: "desc", Format: "messages", Value: v - 1, Comparator: readers.GreaterThanKey}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with value and greater-than-or-equal comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=%s", ts.URL, domainID, chanID, v-1, readers.GreaterThanEqualKey), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Order: "time", Dir: "desc", Limit: 10, Format: "messages", Value: v - 1, Comparator: readers.GreaterThanEqualKey}, + Total: uint64(len(valueMsgs)), + Messages: valueMsgs[0:10], + }, + }, + { + desc: "read page with non-float value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=ab01", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with value and wrong comparator as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?v=%f&comparator=wrong", ts.URL, domainID, chanID, v-1), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with boolean value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vb=true", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", BoolValue: true, Order: "time", Dir: "desc"}, + Total: uint64(len(boolMsgs)), + Messages: boolMsgs[0:10], + }, + }, + { + desc: "read page with non-boolean value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vb=yes", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with string value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vs=%s", ts.URL, domainID, chanID, vs), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", StringValue: vs, Order: "time", Dir: "desc"}, + Total: uint64(len(stringMsgs)), + Messages: stringMsgs[0:10], + }, + }, + { + desc: "read page with data value as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?vd=%s", ts.URL, domainID, chanID, vd), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", DataValue: vd, Order: "time", Dir: "desc"}, + Total: uint64(len(dataMsgs)), + Messages: dataMsgs[0:10], + }, + }, + { + desc: "read page with non-float from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?from=ABCD", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with non-float to as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?to=ABCD", ts.URL, domainID, chanID), + token: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with from/to as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + token: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", From: messages[19].Time, To: messages[4].Time, Order: "time", Dir: "desc"}, + Total: uint64(len(messages[5:20])), + Messages: messages[5:15], + }, + }, + { + desc: "read page with aggregation as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX", ts.URL, domainID, chanID), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with interval as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?interval=10h", ts.URL, domainID, chanID), + key: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Order: "time", Dir: "desc"}, + Total: uint64(len(messages)), + Messages: messages[0:10], + }, + }, + { + desc: "read page with aggregation and interval as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h", ts.URL, domainID, chanID), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval, to and from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: userToken, + status: http.StatusOK, + res: pageRes{ + PageMetadata: readers.PageMetadata{Limit: 10, Format: "messages", Aggregation: "MAX", Interval: "10h", From: messages[19].Time, To: messages[4].Time, Order: "time", Dir: "desc"}, + Total: uint64(len(messages[5:20])), + Messages: messages[5:15], + }, + }, + { + desc: "read page with invalid aggregation and valid interval, to and from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=invalid&interval=10h&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with invalid interval and valid aggregation, to and from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10hrs&from=%f&to=%f", ts.URL, domainID, chanID, messages[19].Time, messages[4].Time), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with missing from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&to=%f", ts.URL, domainID, chanID, messages[4].Time), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with invalid from as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&to=ABCD&from=%f", ts.URL, domainID, chanID, messages[4].Time), + key: userToken, + status: http.StatusBadRequest, + }, + { + desc: "read page with aggregation, interval and to with invalid to as user", + url: fmt.Sprintf("%s/%s/channels/%s/messages?aggregation=MAX&interval=10h&from=%f&to=ABCD", ts.URL, domainID, chanID, messages[4].Time), + key: userToken, + status: http.StatusBadRequest, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(validSession, tc.authnErr) + if tc.key != "" { + authnCall = clients.On("Authenticate", mock.Anything, &grpcClientsV1.AuthnReq{ + Token: smqauthn.AuthPack(smqauthn.DomainAuth, domainID, tc.key), + }).Return(&grpcClientsV1.AuthnRes{Id: testsutil.GenerateUUID(t), Authenticated: true}, tc.authnErr) + } + if tc.authzRes == nil { + tc.authzRes = &grpcChannelsV1.AuthzRes{Authorized: true} + } + authzCall := channels.On("Authorize", mock.Anything, mock.Anything).Return(tc.authzRes, tc.authzErr) + repoCall := repo.On("ReadAll", chanID, tc.res.PageMetadata).Return(readers.MessagesPage{Total: tc.res.Total, Messages: fromSenml(tc.res.Messages)}, nil) + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: tc.url, + token: tc.token, + key: tc.key, + } + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + var page pageRes + err = json.NewDecoder(res.Body).Decode(&page) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected %d got %d", tc.desc, tc.status, res.StatusCode)) + assert.Equal(t, tc.res.Total, page.Total, fmt.Sprintf("%s: expected %d got %d", tc.desc, tc.res.Total, page.Total)) + assert.ElementsMatch(t, tc.res.Messages, page.Messages, fmt.Sprintf("%s: got incorrect body from response", tc.desc)) + authzCall.Unset() + authnCall.Unset() + repoCall.Unset() + }) + } +} + +type pageRes struct { + readers.PageMetadata + Total uint64 `json:"total"` + Messages []senml.Message `json:"messages"` +} + +func fromSenml(in []senml.Message) []readers.Message { + var ret []readers.Message + for _, m := range in { + ret = append(ret, m) + } + return ret +} diff --git a/readers/api/http/requests.go b/readers/api/http/requests.go new file mode 100644 index 000000000..15e99612b --- /dev/null +++ b/readers/api/http/requests.go @@ -0,0 +1,68 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "slices" + "strings" + "time" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/readers" +) + +const maxLimitSize = 1000 + +var validAggregations = []string{"MAX", "MIN", "AVG", "SUM", "COUNT"} + +type listMessagesReq struct { + chanID string + token string + domain string + key string + pageMeta readers.PageMetadata +} + +func (req listMessagesReq) validate() error { + if req.token == "" && req.key == "" { + return apiutil.ErrBearerToken + } + + if req.chanID == "" { + return apiutil.ErrMissingID + } + + if req.pageMeta.Limit < 1 || req.pageMeta.Limit > maxLimitSize { + return apiutil.ErrLimitSize + } + + if req.pageMeta.Comparator != "" && + req.pageMeta.Comparator != readers.EqualKey && + req.pageMeta.Comparator != readers.LowerThanKey && + req.pageMeta.Comparator != readers.LowerThanEqualKey && + req.pageMeta.Comparator != readers.GreaterThanKey && + req.pageMeta.Comparator != readers.GreaterThanEqualKey { + return apiutil.ErrInvalidComparator + } + + if req.pageMeta.Aggregation != "" { + if req.pageMeta.From == 0 { + return apiutil.ErrMissingFrom + } + + if req.pageMeta.To == 0 { + return apiutil.ErrMissingTo + } + + if !slices.Contains(validAggregations, strings.ToUpper(req.pageMeta.Aggregation)) { + return apiutil.ErrInvalidAggregation + } + + if _, err := time.ParseDuration(req.pageMeta.Interval); err != nil { + return apiutil.ErrInvalidInterval + } + } + + return nil +} diff --git a/readers/api/http/responses.go b/readers/api/http/responses.go new file mode 100644 index 000000000..2867239a3 --- /dev/null +++ b/readers/api/http/responses.go @@ -0,0 +1,31 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "net/http" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/readers" +) + +var _ supermq.Response = (*pageRes)(nil) + +type pageRes struct { + readers.PageMetadata + Total uint64 `json:"total"` + Messages []readers.Message `json:"messages"` +} + +func (res pageRes) Headers() map[string]string { + return map[string]string{} +} + +func (res pageRes) Code() int { + return http.StatusOK +} + +func (res pageRes) Empty() bool { + return false +} diff --git a/readers/api/http/transport.go b/readers/api/http/transport.go new file mode 100644 index 000000000..eb011285a --- /dev/null +++ b/readers/api/http/transport.go @@ -0,0 +1,266 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package http + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/absmach/supermq" + grpcChannelsV1 "github.com/absmach/supermq/api/grpc/channels/v1" + grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/connections" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/readers" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +const ( + contentType = "application/json" + offsetKey = "offset" + limitKey = "limit" + formatKey = "format" + subtopicKey = "subtopic" + publisherKey = "publisher" + protocolKey = "protocol" + nameKey = "name" + valueKey = "v" + stringValueKey = "vs" + dataValueKey = "vd" + boolValueKey = "vb" + comparatorKey = "comparator" + fromKey = "from" + toKey = "to" + aggregationKey = "aggregation" + intervalKey = "interval" + defInterval = "1s" + defLimit = 10 + defOffset = 0 + defFormat = "messages" +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc readers.MessageRepository, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient, svcName, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(api.EncodeError), + } + + mux := chi.NewRouter() + mux.Get("/{domainID}/channels/{chanID}/messages", kithttp.NewServer( + listMessagesEndpoint(svc, authn, clients, channels), + decodeList, + encodeResponse, + opts..., + ).ServeHTTP) + + mux.Get("/health", supermq.Health(svcName, instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeList(_ context.Context, r *http.Request) (any, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, offsetKey, defOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + limit, err := apiutil.ReadNumQuery[uint64](r, limitKey, defLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + format, err := apiutil.ReadStringQuery(r, formatKey, defFormat) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + subtopic, err := apiutil.ReadStringQuery(r, subtopicKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + publisher, err := apiutil.ReadStringQuery(r, publisherKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + protocol, err := apiutil.ReadStringQuery(r, protocolKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + name, err := apiutil.ReadStringQuery(r, nameKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + v, err := apiutil.ReadNumQuery[float64](r, valueKey, 0) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + comparator, err := apiutil.ReadStringQuery(r, comparatorKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + vs, err := apiutil.ReadStringQuery(r, stringValueKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + vd, err := apiutil.ReadStringQuery(r, dataValueKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + vb, err := apiutil.ReadBoolQuery(r, boolValueKey, false) + if err != nil && err != apiutil.ErrNotFoundParam { + return nil, err + } + + from, err := apiutil.ReadNumQuery[float64](r, fromKey, 0) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + to, err := apiutil.ReadNumQuery[float64](r, toKey, 0) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + aggregation, err := apiutil.ReadStringQuery(r, aggregationKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + order, err := apiutil.ReadStringQuery(r, api.OrderKey, "time") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + var interval string + if aggregation != "" { + interval, err = apiutil.ReadStringQuery(r, intervalKey, defInterval) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + } + + req := listMessagesReq{ + chanID: chi.URLParam(r, "chanID"), + token: apiutil.ExtractBearerToken(r), + domain: chi.URLParam(r, "domainID"), + key: apiutil.ExtractClientSecret(r), + pageMeta: readers.PageMetadata{ + Offset: offset, + Limit: limit, + Format: format, + Subtopic: subtopic, + Publisher: publisher, + Protocol: protocol, + Name: name, + Value: v, + Comparator: comparator, + StringValue: vs, + DataValue: vd, + BoolValue: vb, + From: from, + To: to, + Aggregation: aggregation, + Interval: interval, + Order: order, + Dir: dir, + }, + } + return req, nil +} + +func encodeResponse(_ context.Context, w http.ResponseWriter, response any) error { + w.Header().Set("Content-Type", contentType) + + if ar, ok := response.(supermq.Response); ok { + for k, v := range ar.Headers() { + w.Header().Set(k, v) + } + + w.WriteHeader(ar.Code()) + + if ar.Empty() { + return nil + } + } + + return json.NewEncoder(w).Encode(response) +} + +func authnAuthz(ctx context.Context, req listMessagesReq, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient, channels grpcChannelsV1.ChannelsServiceClient) error { + clientID, clientType, err := authenticate(ctx, req, authn, clients) + if err != nil { + return err + } + if err := authorize(ctx, clientID, clientType, req.chanID, req.domain, channels); err != nil { + return err + } + return nil +} + +func authenticate(ctx context.Context, req listMessagesReq, authn smqauthn.Authentication, clients grpcClientsV1.ClientsServiceClient) (clientID string, clientType string, err error) { + switch { + case req.token != "": + session, err := authn.Authenticate(ctx, req.token) + if err != nil { + return "", "", err + } + if session.Role == smqauthn.SuperAdminRole { + return session.UserID, policies.UserType, nil + } + + return policies.EncodeDomainUserID(req.domain, session.UserID), policies.UserType, nil + case req.key != "": + res, err := clients.Authenticate(ctx, &grpcClientsV1.AuthnReq{ + Token: smqauthn.AuthPack(smqauthn.DomainAuth, req.domain, req.key), + }) + if err != nil { + return "", "", err + } + if !res.GetAuthenticated() { + return "", "", svcerr.ErrAuthentication + } + return res.GetId(), policies.ClientType, nil + default: + return "", "", svcerr.ErrAuthentication + } +} + +func authorize(ctx context.Context, clientID, clientType, chanID, domain string, channels grpcChannelsV1.ChannelsServiceClient) (err error) { + res, err := channels.Authorize(ctx, &grpcChannelsV1.AuthzReq{ + ClientId: clientID, + ClientType: clientType, + Type: uint32(connections.Subscribe), + ChannelId: chanID, + DomainId: domain, + }) + if err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + if !res.GetAuthorized() { + return svcerr.ErrAuthorization + } + return nil +} diff --git a/readers/middleware/doc.go b/readers/middleware/doc.go new file mode 100644 index 000000000..78e1451d1 --- /dev/null +++ b/readers/middleware/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package middleware provides middleware for Magistrala Readers service. +package middleware diff --git a/readers/middleware/logging.go b/readers/middleware/logging.go new file mode 100644 index 000000000..d7df40883 --- /dev/null +++ b/readers/middleware/logging.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package middleware + +import ( + "log/slog" + "time" + + "github.com/absmach/supermq/readers" +) + +var _ readers.MessageRepository = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc readers.MessageRepository +} + +// LoggingMiddleware adds logging facilities to the core service. +func LoggingMiddleware(svc readers.MessageRepository, logger *slog.Logger) readers.MessageRepository { + return &loggingMiddleware{ + logger: logger, + svc: svc, + } +} + +func (lm *loggingMiddleware) ReadAll(chanID string, rpm readers.PageMetadata) (page readers.MessagesPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("channel_id", chanID), + slog.Group("page", + slog.Uint64("offset", rpm.Offset), + slog.Uint64("limit", rpm.Limit), + slog.Uint64("total", page.Total), + ), + } + if rpm.Subtopic != "" { + args = append(args, slog.String("subtopic", rpm.Subtopic)) + } + if rpm.Publisher != "" { + args = append(args, slog.String("publisher", rpm.Publisher)) + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Read all failed", args...) + return + } + lm.logger.Info("Read all completed successfully", args...) + }(time.Now()) + + return lm.svc.ReadAll(chanID, rpm) +} diff --git a/readers/middleware/metrics.go b/readers/middleware/metrics.go new file mode 100644 index 000000000..852113020 --- /dev/null +++ b/readers/middleware/metrics.go @@ -0,0 +1,39 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +//go:build !test + +package middleware + +import ( + "time" + + "github.com/absmach/supermq/readers" + "github.com/go-kit/kit/metrics" +) + +var _ readers.MessageRepository = (*metricsMiddleware)(nil) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + svc readers.MessageRepository +} + +// MetricsMiddleware instruments core service by tracking request count and latency. +func MetricsMiddleware(svc readers.MessageRepository, counter metrics.Counter, latency metrics.Histogram) readers.MessageRepository { + return &metricsMiddleware{ + counter: counter, + latency: latency, + svc: svc, + } +} + +func (mm *metricsMiddleware) ReadAll(chanID string, rpm readers.PageMetadata) (readers.MessagesPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "read_all").Add(1) + mm.latency.With("method", "read_all").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.ReadAll(chanID, rpm) +} diff --git a/readers/mocks/readers_client.go b/readers/mocks/readers_client.go new file mode 100644 index 000000000..e8915534b --- /dev/null +++ b/readers/mocks/readers_client.go @@ -0,0 +1,127 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/api/grpc/readers/v1" + mock "github.com/stretchr/testify/mock" + "google.golang.org/grpc" +) + +// NewReadersServiceClient creates a new instance of ReadersServiceClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReadersServiceClient(t interface { + mock.TestingT + Cleanup(func()) +}) *ReadersServiceClient { + mock := &ReadersServiceClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// ReadersServiceClient is an autogenerated mock type for the ReadersServiceClient type +type ReadersServiceClient struct { + mock.Mock +} + +type ReadersServiceClient_Expecter struct { + mock *mock.Mock +} + +func (_m *ReadersServiceClient) EXPECT() *ReadersServiceClient_Expecter { + return &ReadersServiceClient_Expecter{mock: &_m.Mock} +} + +// ReadMessages provides a mock function for the type ReadersServiceClient +func (_mock *ReadersServiceClient) ReadMessages(ctx context.Context, in *v1.ReadMessagesReq, opts ...grpc.CallOption) (*v1.ReadMessagesRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for ReadMessages") + } + + var r0 *v1.ReadMessagesRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ReadMessagesReq, ...grpc.CallOption) (*v1.ReadMessagesRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ReadMessagesReq, ...grpc.CallOption) *v1.ReadMessagesRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ReadMessagesRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.ReadMessagesReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// ReadersServiceClient_ReadMessages_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadMessages' +type ReadersServiceClient_ReadMessages_Call struct { + *mock.Call +} + +// ReadMessages is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.ReadMessagesReq +// - opts ...grpc.CallOption +func (_e *ReadersServiceClient_Expecter) ReadMessages(ctx interface{}, in interface{}, opts ...interface{}) *ReadersServiceClient_ReadMessages_Call { + return &ReadersServiceClient_ReadMessages_Call{Call: _e.mock.On("ReadMessages", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *ReadersServiceClient_ReadMessages_Call) Run(run func(ctx context.Context, in *v1.ReadMessagesReq, opts ...grpc.CallOption)) *ReadersServiceClient_ReadMessages_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.ReadMessagesReq + if args[1] != nil { + arg1 = args[1].(*v1.ReadMessagesReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *ReadersServiceClient_ReadMessages_Call) Return(readMessagesRes *v1.ReadMessagesRes, err error) *ReadersServiceClient_ReadMessages_Call { + _c.Call.Return(readMessagesRes, err) + return _c +} + +func (_c *ReadersServiceClient_ReadMessages_Call) RunAndReturn(run func(ctx context.Context, in *v1.ReadMessagesReq, opts ...grpc.CallOption) (*v1.ReadMessagesRes, error)) *ReadersServiceClient_ReadMessages_Call { + _c.Call.Return(run) + return _c +} diff --git a/readers/postgres/README.md b/readers/postgres/README.md new file mode 100644 index 000000000..48e604980 --- /dev/null +++ b/readers/postgres/README.md @@ -0,0 +1,101 @@ +# Postgres reader + +Postgres reader provides message repository implementation for Postgres. + +## Configuration + +The service is configured using the environment variables presented in the +following table. Note that any unset variables will be replaced with their +default values. + +| Variable | Description | Default | +| ------------------------------------ | -------------------------------------------- | ---------------------------- | +| MG_POSTGRES_READER_LOG_LEVEL | Service log level | info | +| MG_POSTGRES_READER_HTTP_HOST | Service HTTP host | localhost | +| MG_POSTGRES_READER_HTTP_PORT | Service HTTP port | 9009 | +| MG_POSTGRES_READER_HTTP_SERVER_CERT | Service HTTP server cert | "" | +| MG_POSTGRES_READER_HTTP_SERVER_KEY | Service HTTP server key | "" | +| MG_POSTGRES_HOST | Postgres DB host | localhost | +| MG_POSTGRES_PORT | Postgres DB port | 5432 | +| MG_POSTGRES_USER | Postgres user | supermq | +| MG_POSTGRES_PASS | Postgres password | supermq | +| MG_POSTGRES_NAME | Postgres database name | messages | +| MG_POSTGRES_SSL_MODE | Postgres SSL mode | disabled | +| MG_POSTGRES_SSL_CERT | Postgres SSL certificate path | "" | +| MG_POSTGRES_SSL_KEY | Postgres SSL key | "" | +| MG_POSTGRES_SSL_ROOT_CERT | Postgres SSL root certificate path | "" | +| MG_CLIENTS_GRPC_URL | Clients service Auth gRPC URL | localhost:7000 | +| MG_CLIENTS_GRPC_TIMEOUT | Clients service Auth gRPC timeout in seconds | 1s | +| MG_CLIENTS_GRPC_CLIENT_TLS | Clients service Auth gRPC TLS mode flag | false | +| MG_CLIENTS_GRPC_CA_CERTS | Clients service Auth gRPC CA certificates | "" | +| MG_AUTH_GRPC_URL | Auth service gRPC URL | localhost:7001 | +| MG_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s | +| MG_AUTH_GRPC_CLIENT_TLS | Auth service gRPC TLS mode flag | false | +| MG_AUTH_GRPC_CA_CERTS | Auth service gRPC CA certificates | "" | +| MG_JAEGER_URL | Jaeger server URL | http://jaeger:4318/v1/traces | +| MG_SEND_TELEMETRY | Send telemetry to supermq call home server | true | +| MG_POSTGRES_READER_INSTANCE_ID | Postgres reader instance ID | | + +## Deployment + +The service itself is distributed as Docker container. Check the [`postgres-reader`](https://github.com/absmach/supermq/blob/main/docker/addons/postgres-reader/docker-compose.yaml#L17-L41) service section in +docker-compose file to see how service is deployed. + +To start the service, execute the following shell script: + +```bash +# download the latest version of the service +git clone https://github.com/absmach/supermq + +cd supermq + +# compile the postgres writer +make postgres-writer + +# copy binary to bin +make install + +# Set the environment variables and run the service +MG_POSTGRES_READER_LOG_LEVEL=[Service log level] \ +MG_POSTGRES_READER_HTTP_HOST=[Service HTTP host] \ +MG_POSTGRES_READER_HTTP_PORT=[Service HTTP port] \ +MG_POSTGRES_READER_HTTP_SERVER_CERT=[Service HTTPS server certificate path] \ +MG_POSTGRES_READER_HTTP_SERVER_KEY=[Service HTTPS server key path] \ +MG_POSTGRES_HOST=[Postgres host] \ +MG_POSTGRES_PORT=[Postgres port] \ +MG_POSTGRES_USER=[Postgres user] \ +MG_POSTGRES_PASS=[Postgres password] \ +MG_POSTGRES_NAME=[Postgres database name] \ +MG_POSTGRES_SSL_MODE=[Postgres SSL mode] \ +MG_POSTGRES_SSL_CERT=[Postgres SSL cert] \ +MG_POSTGRES_SSL_KEY=[Postgres SSL key] \ +MG_POSTGRES_SSL_ROOT_CERT=[Postgres SSL Root cert] \ +MG_CLIENTS_GRPC_URL=[Clients service Auth GRPC URL] \ +MG_CLIENTS_GRPC_TIMEOUT=[Clients service Auth gRPC request timeout in seconds] \ +MG_CLIENTS_GRPC_CLIENT_TLS=[Clients service Auth gRPC TLS mode flag] \ +MG_CLIENTS_GRPC_CA_CERTS=[Clients service Auth gRPC CA certificates] \ +MG_AUTH_GRPC_URL=[Auth service gRPC URL] \ +MG_AUTH_GRPC_TIMEOUT=[Auth service gRPC request timeout in seconds] \ +MG_AUTH_GRPC_CLIENT_TLS=[Auth service gRPC TLS mode flag] \ +MG_AUTH_GRPC_CA_CERTS=[Auth service gRPC CA certificates] \ +MG_JAEGER_URL=[Jaeger server URL] \ +MG_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ +MG_POSTGRES_READER_INSTANCE_ID=[Postgres reader instance ID] \ +$GOBIN/supermq-postgres-reader +``` + +## Usage + +Starting service will start consuming normalized messages in SenML format. + +Comparator Usage Guide: + +| Comparator | Usage | Example | +| ---------- | --------------------------------------------------------------------------- | ---------------------------------- | +| eq | Return values that are equal to the query | eq["active"] -> "active" | +| ge | Return values that are substrings of the query | ge["tiv"] -> "active" and "tiv" | +| gt | Return values that are substrings of the query and not equal to the query | gt["tiv"] -> "active" | +| le | Return values that are superstrings of the query | le["active"] -> "tiv" | +| lt | Return values that are superstrings of the query and not equal to the query | lt["active"] -> "active" and "tiv" | + +Official docs can be found [here](https://docs.supermq.absmach.eu). diff --git a/readers/postgres/doc.go b/readers/postgres/doc.go new file mode 100644 index 000000000..a92d4f9b5 --- /dev/null +++ b/readers/postgres/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres contains repository implementations using Postgres as +// the underlying database. +package postgres diff --git a/readers/postgres/init.go b/readers/postgres/init.go new file mode 100644 index 000000000..10bc5f1eb --- /dev/null +++ b/readers/postgres/init.go @@ -0,0 +1,80 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "fmt" + + "github.com/jmoiron/sqlx" + migrate "github.com/rubenv/sql-migrate" +) + +// Table for SenML messages. +const defTable = "messages" + +// Config defines the options that are used when connecting to a PostgreSQL instance. +type Config struct { + Host string + Port string + User string + Pass string + Name string + SSLMode string + SSLCert string + SSLKey string + SSLRootCert string +} + +// Connect creates a connection to the PostgreSQL instance and applies any +// unapplied database migrations. A non-nil error is returned to indicate +// failure. +func Connect(cfg Config) (*sqlx.DB, error) { + url := fmt.Sprintf("host=%s port=%s user=%s dbname=%s password=%s sslmode=%s sslcert=%s sslkey=%s sslrootcert=%s", cfg.Host, cfg.Port, cfg.User, cfg.Name, cfg.Pass, cfg.SSLMode, cfg.SSLCert, cfg.SSLKey, cfg.SSLRootCert) + + db, err := sqlx.Open("pgx", url) + if err != nil { + return nil, err + } + + if err := migrateDB(db); err != nil { + return nil, err + } + + return db, nil +} + +func migrateDB(db *sqlx.DB) error { + migrations := &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "messages_1", + Up: []string{ + `CREATE TABLE IF NOT EXISTS messages ( + id UUID, + channel UUID, + subtopic VARCHAR(254), + publisher UUID, + protocol TEXT, + name TEXT, + unit TEXT, + value FLOAT, + string_value TEXT, + bool_value BOOL, + data_value TEXT, + sum FLOAT, + time FlOAT, + update_time FLOAT, + PRIMARY KEY (id) + )`, + }, + Down: []string{ + "DROP TABLE messages", + }, + }, + }, + } + + _, err := migrate.Exec(db.DB, "postgres", migrations, migrate.Up) + return err +} diff --git a/readers/postgres/messages.go b/readers/postgres/messages.go new file mode 100644 index 000000000..2f09d450d --- /dev/null +++ b/readers/postgres/messages.go @@ -0,0 +1,202 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "encoding/json" + "fmt" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" +) + +var _ readers.MessageRepository = (*postgresRepository)(nil) + +type postgresRepository struct { + db *sqlx.DB +} + +// New returns new PostgreSQL writer. +func New(db *sqlx.DB) readers.MessageRepository { + return &postgresRepository{ + db: db, + } +} + +func (tr postgresRepository) ReadAll(chanID string, rpm readers.PageMetadata) (readers.MessagesPage, error) { + order := "time" + format := defTable + + if rpm.Format != "" && rpm.Format != defTable { + order = "created" + format = rpm.Format + } + cond := fmtCondition(chanID, rpm) + + q := fmt.Sprintf(`SELECT * FROM %s + WHERE %s ORDER BY %s DESC + LIMIT :limit OFFSET :offset;`, format, cond, order) + + params := map[string]any{ + "channel": chanID, + "limit": rpm.Limit, + "offset": rpm.Offset, + "subtopic": rpm.Subtopic, + "publisher": rpm.Publisher, + "name": rpm.Name, + "protocol": rpm.Protocol, + "value": rpm.Value, + "bool_value": rpm.BoolValue, + "string_value": rpm.StringValue, + "data_value": rpm.DataValue, + "from": rpm.From, + "to": rpm.To, + } + rows, err := tr.db.NamedQuery(q, params) + if err != nil { + if preErr, ok := err.(*pgconn.PrepareError); ok { + err = preErr.Unwrap() + } + if pgErr, ok := err.(*pgconn.PgError); ok { + if pgErr.Code == pgerrcode.UndefinedTable { + return readers.MessagesPage{}, nil + } + } + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + defer rows.Close() + + page := readers.MessagesPage{ + PageMetadata: rpm, + Messages: []readers.Message{}, + } + switch format { + case defTable: + for rows.Next() { + msg := senmlMessage{Message: senml.Message{}} + if err := rows.StructScan(&msg); err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + + page.Messages = append(page.Messages, msg.Message) + } + default: + for rows.Next() { + msg := jsonMessage{} + if err := rows.StructScan(&msg); err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + m, err := msg.toMap() + if err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + page.Messages = append(page.Messages, m) + } + } + + q = fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s;`, format, cond) + rows, err = tr.db.NamedQuery(q, params) + if err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + defer rows.Close() + + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return page, err + } + } + page.Total = total + + return page, nil +} + +func fmtCondition(chanID string, rpm readers.PageMetadata) string { + condition := `channel = :channel` + + var query map[string]any + meta, err := json.Marshal(rpm) + if err != nil { + return condition + } + if err := json.Unmarshal(meta, &query); err != nil { + return condition + } + + for name := range query { + switch name { + case + "subtopic", + "publisher", + "name", + "protocol": + condition = fmt.Sprintf(`%s AND %s = :%s`, condition, name, name) + case "v": + comparator := readers.ParseValueComparator(query) + condition = fmt.Sprintf(`%s AND value %s :value`, condition, comparator) + case "vb": + condition = fmt.Sprintf(`%s AND bool_value = :bool_value`, condition) + case "vs": + comparator := readers.ParseValueComparator(query) + switch comparator { + case "=": + condition = fmt.Sprintf("%s AND string_value = :string_value ", condition) + case ">": + condition = fmt.Sprintf("%s AND string_value LIKE '%%' || :string_value || '%%' AND string_value <> :string_value", condition) + case ">=": + condition = fmt.Sprintf("%s AND string_value LIKE '%%' || :string_value || '%%'", condition) + case "<=": + condition = fmt.Sprintf("%s AND :string_value LIKE '%%' || string_value || '%%'", condition) + case "<": + condition = fmt.Sprintf("%s AND :string_value LIKE '%%' || string_value || '%%' AND string_value <> :string_value", condition) + } + case "vd": + comparator := readers.ParseValueComparator(query) + condition = fmt.Sprintf(`%s AND data_value %s :data_value`, condition, comparator) + case "from": + condition = fmt.Sprintf(`%s AND time >= :from`, condition) + case "to": + condition = fmt.Sprintf(`%s AND time < :to`, condition) + } + } + return condition +} + +type senmlMessage struct { + ID string `db:"id"` + senml.Message +} + +type jsonMessage struct { + ID string `db:"id"` + Channel string `db:"channel"` + Created int64 `db:"created"` + Subtopic string `db:"subtopic"` + Publisher string `db:"publisher"` + Protocol string `db:"protocol"` + Payload []byte `db:"payload"` +} + +func (msg jsonMessage) toMap() (map[string]any, error) { + ret := map[string]any{ + "id": msg.ID, + "channel": msg.Channel, + "created": msg.Created, + "subtopic": msg.Subtopic, + "publisher": msg.Publisher, + "protocol": msg.Protocol, + "payload": map[string]any{}, + } + pld := make(map[string]any) + if err := json.Unmarshal(msg.Payload, &pld); err != nil { + return nil, err + } + ret["payload"] = pld + return ret, nil +} diff --git a/readers/postgres/messages_test.go b/readers/postgres/messages_test.go new file mode 100644 index 000000000..43181bd7c --- /dev/null +++ b/readers/postgres/messages_test.go @@ -0,0 +1,687 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "testing" + "time" + + pwriter "github.com/absmach/supermq/consumers/writers/postgres" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + preader "github.com/absmach/supermq/readers/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + subtopic = "subtopic" + msgsNum = 100 + limit = 10 + valueFields = 5 + mqttProt = "mqtt" + httpProt = "http" + msgName = "temperature" + format1 = "format1" + format2 = "format2" + wrongID = "0" +) + +var ( + v float64 = 5 + vs = "stringValue" + vb = true + vd = "dataValue" + sum float64 = 42 +) + +func TestReadSenml(t *testing.T) { + writer := pwriter.New(db) + + chanID := testsutil.GenerateUUID(t) + pubID := testsutil.GenerateUUID(t) + pubID2 := testsutil.GenerateUUID(t) + wrongID := testsutil.GenerateUUID(t) + + m := senml.Message{ + Channel: chanID, + Publisher: pubID, + Protocol: mqttProt, + } + + messages := []senml.Message{} + valueMsgs := []senml.Message{} + boolMsgs := []senml.Message{} + stringMsgs := []senml.Message{} + dataMsgs := []senml.Message{} + queryMsgs := []senml.Message{} + + now := float64(time.Now().Unix()) + for i := 0; i < msgsNum; i++ { + // Mix possible values as well as value sum. + msg := m + msg.Time = now - float64(i) + + count := i % valueFields + switch count { + case 0: + msg.Value = &v + valueMsgs = append(valueMsgs, msg) + case 1: + msg.BoolValue = &vb + boolMsgs = append(boolMsgs, msg) + case 2: + msg.StringValue = &vs + stringMsgs = append(stringMsgs, msg) + case 3: + msg.DataValue = &vd + dataMsgs = append(dataMsgs, msg) + case 4: + msg.Sum = &sum + msg.Subtopic = subtopic + msg.Protocol = httpProt + msg.Publisher = pubID2 + msg.Name = msgName + queryMsgs = append(queryMsgs, msg) + } + + messages = append(messages, msg) + } + + err := writer.ConsumeBlocking(context.TODO(), messages) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + reader := preader.New(db) + + // Since messages are not saved in natural order, + // cases that return subset of messages are only + // checking data result set size, but not content. + cases := []struct { + desc string + chanID string + pageMeta readers.PageMetadata + page readers.MessagesPage + }{ + { + desc: "read message page for existing channel", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for non-existent channel", + chanID: wrongID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + { + desc: "read message last page", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: msgsNum - 20, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromSenml(messages[msgsNum-20 : msgsNum]), + }, + }, + { + desc: "read message with non-existent subtopic", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + Subtopic: "not-present", + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + { + desc: "read message with subtopic", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Subtopic: subtopic, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with publisher", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Publisher: pubID2, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with wrong format", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Format: "messagess", + Offset: 0, + Limit: uint64(len(queryMsgs)), + Publisher: pubID2, + }, + page: readers.MessagesPage{ + Total: 0, + Messages: []readers.Message{}, + }, + }, + { + desc: "read message with protocol", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Protocol: httpProt, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with name", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Name: msgName, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs[0:limit]), + }, + }, + { + desc: "read message with value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v, + Comparator: readers.EqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v + 1, + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v + 1, + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v - 1, + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v - 1, + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with boolean value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + BoolValue: vb, + }, + page: readers.MessagesPage{ + Total: uint64(len(boolMsgs)), + Messages: fromSenml(boolMsgs[0:limit]), + }, + }, + { + desc: "read message with string value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.EqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: "a stringValues b", + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: "alu", + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with data value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd + string(rune(1)), + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd + string(rune(1)), + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd[:len(vd)-1], + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd[:len(vd)-1], + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with from", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(messages[0:21])), + From: messages[20].Time, + }, + page: readers.MessagesPage{ + Total: uint64(len(messages[0:21])), + Messages: fromSenml(messages[0:21]), + }, + }, + { + desc: "read message with to", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(messages[21:])), + To: messages[20].Time, + }, + page: readers.MessagesPage{ + Total: uint64(len(messages[21:])), + Messages: fromSenml(messages[21:]), + }, + }, + { + desc: "read message with from/to", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + From: messages[5].Time, + To: messages[0].Time, + }, + page: readers.MessagesPage{ + Total: 5, + Messages: fromSenml(messages[1:6]), + }, + }, + } + + for _, tc := range cases { + result, err := reader.ReadAll(tc.chanID, tc.pageMeta) + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %s", tc.desc, err)) + assert.ElementsMatch(t, tc.page.Messages, result.Messages, fmt.Sprintf("%s: got incorrect list of senml Messages from ReadAll()", tc.desc)) + assert.Equal(t, tc.page.Total, result.Total, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.page.Total, result.Total)) + } +} + +func TestReadJSON(t *testing.T) { + writer := pwriter.New(db) + + id1 := testsutil.GenerateUUID(t) + m := json.Message{ + Channel: id1, + Publisher: id1, + Created: time.Now().Unix(), + Subtopic: "subtopic/format/some_json", + Protocol: "coap", + Payload: map[string]any{ + "field_1": 123.0, + "field_2": "value", + "field_3": false, + "field_4": 12.344, + "field_5": map[string]any{ + "field_1": "value", + "field_2": 42.0, + }, + }, + } + messages1 := json.Messages{ + Format: format1, + } + msgs1 := []map[string]any{} + for i := 0; i < msgsNum; i++ { + msg := m + messages1.Data = append(messages1.Data, msg) + m := toMap(msg) + msgs1 = append(msgs1, m) + } + + err := writer.ConsumeBlocking(context.TODO(), messages1) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + id2 := testsutil.GenerateUUID(t) + m = json.Message{ + Channel: id2, + Publisher: id2, + Created: time.Now().Unix(), + Subtopic: "subtopic/other_format/some_other_json", + Protocol: "udp", + Payload: map[string]any{ + "field_1": "other_value", + "false_value": false, + "field_pi": 3.14159265, + }, + } + messages2 := json.Messages{ + Format: format2, + } + msgs2 := []map[string]any{} + for i := 0; i < msgsNum; i++ { + msg := m + if i%2 == 0 { + msg.Protocol = httpProt + } + messages2.Data = append(messages2.Data, msg) + m := toMap(msg) + msgs2 = append(msgs2, m) + } + + err = writer.ConsumeBlocking(context.TODO(), messages2) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + httpMsgs := []map[string]any{} + for i := 0; i < msgsNum; i += 2 { + httpMsgs = append(httpMsgs, msgs2[i]) + } + + reader := preader.New(db) + + cases := map[string]struct { + chanID string + pageMeta readers.PageMetadata + page readers.MessagesPage + }{ + "read message page for existing channel": { + chanID: id1, + pageMeta: readers.PageMetadata{ + Format: messages1.Format, + Offset: 0, + Limit: 10, + }, + page: readers.MessagesPage{ + Total: 100, + Messages: fromJSON(msgs1[:10]), + }, + }, + "read message page for non-existent channel": { + chanID: wrongID, + pageMeta: readers.PageMetadata{ + Format: messages1.Format, + Offset: 0, + Limit: 10, + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + "read message last page": { + chanID: id2, + pageMeta: readers.PageMetadata{ + Format: messages2.Format, + Offset: msgsNum - 20, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromJSON(msgs2[msgsNum-20 : msgsNum]), + }, + }, + "read message with protocol": { + chanID: id2, + pageMeta: readers.PageMetadata{ + Format: messages2.Format, + Offset: 0, + Limit: uint64(msgsNum / 2), + Protocol: httpProt, + }, + page: readers.MessagesPage{ + Total: uint64(msgsNum / 2), + Messages: fromJSON(httpMsgs), + }, + }, + } + + for desc, tc := range cases { + result, err := reader.ReadAll(tc.chanID, tc.pageMeta) + for i := 0; i < len(result.Messages); i++ { + m := result.Messages[i] + // Remove id as it is not sent by the client. + delete(m.(map[string]any), "id") + result.Messages[i] = m + } + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %s", desc, err)) + assert.ElementsMatch(t, tc.page.Messages, result.Messages, fmt.Sprintf("%s: got incorrect list of json Messages from ReadAll()", desc)) + assert.Equal(t, tc.page.Total, result.Total, fmt.Sprintf("%s: expected %v got %v", desc, tc.page.Total, result.Total)) + } +} + +func fromSenml(msg []senml.Message) []readers.Message { + var ret []readers.Message + for _, m := range msg { + ret = append(ret, m) + } + return ret +} + +func fromJSON(msg []map[string]any) []readers.Message { + var ret []readers.Message + for _, m := range msg { + ret = append(ret, m) + } + return ret +} + +func toMap(msg json.Message) map[string]any { + return map[string]any{ + "channel": msg.Channel, + "created": msg.Created, + "subtopic": msg.Subtopic, + "publisher": msg.Publisher, + "protocol": msg.Protocol, + "payload": map[string]any(msg.Payload), + } +} diff --git a/readers/postgres/setup_test.go b/readers/postgres/setup_test.go new file mode 100644 index 000000000..4636f6a28 --- /dev/null +++ b/readers/postgres/setup_test.go @@ -0,0 +1,83 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package postgres_test contains tests for PostgreSQL repository +// implementations. +package postgres_test + +import ( + "fmt" + "log" + "os" + "testing" + + "github.com/absmach/supermq/readers/postgres" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var db *sqlx.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + + if err = pool.Retry(func() error { + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = postgres.Connect(dbConfig); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err = pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/readers/timescale/README.md b/readers/timescale/README.md new file mode 100644 index 000000000..5a401e623 --- /dev/null +++ b/readers/timescale/README.md @@ -0,0 +1,99 @@ +# Timescale reader + +Timescale reader provides message repository implementation for Timescale. + +## Configuration + +The service is configured using the environment variables presented in the +following table. Note that any unset variables will be replaced with their +default values. + +| Variable | Description | Default | +| ------------------------------------- | -------------------------------------------- | ---------------------------- | +| MG_TIMESCALE_READER_LOG_LEVEL | Service log level | info | +| MG_TIMESCALE_READER_HTTP_HOST | Service HTTP host | localhost | +| MG_TIMESCALE_READER_HTTP_PORT | Service HTTP port | 8180 | +| MG_TIMESCALE_READER_HTTP_SERVER_CERT | Service HTTP server certificate path | "" | +| MG_TIMESCALE_READER_HTTP_SERVER_KEY | Service HTTP server key path | "" | +| MG_TIMESCALE_HOST | Timescale DB host | localhost | +| MG_TIMESCALE_PORT | Timescale DB port | 5432 | +| MG_TIMESCALE_USER | Timescale user | supermq | +| MG_TIMESCALE_PASS | Timescale password | supermq | +| MG_TIMESCALE_NAME | Timescale database name | messages | +| MG_TIMESCALE_SSL_MODE | Timescale SSL mode | disabled | +| MG_TIMESCALE_SSL_CERT | Timescale SSL certificate path | "" | +| MG_TIMESCALE_SSL_KEY | Timescale SSL key | "" | +| MG_TIMESCALE_SSL_ROOT_CERT | Timescale SSL root certificate path | "" | +| MG_CLIENTS_GRPC_URL | Clients service Auth gRPC URL | localhost:7000 | +| MG_CLIENTS_GRPC_TIMEOUT | Clients service Auth gRPC timeout in seconds | 1s | +| MG_CLIENTS_GRPC_CLIENT_TLS | Clients service Auth gRPC TLS enabled flag | false | +| MG_CLIENTS_GRPC_CA_CERTS | Clients service Auth gRPC CA certificates | "" | +| MG_AUTH_GRPC_URL | Auth service gRPC URL | localhost:7001 | +| MG_AUTH_GRPC_TIMEOUT | Auth service gRPC timeout in seconds | 1s | +| MG_AUTH_GRPC_CLIENT_TLS | Auth service gRPC TLS enabled flag | false | +| MG_AUTH_GRPC_CA_CERT | Auth service gRPC CA certificate | "" | +| MG_JAEGER_URL | Jaeger server URL | http://jaeger:4318/v1/traces | +| MG_SEND_TELEMETRY | Send telemetry to supermq call home server | true | +| MG_TIMESCALE_READER_INSTANCE_ID | Timescale reader instance ID | "" | + +## Deployment + +The service itself is distributed as Docker container. Check the [`timescale-reader`](https://github.com/absmach/supermq/blob/main/docker/docker-compose.yaml) service section in the root docker-compose file to see how service is deployed. + +To start the service, execute the following shell script: + +```bash +# download the latest version of the service +git clone https://github.com/absmach/supermq + +cd supermq + +# compile the timescale writer +make timescale-writer + +# copy binary to bin +make install + +# Set the environment variables and run the service +MG_TIMESCALE_READER_LOG_LEVEL=[Service log level] \ +MG_TIMESCALE_READER_HTTP_HOST=[Service HTTP host] \ +MG_TIMESCALE_READER_HTTP_PORT=[Service HTTP port] \ +MG_TIMESCALE_READER_HTTP_SERVER_CERT=[Service HTTP server cert] \ +MG_TIMESCALE_READER_HTTP_SERVER_KEY=[Service HTTP server key] \ +MG_TIMESCALE_HOST=[Timescale host] \ +MG_TIMESCALE_PORT=[Timescale port] \ +MG_TIMESCALE_USER=[Timescale user] \ +MG_TIMESCALE_PASS=[Timescale password] \ +MG_TIMESCALE_NAME=[Timescale database name] \ +MG_TIMESCALE_SSL_MODE=[Timescale SSL mode] \ +MG_TIMESCALE_SSL_CERT=[Timescale SSL cert] \ +MG_TIMESCALE_SSL_KEY=[Timescale SSL key] \ +MG_TIMESCALE_SSL_ROOT_CERT=[Timescale SSL Root cert] \ +MG_CLIENTS_GRPC_URL=[Clients service Auth GRPC URL] \ +MG_CLIENTS_GRPC_TIMEOUT=[Clients service Auth gRPC request timeout in seconds] \ +MG_CLIENTS_GRPC_CLIENT_TLS=[Clients service Auth gRPC TLS enabled flag] \ +MG_CLIENTS_GRPC_CA_CERTS=[Clients service Auth gRPC CA certificates] \ +MG_AUTH_GRPC_URL=[Auth service Auth gRPC URL] \ +MG_AUTH_GRPC_TIMEOUT=[Auth service Auth gRPC request timeout in seconds] \ +MG_AUTH_GRPC_CLIENT_TLS=[Auth service Auth gRPC TLS enabled flag] \ +MG_AUTH_GRPC_CA_CERT=[Auth service Auth gRPC CA certificates] \ +MG_JAEGER_URL=[Jaeger server URL] \ +MG_SEND_TELEMETRY=[Send telemetry to supermq call home server] \ +MG_TIMESCALE_READER_INSTANCE_ID=[Timescale reader instance ID] \ +$GOBIN/supermq-timescale-reader +``` + +## Usage + +Starting service will start consuming normalized messages in SenML format. + +Comparator Usage Guide: +| Comparator | Usage | Example | +| ---------- | --------------------------------------------------------------------------- | ---------------------------------- | +| eq | Return values that are equal to the query | eq["active"] -> "active" | +| ge | Return values that are substrings of the query | ge["tiv"] -> "active" and "tiv" | +| gt | Return values that are substrings of the query and not equal to the query | gt["tiv"] -> "active" | +| le | Return values that are superstrings of the query | le["active"] -> "tiv" | +| lt | Return values that are superstrings of the query and not equal to the query | lt["active"] -> "active" and "tiv" | + +Official docs can be found [here](https://docs.supermq.absmach.eu). diff --git a/readers/timescale/doc.go b/readers/timescale/doc.go new file mode 100644 index 000000000..302be6ea5 --- /dev/null +++ b/readers/timescale/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package timescale contains repository implementations using Timescale as +// the underlying database. +package timescale diff --git a/readers/timescale/messages.go b/readers/timescale/messages.go new file mode 100644 index 000000000..d3067d180 --- /dev/null +++ b/readers/timescale/messages.go @@ -0,0 +1,352 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package timescale + +import ( + "encoding/json" + "fmt" + "strings" + + api "github.com/absmach/supermq/api/http" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" // required for DB access +) + +// Table for SenML messages. +const ( + defTable = "messages" + orderByTime = "time" + orderByCreated = "created" +) + +var _ readers.MessageRepository = (*timescaleRepository)(nil) + +type timescaleRepository struct { + db *sqlx.DB +} + +// New returns new TimescaleSQL writer. +func New(db *sqlx.DB) readers.MessageRepository { + return ×caleRepository{ + db: db, + } +} + +func (tr timescaleRepository) ReadAll(chanID string, rpm readers.PageMetadata) (readers.MessagesPage, error) { + format := defTable + + if rpm.Format != "" && rpm.Format != defTable { + format = rpm.Format + } + + isSenml := (format == defTable) + + // If aggregation is provided, add time_bucket and aggregation to the query + const timeDivisor = 1000000000 + isAggregated := isSenml && rpm.Aggregation != "" && rpm.Interval != "" + + if rpm.Order == "" { + switch { + case isSenml: + rpm.Order = orderByTime + default: + rpm.Order = orderByCreated + } + } + + orderClause := applyOrdering(rpm, isAggregated, isSenml) + + pgData := "" + if rpm.Limit != 0 { + pgData = "LIMIT :limit" + } + if rpm.Offset != 0 { + if pgData != "" { + pgData += " " + } + pgData += "OFFSET :offset" + } + + where := fmtCondition(rpm) + + var q string + totalQuery := fmt.Sprintf(`SELECT COUNT(*) FROM %s WHERE %s;`, format, where) + + if isAggregated { + q = fmt.Sprintf(` + SELECT + EXTRACT(epoch FROM time_bucket('%s', to_timestamp(time/%d))) *%d AS time, + %s(value) AS value, + FIRST(publisher, time) AS publisher, + FIRST(protocol, time) AS protocol, + FIRST(subtopic, time) AS subtopic, + FIRST(name,time) AS name, + FIRST(unit, time) AS unit + FROM + %s + WHERE + %s + GROUP BY 1 + %s + %s; + `, + rpm.Interval, timeDivisor, timeDivisor, rpm.Aggregation, format, where, orderClause, pgData) + + totalQuery = fmt.Sprintf(`SELECT COUNT(*) FROM (SELECT EXTRACT(epoch FROM time_bucket('%s', to_timestamp(time/%d))) AS time, %s(value) AS value FROM %s WHERE %s GROUP BY 1) AS subquery;`, rpm.Interval, timeDivisor, rpm.Aggregation, format, where) + } else { + q = fmt.Sprintf(`SELECT * FROM %s WHERE %s %s %s;`, format, where, orderClause, pgData) + } + + params := map[string]any{ + "channel": chanID, + "limit": rpm.Limit, + "offset": rpm.Offset, + "subtopic": rpm.Subtopic, + "publisher": rpm.Publisher, + "name": rpm.Name, + "protocol": rpm.Protocol, + "value": rpm.Value, + "bool_value": rpm.BoolValue, + "string_value": rpm.StringValue, + "data_value": rpm.DataValue, + "from": rpm.From, + "to": rpm.To, + } + + rows, err := tr.db.NamedQuery(q, params) + if err != nil { + if preErr, ok := err.(*pgconn.PrepareError); ok { + err = preErr.Unwrap() + } + if pgErr, ok := err.(*pgconn.PgError); ok { + if pgErr.Code == pgerrcode.UndefinedTable { + return readers.MessagesPage{}, nil + } + } + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + defer rows.Close() + + page := readers.MessagesPage{ + PageMetadata: rpm, + Messages: []readers.Message{}, + } + + switch format { + case defTable: + for rows.Next() { + msg := senmlMessage{Message: senml.Message{}} + if err := rows.StructScan(&msg); err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + + page.Messages = append(page.Messages, msg.Message) + } + default: + for rows.Next() { + msg := jsonMessage{} + if err := rows.StructScan(&msg); err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + m, err := msg.toMap() + if err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + page.Messages = append(page.Messages, m) + } + } + + rows, err = tr.db.NamedQuery(totalQuery, params) + if err != nil { + return readers.MessagesPage{}, errors.Wrap(readers.ErrReadMessages, err) + } + defer rows.Close() + + total := uint64(0) + if rows.Next() { + if err := rows.Scan(&total); err != nil { + return page, err + } + } + page.Total = total + + return page, nil +} + +func fmtCondition(rpm readers.PageMetadata) string { + // Indexed columns conditions based on indices order. + chCondition := " channel = :channel " + + var query map[string]any + meta, err := json.Marshal(rpm) + if err != nil { + return chCondition + } + if err := json.Unmarshal(meta, &query); err != nil { + return chCondition + } + + conditions := []string{chCondition} + + if _, ok := query["subtopic"]; ok { + conditions = append(conditions, " subtopic = :subtopic ") + } + + if _, ok := query["publisher"]; ok { + conditions = append(conditions, " publisher = :publisher ") + } + + if _, ok := query["name"]; ok { + conditions = append(conditions, " name = :name ") + } + + if _, ok := query["from"]; ok { + conditions = append(conditions, " time >= :from ") + } + + if _, ok := query["to"]; ok { + conditions = append(conditions, " time < :to ") + } + + // Non Indexed columns conditions added after indexed columns conditions order. + if _, ok := query["protocol"]; ok { + conditions = append(conditions, " protocol = :protocol ") + } + + for name := range query { + switch name { + case "v": + comparator := readers.ParseValueComparator(query) + conditions = append(conditions, fmt.Sprintf(" value %s :value ", comparator)) + case "vb": + conditions = append(conditions, "bool_value = :bool_value") + case "vs": + comparator := readers.ParseValueComparator(query) + switch comparator { + case "=": + conditions = append(conditions, " string_value = :string_value ") + case ">": + conditions = append(conditions, " string_value LIKE '%%' || :string_value || '%%' AND string_value <> :string_value ") + case ">=": + conditions = append(conditions, " string_value LIKE '%%' || :string_value || '%%' ") + case "<=": + conditions = append(conditions, " :string_value LIKE '%%' || string_value || '%%' ") + case "<": + conditions = append(conditions, " :string_value LIKE '%%' || string_value || '%%' AND string_value <> :string_value ") + } + case "vd": + comparator := readers.ParseValueComparator(query) + conditions = append(conditions, fmt.Sprintf(" data_value %s :data_value ", comparator)) + } + } + + return strings.Join(conditions, " AND ") +} + +type senmlMessage struct { + ID string `db:"id"` + senml.Message +} + +type jsonMessage struct { + Channel string `db:"channel"` + Created int64 `db:"created"` + Subtopic string `db:"subtopic"` + Publisher string `db:"publisher"` + Protocol string `db:"protocol"` + Payload []byte `db:"payload"` +} + +func (msg jsonMessage) toMap() (map[string]any, error) { + ret := map[string]any{ + "channel": msg.Channel, + "created": msg.Created, + "subtopic": msg.Subtopic, + "publisher": msg.Publisher, + "protocol": msg.Protocol, + "payload": map[string]any{}, + } + pld := make(map[string]any) + if err := json.Unmarshal(msg.Payload, &pld); err != nil { + return nil, err + } + ret["payload"] = pld + return ret, nil +} + +func applyOrdering(pm readers.PageMetadata, isAggregated bool, isSenml bool) string { + timeCol := orderByTime + if !isSenml { + timeCol = orderByCreated + } + + dir := pm.Dir + if dir != api.AscDir && dir != api.DescDir { + dir = api.DescDir + } + + aggCols := map[string]bool{ + orderByTime: true, + "value": true, + "sum": true, + "publisher": true, + "protocol": true, + "subtopic": true, + "name": true, + "unit": true, + } + + senmlCols := map[string]bool{ + orderByTime: true, + "value": true, + "bool_value": true, + "string_value": true, + "data_value": true, + "publisher": true, + "name": true, + "protocol": true, + "channel": true, + "subtopic": true, + "unit": true, + } + + jsonCols := map[string]bool{ + orderByCreated: true, "publisher": true, "protocol": true, + "channel": true, "subtopic": true, + } + + if isAggregated { + col := pm.Order + if !aggCols[col] { + col = orderByTime + } + if col == orderByTime { + return fmt.Sprintf("ORDER BY time %s", dir) + } + return fmt.Sprintf("ORDER BY %s %s, time %s", col, dir, dir) + } + + col := pm.Order + switch { + case isSenml: + if !senmlCols[col] { + col = orderByTime + } + case !isSenml: + if !jsonCols[col] { + col = orderByCreated + } + } + + secondary := fmt.Sprintf("%s DESC", timeCol) + + if col == timeCol { + return fmt.Sprintf("ORDER BY %s %s", col, dir) + } + return fmt.Sprintf("ORDER BY %s %s, %s", col, dir, secondary) +} diff --git a/readers/timescale/messages_test.go b/readers/timescale/messages_test.go new file mode 100644 index 000000000..f55795027 --- /dev/null +++ b/readers/timescale/messages_test.go @@ -0,0 +1,810 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package timescale_test + +import ( + "context" + "fmt" + "testing" + "time" + + twriter "github.com/absmach/supermq/consumers/writers/timescale" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/transformers/json" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/readers" + treader "github.com/absmach/supermq/readers/timescale" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + subtopic = "subtopic" + msgsNum = 100 + limit = 10 + valueFields = 5 + mqttProt = "mqtt" + httpProt = "http" + msgName = "temperature" + format1 = "format1" + format2 = "format2" + wrongID = "0" +) + +var ( + v float64 = 5 + vs = "stringValue" + vb = true + vd = "dataValue" + sum float64 = 42 +) + +func TestReadSenml(t *testing.T) { + writer := twriter.New(db) + + chanID := testsutil.GenerateUUID(t) + pubID := testsutil.GenerateUUID(t) + pubID2 := testsutil.GenerateUUID(t) + wrongID := testsutil.GenerateUUID(t) + + m := senml.Message{ + Channel: chanID, + Publisher: pubID, + Protocol: mqttProt, + } + + messages := []senml.Message{} + valueMsgs := []senml.Message{} + boolMsgs := []senml.Message{} + stringMsgs := []senml.Message{} + dataMsgs := []senml.Message{} + queryMsgs := []senml.Message{} + + now := float64(time.Now().Unix()) + for i := 0; i < msgsNum; i++ { + // Mix possible values as well as value sum. + msg := m + msg.Time = now - float64(i) + + count := i % valueFields + switch count { + case 0: + msg.Value = &v + valueMsgs = append(valueMsgs, msg) + case 1: + msg.BoolValue = &vb + boolMsgs = append(boolMsgs, msg) + case 2: + msg.StringValue = &vs + stringMsgs = append(stringMsgs, msg) + case 3: + msg.DataValue = &vd + dataMsgs = append(dataMsgs, msg) + case 4: + msg.Sum = &sum + msg.Subtopic = subtopic + msg.Protocol = httpProt + msg.Publisher = pubID2 + msg.Name = msgName + queryMsgs = append(queryMsgs, msg) + } + + messages = append(messages, msg) + } + + err := writer.ConsumeBlocking(context.TODO(), messages) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + reader := treader.New(db) + + // Since messages are not saved in natural order, + // cases that return subset of messages are only + // checking data result set size, but not content. + cases := []struct { + desc string + chanID string + pageMeta readers.PageMetadata + page readers.MessagesPage + }{ + { + desc: "read message page for existing channel", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for non-existent channel", + chanID: wrongID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + { + desc: "read message last page", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: msgsNum - 20, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromSenml(messages[msgsNum-20 : msgsNum]), + }, + }, + { + desc: "read message with non-existent subtopic", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: msgsNum, + Subtopic: "not-present", + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + { + desc: "read message with subtopic", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Subtopic: subtopic, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with publisher", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Publisher: pubID2, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with wrong format", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Format: "messagess", + Offset: 0, + Limit: uint64(len(queryMsgs)), + Publisher: pubID2, + }, + page: readers.MessagesPage{ + Total: 0, + Messages: []readers.Message{}, + }, + }, + { + desc: "read message with protocol", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(queryMsgs)), + Protocol: httpProt, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs), + }, + }, + { + desc: "read message with name", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Name: msgName, + }, + page: readers.MessagesPage{ + Total: uint64(len(queryMsgs)), + Messages: fromSenml(queryMsgs[0:limit]), + }, + }, + { + desc: "read message with value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v, + Comparator: readers.EqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v + 1, + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v + 1, + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v - 1, + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + Value: v - 1, + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(valueMsgs)), + Messages: fromSenml(valueMsgs[0:limit]), + }, + }, + { + desc: "read message with boolean value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + BoolValue: vb, + }, + page: readers.MessagesPage{ + Total: uint64(len(boolMsgs)), + Messages: fromSenml(boolMsgs[0:limit]), + }, + }, + { + desc: "read message with string value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.EqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: "a stringValues b", + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: "alu", + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with string value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + StringValue: vs, + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(stringMsgs)), + Messages: fromSenml(stringMsgs[0:limit]), + }, + }, + { + desc: "read message with data value", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and lower-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd + string(rune(1)), + Comparator: readers.LowerThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and lower-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd + string(rune(1)), + Comparator: readers.LowerThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and greater-than comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd[:len(vd)-1] + string(rune(1)), + Comparator: readers.GreaterThanKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with data value and greater-than-or-equal comparator", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + DataValue: vd[:len(vd)-1] + string(rune(1)), + Comparator: readers.GreaterThanEqualKey, + }, + page: readers.MessagesPage{ + Total: uint64(len(dataMsgs)), + Messages: fromSenml(dataMsgs[0:limit]), + }, + }, + { + desc: "read message with from", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(messages[0:21])), + From: messages[20].Time, + }, + page: readers.MessagesPage{ + Total: uint64(len(messages[0:21])), + Messages: fromSenml(messages[0:21]), + }, + }, + { + desc: "read message with to", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: uint64(len(messages[21:])), + To: messages[20].Time, + }, + page: readers.MessagesPage{ + Total: uint64(len(messages[21:])), + Messages: fromSenml(messages[21:]), + }, + }, + { + desc: "read message with from/to", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Offset: 0, + Limit: limit, + From: messages[5].Time, + To: messages[0].Time, + }, + page: readers.MessagesPage{ + Total: 5, + Messages: fromSenml(messages[1:6]), + }, + }, + } + + for _, tc := range cases { + result, err := reader.ReadAll(tc.chanID, tc.pageMeta) + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %s", tc.desc, err)) + assert.ElementsMatch(t, tc.page.Messages, result.Messages, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.page.Messages, result.Messages)) + assert.Equal(t, tc.page.Total, result.Total, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.page.Total, result.Total)) + } +} + +func TestReadMessagesWithAggregation(t *testing.T) { + writer := twriter.New(db) + + chanID := testsutil.GenerateUUID(t) + pubID := testsutil.GenerateUUID(t) + messages := []senml.Message{} + + now := float64(time.Now().UnixNano()) + value := 10.0 + for i := 0; i < 100; i++ { + if i%10 == 0 { + value += 10.0 + } + v := value + msg := senml.Message{ + Channel: chanID, + Publisher: pubID, + Time: now - float64(i*1000000000), // over 100 seconds + Value: &v, + Protocol: mqttProt, + } + messages = append(messages, msg) + } + + err := writer.ConsumeBlocking(context.TODO(), messages) + require.Nil(t, err, "expected no error got %s\n", err) + + reader := treader.New(db) + + // Set up cases for aggregation readAll + cases := []struct { + desc string + chanID string + pageMeta readers.PageMetadata + page readers.MessagesPage + }{ + { + desc: "read message page for existing channel with AVG aggregation over an hour", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Limit: 100, + Offset: 0, + Aggregation: "AVG", + Interval: "1 hour", + From: now - float64(100000000000), + To: now, + }, + page: readers.MessagesPage{ + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for existing channel with MAX aggregation over an hour", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Limit: 100, + Offset: 0, + Aggregation: "MAX", + Interval: "1 hour", + From: now - float64(100000000000), + To: now, + }, + page: readers.MessagesPage{ + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for existing channel with MIN aggregation over an hour", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Limit: 100, + Offset: 0, + Aggregation: "MIN", + Interval: "1 hour", + From: now - float64(100000000000), + To: now, + }, + page: readers.MessagesPage{ + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for existing channel with SUM aggregation over an hour", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Limit: 100, + Offset: 0, + Aggregation: "SUM", + Interval: "1 hour", + From: now - float64(100000000000), + To: now, + }, + page: readers.MessagesPage{ + Messages: fromSenml(messages), + }, + }, + { + desc: "read message page for existing channel with COUNT aggregation over an hour", + chanID: chanID, + pageMeta: readers.PageMetadata{ + Limit: 100, + Offset: 0, + Aggregation: "COUNT", + Interval: "1 hour", + From: now - float64(100000000000), + To: now, + }, + page: readers.MessagesPage{ + Messages: fromSenml(messages), + }, + }, + } + + for _, tc := range cases { + resultPage, err := reader.ReadAll(tc.chanID, tc.pageMeta) + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %s", tc.desc, err)) + assert.NotEmpty(t, resultPage.Messages, "expected non-empty result set") + for i := range resultPage.Messages { + msg, ok := resultPage.Messages[i].(senml.Message) + if ok && msg.Value != nil { + assert.GreaterOrEqual(t, *msg.Value, resultPage.Value, "expected aggregated value to be greater or equal to the expected value") + } + } + } +} + +func TestReadJSON(t *testing.T) { + writer := twriter.New(db) + + id1 := testsutil.GenerateUUID(t) + messages1 := json.Messages{ + Format: format1, + } + msgs1 := []map[string]any{} + timeNow := time.Now().UnixMilli() + for i := 0; i < msgsNum; i++ { + m := json.Message{ + Channel: id1, + Publisher: id1, + Created: timeNow - int64(i), + Subtopic: "subtopic/format/some_json", + Protocol: "coap", + Payload: map[string]any{ + "field_1": 123.0, + "field_2": "value", + "field_3": false, + "field_4": 12.344, + "field_5": map[string]any{ + "field_1": "value", + "field_2": 42.0, + }, + }, + } + + msg := m + messages1.Data = append(messages1.Data, msg) + mapped := toMap(msg) + msgs1 = append(msgs1, mapped) + } + + err := writer.ConsumeBlocking(context.TODO(), messages1) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + id2 := testsutil.GenerateUUID(t) + messages2 := json.Messages{ + Format: format2, + } + msgs2 := []map[string]any{} + for i := 0; i < msgsNum; i++ { + m := json.Message{ + Channel: id2, + Publisher: id2, + Created: timeNow - int64(i), + Subtopic: "subtopic/other_format/some_other_json", + Protocol: "udp", + Payload: map[string]any{ + "field_1": "other_value", + "false_value": false, + "field_pi": 3.14159265, + }, + } + + msg := m + if i%2 == 0 { + msg.Protocol = httpProt + } + messages2.Data = append(messages2.Data, msg) + mapped := toMap(msg) + msgs2 = append(msgs2, mapped) + } + + err = writer.ConsumeBlocking(context.TODO(), messages2) + require.Nil(t, err, fmt.Sprintf("expected no error got %s\n", err)) + + httpMsgs := []map[string]any{} + for i := 0; i < msgsNum; i += 2 { + httpMsgs = append(httpMsgs, msgs2[i]) + } + + reader := treader.New(db) + + cases := map[string]struct { + chanID string + pageMeta readers.PageMetadata + page readers.MessagesPage + }{ + "read message page for existing channel": { + chanID: id1, + pageMeta: readers.PageMetadata{ + Format: messages1.Format, + Offset: 0, + Limit: 10, + }, + page: readers.MessagesPage{ + Total: 100, + Messages: fromJSON(msgs1[:10]), + }, + }, + "read message page for non-existent channel": { + chanID: wrongID, + pageMeta: readers.PageMetadata{ + Format: messages1.Format, + Offset: 0, + Limit: 10, + }, + page: readers.MessagesPage{ + Messages: []readers.Message{}, + }, + }, + "read message last page": { + chanID: id2, + pageMeta: readers.PageMetadata{ + Format: messages2.Format, + Offset: msgsNum - 20, + Limit: msgsNum, + }, + page: readers.MessagesPage{ + Total: msgsNum, + Messages: fromJSON(msgs2[msgsNum-20 : msgsNum]), + }, + }, + "read message with protocol": { + chanID: id2, + pageMeta: readers.PageMetadata{ + Format: messages2.Format, + Offset: 0, + Limit: uint64(msgsNum / 2), + Protocol: httpProt, + }, + page: readers.MessagesPage{ + Total: uint64(msgsNum / 2), + Messages: fromJSON(httpMsgs), + }, + }, + } + + for desc, tc := range cases { + result, err := reader.ReadAll(tc.chanID, tc.pageMeta) + assert.Nil(t, err, fmt.Sprintf("%s: expected no error got %s", desc, err)) + assert.ElementsMatch(t, tc.page.Messages, result.Messages, fmt.Sprintf("%s: got incorrect list of json Messages from ReadAll()", desc)) + assert.Equal(t, tc.page.Total, result.Total, fmt.Sprintf("%s: expected %v got %v", desc, tc.page.Total, result.Total)) + } +} + +func fromSenml(msg []senml.Message) []readers.Message { + var ret []readers.Message + for _, m := range msg { + ret = append(ret, m) + } + return ret +} + +func fromJSON(msg []map[string]any) []readers.Message { + var ret []readers.Message + for _, m := range msg { + ret = append(ret, m) + } + return ret +} + +func toMap(msg json.Message) map[string]any { + return map[string]any{ + "channel": msg.Channel, + "created": msg.Created, + "subtopic": msg.Subtopic, + "publisher": msg.Publisher, + "protocol": msg.Protocol, + "payload": map[string]any(msg.Payload), + } +} diff --git a/readers/timescale/setup_test.go b/readers/timescale/setup_test.go new file mode 100644 index 000000000..8d66c5a3b --- /dev/null +++ b/readers/timescale/setup_test.go @@ -0,0 +1,88 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package timescale_test contains tests for PostgreSQL repository +// implementations. +package timescale_test + +import ( + "fmt" + "log" + "os" + "testing" + "time" + + tsWriter "github.com/absmach/supermq/consumers/writers/timescale" + pgclient "github.com/absmach/supermq/pkg/postgres" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +var db *sqlx.DB + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "timescale/timescaledb", + Tag: "2.19.3-pg16-oss", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + + // exponential backoff-retry, because the application in the container might not be ready to accept connections yet + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + db, err = sqlx.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := pgclient.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + if db, err = pgclient.Setup(dbConfig, *tsWriter.Migration()); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + code := m.Run() + + // Defers will not be run when using os.Exit + db.Close() + if err = pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/reports/README.md b/reports/README.md new file mode 100644 index 000000000..5c041aaef --- /dev/null +++ b/reports/README.md @@ -0,0 +1,335 @@ +# Reports + +The Reports service generates time-series reports from stored messages. It fetches data from the readers gRPC service, formats results as JSON, CSV, or PDF, optionally emails the report, and supports scheduled report delivery. + +## Configuration + +The service is configured using the following environment variables (values shown are from [docker/.env](https://github.com/absmach/magistrala/blob/main/docker/.env) where available, otherwise from service defaults): + +### Core service + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_REPORTS_LOG_LEVEL` | Log level for the service | `debug` | +| `MG_REPORTS_HTTP_HOST` | HTTP host to bind | `reports` | +| `MG_REPORTS_HTTP_PORT` | HTTP port to bind | `9017` | +| `MG_REPORTS_HTTP_SERVER_CERT` | Path to PEM-encoded HTTPS server certificate | "" | +| `MG_REPORTS_HTTP_SERVER_KEY` | Path to PEM-encoded HTTPS server key | "" | +| `MG_REPORTS_INSTANCE_ID` | Instance ID for tracing/health | "" | +| `MG_JAEGER_URL` | Jaeger collector endpoint | `http://jaeger:4318/v1/traces` | +| `MG_JAEGER_TRACE_RATIO` | Trace sampling ratio | `1.0` | +| `MG_SEND_TELEMETRY` | Send telemetry to Magistrala call-home server | `true` | +| `MG_MESSAGE_BROKER_URL` | Message broker URL (parsed, currently unused by reports) | `nats://nats:4222` | +| `MG_ES_URL` | Event store URL (parsed, currently unused by reports) | `nats://nats:4222` | + +### Database + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_REPORTS_DB_HOST` | PostgreSQL host | `reports-db` | +| `MG_REPORTS_DB_PORT` | PostgreSQL port | `5432` | +| `MG_REPORTS_DB_USER` | PostgreSQL user | `magistrala` | +| `MG_REPORTS_DB_PASS` | PostgreSQL password | `magistrala` | +| `MG_REPORTS_DB_NAME` | PostgreSQL database name | `reports` | +| `MG_REPORTS_DB_SSL_MODE` | PostgreSQL SSL mode | `disable` | +| `MG_REPORTS_DB_SSL_CERT` | PostgreSQL SSL client cert | "" | +| `MG_REPORTS_DB_SSL_KEY` | PostgreSQL SSL client key | "" | +| `MG_REPORTS_DB_SSL_ROOT_CERT` | PostgreSQL SSL root cert | "" | + +### Auth and domains gRPC + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_AUTH_GRPC_URL` | Auth gRPC endpoint | `auth:7001` | +| `MG_AUTH_GRPC_TIMEOUT` | Auth gRPC timeout | `300s` | +| `MG_AUTH_GRPC_CLIENT_CERT` | Auth gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.crt}` | +| `MG_AUTH_GRPC_CLIENT_KEY` | Auth gRPC client key path | `${GRPC_MTLS:+./ssl/certs/auth-grpc-client.key}` | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Auth gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_DOMAINS_GRPC_URL` | Domains gRPC endpoint | `domains:7003` | +| `MG_DOMAINS_GRPC_TIMEOUT` | Domains gRPC timeout | `300s` | +| `MG_DOMAINS_GRPC_CLIENT_CERT` | Domains gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.crt}` | +| `MG_DOMAINS_GRPC_CLIENT_KEY` | Domains gRPC client key path | `${GRPC_MTLS:+./ssl/certs/domains-grpc-client.key}` | +| `MG_DOMAINS_GRPC_SERVER_CA_CERTS` | Domains gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_ALLOW_UNVERIFIED_USER` | Allow unverified users to access | `true` | +| `MG_SPICEDB_PRE_SHARED_KEY` | SpiceDB pre-shared key | `12345678` | +| `MG_SPICEDB_HOST` | SpiceDB host | `supermq-spicedb` | +| `MG_SPICEDB_PORT` | SpiceDB gRPC port | `50051` | + +### Readers gRPC + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_TIMESCALE_READER_GRPC_URL` | Readers gRPC endpoint | `timescale-reader:7011` | +| `MG_TIMESCALE_READER_GRPC_TIMEOUT` | Readers gRPC timeout | `300s` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_CERT` | Readers gRPC client cert path | `${GRPC_MTLS:+./ssl/certs/reader-grpc-client.crt}` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_CA_CERTS` | Readers gRPC server CA path | `${GRPC_MTLS:+./ssl/certs/ca.crt}` | +| `MG_TIMESCALE_READER_GRPC_CLIENT_KEY` | Readers gRPC client key path | `${GRPC_MTLS:+./ssl/certs/readers-grpc-client.key}` | + +### Email + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_EMAIL_HOST` | SMTP host | `smtp.mailtrap.io` | +| `MG_EMAIL_PORT` | SMTP port | `2525` | +| `MG_EMAIL_USERNAME` | SMTP username | `18bf7f70705139` | +| `MG_EMAIL_PASSWORD` | SMTP password | `2b0d302e775b1e` | +| `MG_EMAIL_FROM_ADDRESS` | Sender email address | `from@example.com` | +| `MG_EMAIL_FROM_NAME` | Sender display name | `Example` | +| `MG_EMAIL_TEMPLATE` | Email template path | `email.tmpl` | +| `MG_REPORTS_EMAIL_TEMPLATE` | Template file mounted by Docker Compose | `reports.tmpl` | + +### Templates and PDF conversion + +| Variable | Description | Default | +| --- | --- | --- | +| `MG_REPORTS_DEFAULT_TEMPLATE` | Use on-disk HTML template when non-empty | "" | +| `MG_PDF_CONVERTER_URL` | HTML-to-PDF conversion endpoint | `http://pdf-generator:3000/forms/chromium/convert/html` | + +## Features + +- **Report generation**: Build report data from time-series messages. +- **Multiple formats**: JSON responses, CSV exports, and PDF rendering. +- **Scheduling**: Periodic report delivery via email. +- **Template support**: Custom HTML templates for PDF reports. +- **Observability**: `/metrics` Prometheus endpoint and Jaeger tracing support. + +## Architecture + +### Runtime flow + +1. The Reports API receives a report request or a scheduled run triggers report generation. +2. The service expands requested metrics and fetches messages via the readers gRPC API in batches of 1000. +3. Results are grouped by publisher when `client_ids` are not specified. +4. Output is returned as JSON, rendered to CSV, or converted to PDF via `MG_PDF_CONVERTER_URL`. +5. For scheduled/email actions, the report is sent as an email attachment. + +### Scheduling + +The scheduler runs on a 30-second ticker and selects enabled report configs with `due` time earlier than now. It updates `due` using `Schedule.NextDue()` and generates a report with the `email` action. + +Recurring types are: `none`, `hourly`, `daily`, `weekly`, `monthly`. The `recurring_period` controls the interval (1 = every interval, 2 = every second interval, etc.). + +### Templates + +PDF templates are Go `html/template` documents. A template must include: + +- `{{$.Title}}` +- `{{range .Messages}}` or `{{range .Reports}}` +- `{{formatTime .Time}}` +- `{{formatValue .}}` +- `{{end}}` + +Helper functions include `formatTime`, `formatValue`, `add`, `sub`, `div`, `mod`, `iterate`, `eq`, `ge`, `lt`, `getStartRow`, and `getEndRow`. + +## Data model + +### report_config table + +Defined in `reports/postgres/init.go`: + +| Column | Type | Description | +| --- | --- | --- | +| `id` | `VARCHAR(36)` | Report config UUID (primary key) | +| `name` | `VARCHAR(1024)` | Report name | +| `description` | `TEXT` | Report description | +| `domain_id` | `VARCHAR(36)` | Domain ID | +| `status` | `SMALLINT` | 0 = enabled, 1 = disabled, 2 = deleted | +| `created_at` | `TIMESTAMP` | Creation timestamp | +| `created_by` | `VARCHAR(254)` | Creator user ID | +| `updated_at` | `TIMESTAMP` | Last update timestamp | +| `updated_by` | `VARCHAR(254)` | Last updater user ID | +| `due` | `TIMESTAMPTZ` | Next scheduled execution time | +| `recurring` | `SMALLINT` | Recurring type | +| `recurring_period` | `SMALLINT` | Recurring period | +| `start_datetime` | `TIMESTAMP` | Schedule start time | +| `config` | `JSONB` | Metric config (from/to/title/format/aggregation) | +| `email` | `JSONB` | Email settings | +| `metrics` | `JSONB` | Requested metrics list | +| `report_template` | `TEXT` | Custom HTML template | + +## Deployment + +### Build and run locally + +```bash +make reports + +MG_REPORTS_LOG_LEVEL=debug \ +MG_REPORTS_HTTP_PORT=9017 \ +MG_REPORTS_DB_HOST=localhost \ +MG_REPORTS_DB_PORT=5432 \ +MG_REPORTS_DB_USER=magistrala \ +MG_REPORTS_DB_PASS=magistrala \ +MG_REPORTS_DB_NAME=reports \ +MG_PDF_CONVERTER_URL=http://localhost:4000/forms/chromium/convert/html \ +MG_AUTH_GRPC_URL=localhost:7001 \ +MG_AUTH_GRPC_TIMEOUT=300s \ +MG_DOMAINS_GRPC_URL=localhost:7003 \ +MG_DOMAINS_GRPC_TIMEOUT=300s \ +MG_TIMESCALE_READER_GRPC_URL=localhost:7011 \ +MG_TIMESCALE_READER_GRPC_TIMEOUT=300s \ +./build/reports +``` + +### Docker Compose + +The service is available as a Docker container. Refer to [docker/docker-compose.yaml](https://github.com/absmach/magistrala/blob/main/docker/docker-compose.yaml) for the `reports`, `reports-db`, and `pdf-generator` services and their environment variables. For a full local stack, ensure auth, domains, readers, and the PDF generator are running. + +```bash +docker compose -f docker/docker-compose.yaml up reports reports-db pdf-generator +``` + +### Health check + +```bash +curl -X GET http://localhost:9017/health \ + -H "accept: application/health+json" +``` + +## Testing + +```bash +go test ./reports/... +``` + +## Usage + +The Reports service supports the following operations: + +| Operation | Method & Path | Description | +| --- | --- | --- | +| `generateReport` | `POST /{domainID}/reports` | Generate a report (`action` query param) | +| `addReportConfig` | `POST /{domainID}/reports/configs` | Create a report configuration | +| `listReportsConfig` | `GET /{domainID}/reports/configs` | List report configurations | +| `viewReportConfig` | `GET /{domainID}/reports/configs/{reportID}` | View a report configuration | +| `updateReportConfig` | `PATCH /{domainID}/reports/configs/{reportID}` | Update a report configuration | +| `updateReportSchedule` | `PATCH /{domainID}/reports/configs/{reportID}/schedule` | Update schedule | +| `enableReportConfig` | `POST /{domainID}/reports/configs/{reportID}/enable` | Enable a report configuration | +| `disableReportConfig` | `POST /{domainID}/reports/configs/{reportID}/disable` | Disable a report configuration | +| `deleteReportConfig` | `DELETE /{domainID}/reports/configs/{reportID}` | Delete a report configuration | +| `updateReportTemplate` | `PUT /{domainID}/reports/configs/{reportID}/template` | Update custom template | +| `viewReportTemplate` | `GET /{domainID}/reports/configs/{reportID}/template` | View custom template | +| `deleteReportTemplate` | `DELETE /{domainID}/reports/configs/{reportID}/template` | Delete custom template | +| `health` | `GET /health` | Service health check | + +List filters: `offset`, `limit`, `status`, `name`, `order` (`name`, `created_at`, `updated_at`), and `dir` (`asc`, `desc`). + +Time ranges use relative expressions parsed by `pkg/reltime`, such as `now()` or `now()-24h` (units: `s`, `m`, `h`, `d`, `w`). Aggregation intervals use Go duration strings like `15m` or `1h`. File output formats are `pdf` and `csv`. + +### Example: Generate a report + +```bash +curl -X POST "http://localhost:9017//reports?action=view" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "temperature-view", + "metrics": [ + { + "channel_id": "", + "client_ids": [""], + "name": "temperature", + "subtopic": "sensor" + } + ], + "config": { + "from": "now()-24h", + "to": "now()", + "title": "Temperature (last 24h)", + "timezone": "UTC", + "aggregation": { + "agg_type": "avg", + "interval": "1h" + } + } + }' +``` + +### Example: Generate and email a report + +```bash +curl -X POST "http://localhost:9017//reports?action=email" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "temperature-email", + "metrics": [ + { + "channel_id": "", + "name": "temperature" + } + ], + "config": { + "from": "now()-1d", + "to": "now()", + "title": "Daily Temperature", + "file_format": "csv" + }, + "email": { + "to": ["ops@example.com"], + "subject": "Daily temperature report", + "content": "Report attached." + } + }' + +``` + +### Example: Create a scheduled report config + +```bash +curl -X POST "http://localhost:9017//reports/configs" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "name": "daily-temperature", + "description": "Daily temperature summary", + "metrics": [ + { + "channel_id": "", + "name": "temperature" + } + ], + "config": { + "from": "now()-1d", + "to": "now()", + "title": "Daily Temperature", + "file_format": "pdf", + "aggregation": { + "agg_type": "avg", + "interval": "1h" + } + }, + "email": { + "to": ["ops@example.com"], + "subject": "Daily temperature report", + "content": "Report attached." + }, + "schedule": { + "start_datetime": "2025-01-01T00:00:00Z", + "recurring": "daily", + "recurring_period": 1 + } + }' +``` + +### Example: Update a report template + +```bash +curl -X PUT "http://localhost:9017//reports/configs//template" \ + -H "Authorization: Bearer " \ + -H "Content-Type: application/json" \ + -d '{ + "report_template": "

{{$.Title}}

{{range .Reports}}{{range .Messages}}{{formatTime .Time}} {{formatValue .}}{{end}}{{end}}" + }' +``` + +### Example: Enable a report config + +```bash +curl -X POST "http://localhost:9017//reports/configs//enable" \ + -H "Authorization: Bearer " +``` + +For an in-depth explanation of our Reports Service, see the see the [official documentation][doc]. + +[doc]: https://docs.magistrala.absmach.eu/dev-guide/reports/ diff --git a/reports/api/doc.go b/reports/api/doc.go new file mode 100644 index 000000000..2424852cc --- /dev/null +++ b/reports/api/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package api contains API-related concerns: endpoint definitions, middlewares +// and all resource representations. +package api diff --git a/reports/api/endpoints.go b/reports/api/endpoints.go new file mode 100644 index 000000000..d5f87cb8d --- /dev/null +++ b/reports/api/endpoints.go @@ -0,0 +1,294 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/reports" + "github.com/go-kit/kit/endpoint" +) + +func generateReportEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(generateReportReq) + if err := req.validate(); err != nil { + return generateReportResp{}, err + } + + res, err := svc.GenerateReport(ctx, session, req.ReportConfig, req.action) + if err != nil { + return generateReportResp{}, err + } + + switch req.action { + case reports.DownloadReport: + return downloadReportResp{ + File: res.File, + }, nil + case reports.EmailReport: + return emailReportResp{}, nil + default: + return generateReportResp{ + Total: res.Total, + From: res.From, + To: res.To, + Aggregation: res.Aggregation, + Reports: res.Reports, + }, nil + } + } +} + +func listReportsConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(listReportsConfigReq) + if err := req.validate(); err != nil { + return listReportsConfigRes{}, err + } + + page, err := svc.ListReportsConfig(ctx, session, req.PageMeta) + if err != nil { + return listReportsConfigRes{}, err + } + + return listReportsConfigRes{ + pageRes: pageRes{ + Limit: page.Limit, + Offset: page.Offset, + Total: page.Total, + }, + ReportConfigs: page.ReportConfigs, + }, nil + } +} + +func deleteReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(deleteReportConfigReq) + if err := req.validate(); err != nil { + return deleteReportConfigRes{}, err + } + + err := svc.RemoveReportConfig(ctx, session, req.ID) + if err != nil { + return deleteReportConfigRes{false}, err + } + + return deleteReportConfigRes{true}, nil + } +} + +func updateReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateReportConfigReq) + if err := req.validate(); err != nil { + return updateReportConfigRes{}, err + } + + cfg, err := svc.UpdateReportConfig(ctx, session, req.ReportConfig) + if err != nil { + return updateReportConfigRes{}, err + } + + return updateReportConfigRes{ReportConfig: cfg}, nil + } +} + +func updateReportScheduleEndpoint(s reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateReportScheduleReq) + if err := req.validate(); err != nil { + return updateReportConfigRes{}, err + } + + rpt := reports.ReportConfig{ + ID: req.id, + Schedule: req.Schedule, + } + + updatedReport, err := s.UpdateReportSchedule(ctx, session, rpt) + if err != nil { + return updateReportConfigRes{}, err + } + return updateReportConfigRes{ReportConfig: updatedReport}, nil + } +} + +func viewReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(viewReportConfigReq) + if err := req.validate(); err != nil { + return viewReportConfigRes{}, err + } + + cfg, err := svc.ViewReportConfig(ctx, session, req.ID, req.withRoles) + if err != nil { + return viewReportConfigRes{}, err + } + + return viewReportConfigRes{ReportConfig: cfg}, nil + } +} + +func addReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(addReportConfigReq) + if err := req.validate(); err != nil { + return addReportConfigRes{}, err + } + + cfg, err := svc.AddReportConfig(ctx, session, req.ReportConfig) + if err != nil { + return addReportConfigRes{}, err + } + + return addReportConfigRes{ + ReportConfig: cfg, + created: true, + }, nil + } +} + +func enableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateReportStatusReq) + if err := req.validate(); err != nil { + return updateReportConfigRes{}, err + } + + cfg, err := svc.EnableReportConfig(ctx, session, req.id) + if err != nil { + return updateReportConfigRes{}, err + } + + return updateReportConfigRes{ReportConfig: cfg}, nil + } +} + +func disableReportConfigEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateReportStatusReq) + if err := req.validate(); err != nil { + return updateReportConfigRes{}, err + } + + cfg, err := svc.DisableReportConfig(ctx, session, req.id) + if err != nil { + return updateReportConfigRes{}, err + } + + return updateReportConfigRes{ReportConfig: cfg}, nil + } +} + +func updateReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(updateReportTemplateReq) + if err := req.validate(); err != nil { + return updateReportTemplateRes{false}, err + } + + err := svc.UpdateReportTemplate(ctx, session, req.ReportConfig) + if err != nil { + return updateReportTemplateRes{false}, err + } + + return updateReportTemplateRes{true}, nil + } +} + +func viewReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(getReportTemplateReq) + if err := req.validate(); err != nil { + return viewReportTemplateRes{}, err + } + + template, err := svc.ViewReportTemplate(ctx, session, req.ID) + if err != nil { + return viewReportTemplateRes{}, err + } + + return viewReportTemplateRes{Template: template}, nil + } +} + +func deleteReportTemplateEndpoint(svc reports.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + req := request.(deleteReportTemplateReq) + if err := req.validate(); err != nil { + return deleteReportTemplateRes{false}, err + } + + err := svc.DeleteReportTemplate(ctx, session, req.ID) + if err != nil { + return deleteReportTemplateRes{false}, err + } + + return deleteReportTemplateRes{true}, nil + } +} diff --git a/reports/api/endpoints_test.go b/reports/api/endpoints_test.go new file mode 100644 index 000000000..fb2119319 --- /dev/null +++ b/reports/api/endpoints_test.go @@ -0,0 +1,1404 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api_test + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/internal/testsutil" + smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" + authnmocks "github.com/absmach/supermq/pkg/authn/mocks" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkgSch "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/api" + "github.com/absmach/supermq/reports/mocks" + "github.com/go-chi/chi/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +const contentType = "application/json" + +var ( + namegen = namegenerator.NewGenerator() + domainID = testsutil.GenerateUUID(&testing.T{}) + userID = testsutil.GenerateUUID(&testing.T{}) + validID = testsutil.GenerateUUID(&testing.T{}) + validToken = "valid" + invalidToken = "invalid" + now = time.Now().UTC().Truncate(time.Minute) + future = now.Add(1 * time.Hour) + schedule = pkgSch.Schedule{ + StartDateTime: future, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: future, + } + reportConfig = reports.ReportConfig{ + ID: validID, + Name: namegen.Generate(), + DomainID: domainID, + Schedule: schedule, + Status: reports.EnabledStatus, + Metrics: []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + }, + Config: &reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: title, + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + }, + Email: &reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + }, + } + title = "test_title" +) + +type testRequest struct { + client *http.Client + method string + url string + contentType string + token string + body io.Reader +} + +func (tr testRequest) make() (*http.Response, error) { + req, err := http.NewRequest(tr.method, tr.url, tr.body) + if err != nil { + return nil, err + } + + if tr.token != "" { + req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) + } + + if tr.contentType != "" { + req.Header.Set("Content-Type", tr.contentType) + } + + req.Header.Set("Referer", "http://localhost") + + return tr.client.Do(req) +} + +func newReportsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { + svc := new(mocks.Service) + authn := new(authnmocks.Authentication) + + logger := smqlog.NewMock() + mux := chi.NewRouter() + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + + api.MakeHandler(svc, am, mux, logger, "") + + return httptest.NewServer(mux), svc, authn +} + +func toJSON(data any) string { + jsonData, err := json.Marshal(data) + if err != nil { + return "" + } + return string(jsonData) +} + +func TestAddReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + scheduleInPast := pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: now, + } + + reportInPast := reportConfig + reportInPast.Schedule = scheduleInPast + + cases := []struct { + desc string + cfg reports.ReportConfig + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes reports.ReportConfig + svcErr error + err error + }{ + { + desc: "add report config successfully", + cfg: reportConfig, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusCreated, + svcRes: reportConfig, + }, + { + desc: "add report config with invalid token", + cfg: reportConfig, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "add report config with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + cfg: reportConfig, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "add report config with empty domainID", + token: validToken, + cfg: reportConfig, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "add report config with invalid content type", + token: validToken, + domainID: domainID, + cfg: reportConfig, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "add report config with startdatetime in past", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + cfg: reportInPast, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "add report config with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + cfg: reportConfig, + contentType: contentType, + svcErr: svcerr.ErrCreateEntity, + status: http.StatusUnprocessableEntity, + err: svcerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.cfg) + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/reports/configs", ts.URL, tc.domainID), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("AddReportConfig", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + id string + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes reports.ReportConfig + svcErr error + err error + }{ + { + desc: "view report config successfully", + id: validID, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusOK, + svcRes: reportConfig, + }, + { + desc: "view report config with invalid token", + id: validID, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "view report config with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + id: validID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "view report config with empty domainID", + token: validToken, + id: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "view report config with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + id: validID, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/reports/configs/%s", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("ViewReportConfig", mock.Anything, tc.authnRes, tc.id, false).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestListReportsConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + query string + domainID string + token string + session smqauthn.Session + listReportsResponse reports.ReportConfigPage + status int + authnErr error + err error + }{ + { + desc: "list reports config successfully", + domainID: domainID, + token: validToken, + status: http.StatusOK, + listReportsResponse: reports.ReportConfigPage{ + ReportConfigs: []reports.ReportConfig{reportConfig}, + PageMeta: reports.PageMeta{Total: 1}, + }, + err: nil, + }, + { + desc: "list reports config with empty token", + domainID: domainID, + token: "", + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "list reports config with invalid token", + domainID: domainID, + token: invalidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: ts.URL + "/" + tc.domainID + "/reports/configs?" + tc.query, + contentType: contentType, + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("ListReportsConfig", mock.Anything, tc.session, mock.Anything).Return(tc.listReportsResponse, tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var bodyRes respBody + err = json.NewDecoder(res.Body).Decode(&bodyRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if bodyRes.Err != "" || bodyRes.Message != "" { + err = errors.Wrap(errors.New(bodyRes.Err), errors.New(bodyRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestUpdateReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + updateReq reports.ReportConfig + contentType string + session smqauthn.Session + svcResp reports.ReportConfig + svcErr error + status int + authnErr error + err error + }{ + { + desc: "update report config successfully", + token: validToken, + domainID: domainID, + id: validID, + updateReq: reportConfig, + contentType: contentType, + svcResp: reportConfig, + status: http.StatusOK, + err: nil, + }, + { + desc: "update report config with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + updateReq: reportConfig, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update report config with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + updateReq: reportConfig, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update report config with empty domainID", + token: validToken, + id: validID, + updateReq: reportConfig, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update report config with invalid content type", + token: validToken, + id: validID, + domainID: domainID, + updateReq: reportConfig, + contentType: "application/xml", + svcResp: reportConfig, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update report config with service error", + token: validToken, + id: validID, + domainID: domainID, + updateReq: reportConfig, + contentType: contentType, + svcResp: reports.ReportConfig{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.updateReq) + req := testRequest{ + client: ts.Client(), + method: http.MethodPatch, + url: fmt.Sprintf("%s/%s/reports/configs/%s", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("UpdateReportConfig", mock.Anything, tc.session, mock.Anything).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcErr error + status int + authnErr error + err error + }{ + { + desc: "delete report config successfully", + token: validToken, + domainID: domainID, + id: validID, + svcErr: nil, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "delete report config with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "delete report config with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "delete report config with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "delete report config with service error", + token: validToken, + id: validID, + domainID: domainID, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/reports/configs/%s", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("RemoveReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestEnableReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcResp reports.ReportConfig + svcErr error + status int + authnErr error + err error + }{ + { + desc: "enable report config successfully", + token: validToken, + domainID: domainID, + id: validID, + svcResp: reportConfig, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "enable report config with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "enable report config with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "enable report config with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "enable report config with service error", + token: validToken, + id: validID, + domainID: domainID, + svcResp: reports.ReportConfig{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "enable report config with empty id", + token: validToken, + id: "", + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/reports/configs/%s/enable", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("EnableReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDisableReportConfigEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + token string + id string + domainID string + session smqauthn.Session + svcResp reports.ReportConfig + svcErr error + status int + authnErr error + err error + }{ + { + desc: "disable report config successfully", + token: validToken, + domainID: domainID, + id: validID, + svcResp: reportConfig, + svcErr: nil, + status: http.StatusOK, + err: nil, + }, + { + desc: "disable report config with invalid token", + token: invalidToken, + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "disable report config with empty token", + token: "", + session: smqauthn.Session{}, + domainID: domainID, + id: validID, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "disable report config with empty domainID", + token: validToken, + id: validID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "disable report config with service error", + token: validToken, + id: validID, + domainID: domainID, + svcResp: reports.ReportConfig{}, + svcErr: svcerr.ErrAuthorization, + status: http.StatusForbidden, + err: svcerr.ErrAuthorization, + }, + { + desc: "disable report config with empty id", + token: validToken, + id: "", + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/reports/configs/%s/disable", ts.URL, tc.domainID, tc.id), + token: tc.token, + } + if tc.token == validToken { + tc.session = smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID} + } + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + svcCall := svc.On("DisableReportConfig", mock.Anything, tc.session, tc.id).Return(tc.svcResp, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +type respBody struct { + Err string `json:"error"` + Message string `json:"message"` + Total uint64 `json:"total"` + ID string `json:"id"` + Status reports.Status `json:"status"` +} + +const ( + validTemplate = ` + + + {{$.Title}} + + + +
+

{{$.Title}}

+

Generated on: {{$.GeneratedDate}}

+
+
+

Messages

+ {{range .Messages}} +
+

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+
+ {{end}} +
+ +` + + templateWithoutTitle = ` + + + Report + + + +

Report

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end}} + +` + + templateWithSyntaxError = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end + +` +) + +func TestUpdateReportTemplateEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + id string + template reports.ReportTemplate + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + err error + }{ + { + desc: "update report template successfully", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusNoContent, + }, + { + desc: "update report template with invalid token", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "update report template with empty token", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "update report template with empty domainID", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: validToken, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "update report template with invalid content type", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: validToken, + domainID: domainID, + contentType: "application/xml", + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "update report template with empty ID", + id: "", + template: reports.ReportTemplate(validTemplate), + token: validToken, + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "update report template with empty template", + id: validID, + token: validToken, + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "update report template without title field", + id: validID, + template: reports.ReportTemplate(templateWithoutTitle), + token: validToken, + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "update report template with syntax error", + id: validID, + template: reports.ReportTemplate(templateWithSyntaxError), + token: validToken, + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "update report template with service error", + id: validID, + template: reports.ReportTemplate(validTemplate), + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + contentType: contentType, + svcErr: svcerr.ErrUpdateEntity, + status: http.StatusUnprocessableEntity, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(map[string]any{ + "report_template": tc.template, + }) + req := testRequest{ + client: ts.Client(), + method: http.MethodPut, + url: fmt.Sprintf("%s/%s/reports/configs/%s/template", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("UpdateReportTemplate", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + if res.StatusCode != http.StatusNoContent { + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + } + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestViewReportTemplateEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + id string + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes reports.ReportTemplate + svcErr error + err error + }{ + { + desc: "view report template successfully", + id: validID, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusOK, + svcRes: reports.ReportTemplate(validTemplate), + }, + { + desc: "view report template with invalid token", + id: validID, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "view report template with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + id: validID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "view report template with empty domainID", + token: validToken, + id: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "view report template with empty ID", + token: validToken, + id: "", + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "view report template with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + id: validID, + contentType: contentType, + svcErr: svcerr.ErrViewEntity, + status: http.StatusUnprocessableEntity, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/%s/reports/configs/%s/template", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("ViewReportTemplate", mock.Anything, tc.authnRes, tc.id).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestDeleteReportTemplateEndpoint(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + id string + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + err error + }{ + { + desc: "delete report template successfully", + id: validID, + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusNoContent, + }, + { + desc: "delete report template with invalid token", + id: validID, + token: invalidToken, + authnRes: smqauthn.Session{}, + domainID: domainID, + contentType: contentType, + authnErr: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + err: svcerr.ErrAuthentication, + }, + { + desc: "delete report template with empty token", + token: "", + authnRes: smqauthn.Session{}, + domainID: domainID, + id: validID, + contentType: contentType, + status: http.StatusUnauthorized, + err: apiutil.ErrBearerToken, + }, + { + desc: "delete report template with empty domainID", + token: validToken, + id: validID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingDomainID, + }, + { + desc: "delete report template with empty ID", + token: validToken, + id: "", + domainID: domainID, + contentType: contentType, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "delete report template with service error", + token: validToken, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + id: validID, + contentType: contentType, + svcErr: svcerr.ErrRemoveEntity, + status: http.StatusUnprocessableEntity, + err: svcerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + client: ts.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/%s/reports/configs/%s/template", ts.URL, tc.domainID, tc.id), + contentType: tc.contentType, + token: tc.token, + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("DeleteReportTemplate", mock.Anything, tc.authnRes, tc.id).Return(tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + + if res.StatusCode != http.StatusNoContent { + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + } + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} + +func TestGenerateReportWithTemplateValidation(t *testing.T) { + ts, svc, authn := newReportsServer() + defer ts.Close() + + cases := []struct { + desc string + cfg reports.ReportConfig + action string + domainID string + token string + contentType string + status int + authnRes smqauthn.Session + authnErr error + svcRes reports.ReportPage + svcErr error + err error + }{ + { + desc: "generate report with valid template successfully", + cfg: reports.ReportConfig{ + ID: validID, + Name: namegen.Generate(), + DomainID: domainID, + Metrics: []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + }, + Config: &reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: title, + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + }, + ReportTemplate: reports.ReportTemplate(validTemplate), + }, + action: "view", + token: validToken, + contentType: contentType, + domainID: domainID, + authnRes: smqauthn.Session{DomainUserID: auth.EncodeDomainUserID(domainID, userID), UserID: userID, DomainID: domainID}, + status: http.StatusOK, + svcRes: reports.ReportPage{}, + }, + { + desc: "generate report with invalid template", + cfg: reports.ReportConfig{ + ID: validID, + Name: namegen.Generate(), + DomainID: domainID, + Metrics: []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + }, + Config: &reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: title, + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + }, + ReportTemplate: reports.ReportTemplate(templateWithoutTitle), + }, + action: "view", + token: validToken, + contentType: contentType, + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "generate report with template syntax error", + cfg: reports.ReportConfig{ + ID: validID, + Name: namegen.Generate(), + DomainID: domainID, + Metrics: []reports.ReqMetric{ + { + ChannelID: "channel1", + ClientIDs: []string{"client1"}, + Name: "metric_name", + }, + }, + Config: &reports.MetricConfig{ + From: "now()-1h", + To: "now()", + Title: title, + Aggregation: reports.AggConfig{AggType: reports.AggregationAVG, Interval: "1h"}, + }, + ReportTemplate: reports.ReportTemplate(templateWithSyntaxError), + }, + action: "view", + token: validToken, + contentType: contentType, + domainID: domainID, + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + data := toJSON(tc.cfg) + req := testRequest{ + client: ts.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/%s/reports?action=%s", ts.URL, tc.domainID, tc.action), + contentType: tc.contentType, + token: tc.token, + body: strings.NewReader(data), + } + + authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("GenerateReport", mock.Anything, tc.authnRes, mock.Anything, mock.Anything).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + var errRes respBody + err = json.NewDecoder(res.Body).Decode(&errRes) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authCall.Unset() + }) + } +} diff --git a/reports/api/request.go b/reports/api/request.go new file mode 100644 index 000000000..f44a918fc --- /dev/null +++ b/reports/api/request.go @@ -0,0 +1,242 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "fmt" + + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/reports" +) + +const ( + maxLimitSize = 1000 + MaxNameSize = 1024 + MaxTitleSize = 37 + + errInvalidMetric = "invalid metric[%d]: %w" +) + +var ( + errInvalidReportAction = errors.New("invalid report action") + errMetricsNotProvided = errors.New("metrics not provided") + errMissingReportConfig = errors.New("missing report config") + errMissingReportEmailConfig = errors.New("missing report email config") + errInvalidRecurringPeriod = errors.New("invalid recurring period") + errMissingReportTemplate = errors.New("missing report template") + errTitleSize = errors.New("invalid title size") +) + +type addReportConfigReq struct { + reports.ReportConfig `json:",inline"` +} + +func (req addReportConfigReq) validate() error { + if req.Name == "" { + return apiutil.ErrMissingName + } + if err := req.Schedule.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + if req.ReportTemplate.String() != "" { + if err := req.ReportTemplate.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + } + return validateReportConfig(req.ReportConfig, false, false) +} + +type viewReportConfigReq struct { + ID string `json:"id"` + withRoles bool +} + +func (req viewReportConfigReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + return nil +} + +type listReportsConfigReq struct { + reports.PageMeta `json:",inline"` +} + +func (req listReportsConfigReq) validate() error { + if req.Limit > maxLimitSize { + return svcerr.ErrMalformedEntity + } + + switch req.Order { + case "", api.NameKey, api.CreatedAtOrder, api.UpdatedAtOrder: + default: + return apiutil.ErrInvalidOrder + } + + if req.Dir != api.AscDir && req.Dir != api.DescDir { + return apiutil.ErrInvalidDirection + } + + return nil +} + +type updateReportConfigReq struct { + reports.ReportConfig `json:",inline"` +} + +func (req updateReportConfigReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + return validateReportConfig(req.ReportConfig, false, false) +} + +type updateReportScheduleReq struct { + id string + Schedule schedule.Schedule `json:"schedule,omitempty"` +} + +func (req updateReportScheduleReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + + if err := req.Schedule.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + + return nil +} + +type deleteReportConfigReq struct { + ID string `json:"id"` +} + +func (req deleteReportConfigReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + return nil +} + +type generateReportReq struct { + reports.ReportConfig + action reports.ReportAction +} + +func (req generateReportReq) validate() error { + if len(req.Config.Title) > MaxTitleSize { + return errors.Wrap(apiutil.ErrValidation, errTitleSize) + } + + if req.ReportTemplate.String() != "" { + if err := req.ReportTemplate.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + } + + switch req.action { + case reports.ViewReport, reports.DownloadReport: + return validateReportConfig(req.ReportConfig, true, true) + case reports.EmailReport: + return validateReportConfig(req.ReportConfig, false, true) + default: + return errors.Wrap(apiutil.ErrValidation, errInvalidReportAction) + } +} + +type updateReportStatusReq struct { + id string +} + +func (req updateReportStatusReq) validate() error { + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +func validateReportConfig(req reports.ReportConfig, skipEmailValidation bool, skipSchedularValidation bool) error { + if len(req.Metrics) == 0 { + return errors.Wrap(apiutil.ErrValidation, errMetricsNotProvided) + } + for i, metric := range req.Metrics { + if err := metric.Validate(); err != nil { + return errors.Wrap(apiutil.ErrValidation, fmt.Errorf(errInvalidMetric, i+1, err)) + } + } + + if req.Config == nil { + return errors.Wrap(errMissingReportConfig, apiutil.ErrValidation) + } + if err := req.Config.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + + if skipEmailValidation { + return nil + } + if req.Email == nil { + return errors.Wrap(errMissingReportEmailConfig, apiutil.ErrValidation) + } + if err := req.Email.Validate(); err != nil { + return errors.Wrap(apiutil.ErrValidation, err) + } + + if skipSchedularValidation { + return nil + } + + return validateScheduler(req.Schedule) +} + +func validateScheduler(sch schedule.Schedule) error { + if sch.Recurring != schedule.None && sch.RecurringPeriod < 1 { + return errInvalidRecurringPeriod + } + return nil +} + +type updateReportTemplateReq struct { + reports.ReportConfig `json:",inline"` +} + +func (req updateReportTemplateReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + if req.ReportTemplate == "" { + return errors.Wrap(errMissingReportTemplate, apiutil.ErrValidation) + } + if err := req.ReportTemplate.Validate(); err != nil { + return errors.Wrap(err, apiutil.ErrValidation) + } + return nil +} + +type getReportTemplateReq struct { + ID string `json:"id"` +} + +func (req getReportTemplateReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + return nil +} + +type deleteReportTemplateReq struct { + ID string `json:"id"` +} + +func (req deleteReportTemplateReq) validate() error { + if req.ID == "" { + return apiutil.ErrMissingID + } + return nil +} diff --git a/reports/api/response.go b/reports/api/response.go new file mode 100644 index 000000000..34bc4243b --- /dev/null +++ b/reports/api/response.go @@ -0,0 +1,221 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "net/http" + "time" + + "github.com/absmach/supermq" + "github.com/absmach/supermq/reports" +) + +var ( + _ supermq.Response = (*addReportConfigRes)(nil) + _ supermq.Response = (*viewReportConfigRes)(nil) + _ supermq.Response = (*updateReportConfigRes)(nil) + _ supermq.Response = (*deleteReportConfigRes)(nil) + _ supermq.Response = (*listReportsConfigRes)(nil) +) + +type pageRes struct { + Limit uint64 `json:"limit,omitempty"` + Offset uint64 `json:"offset"` + Total uint64 `json:"total"` +} + +type generateReportResp struct { + Total uint64 `json:"total"` + From time.Time `json:"from,omitempty"` + To time.Time `json:"to,omitempty"` + Aggregation reports.AggConfig `json:"aggregation,omitempty"` + Reports []reports.Report `json:"reports,omitempty"` +} + +func (res generateReportResp) Code() int { + return http.StatusOK +} + +func (res generateReportResp) Headers() map[string]string { + return map[string]string{} +} + +func (res generateReportResp) Empty() bool { + return false +} + +type addReportConfigRes struct { + reports.ReportConfig `json:",inline"` + created bool +} + +func (res addReportConfigRes) Code() int { + if res.created { + return http.StatusCreated + } + return http.StatusOK +} + +func (res addReportConfigRes) Headers() map[string]string { + if res.created { + return map[string]string{} + } + return map[string]string{} +} + +func (res addReportConfigRes) Empty() bool { + return false +} + +type viewReportConfigRes struct { + reports.ReportConfig `json:",inline"` +} + +func (res viewReportConfigRes) Code() int { + return http.StatusOK +} + +func (res viewReportConfigRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewReportConfigRes) Empty() bool { + return false +} + +type updateReportConfigRes struct { + reports.ReportConfig `json:",inline"` +} + +func (res updateReportConfigRes) Code() int { + return http.StatusOK +} + +func (res updateReportConfigRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updateReportConfigRes) Empty() bool { + return false +} + +type deleteReportConfigRes struct { + deleted bool +} + +func (res deleteReportConfigRes) Code() int { + if res.deleted { + return http.StatusNoContent + } + return http.StatusOK +} + +func (res deleteReportConfigRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteReportConfigRes) Empty() bool { + return true +} + +type listReportsConfigRes struct { + pageRes + ReportConfigs []reports.ReportConfig `json:"report_configs"` +} + +func (res listReportsConfigRes) Code() int { + return http.StatusOK +} + +func (res listReportsConfigRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listReportsConfigRes) Empty() bool { + return false +} + +type downloadReportResp struct { + File reports.ReportFile +} + +func (res downloadReportResp) Code() int { + return http.StatusOK +} + +func (res downloadReportResp) Headers() map[string]string { + return map[string]string{} +} + +func (res downloadReportResp) Empty() bool { + return false +} + +type emailReportResp struct{} + +func (res emailReportResp) Code() int { + return http.StatusOK +} + +func (res emailReportResp) Headers() map[string]string { + return map[string]string{} +} + +func (res emailReportResp) Empty() bool { + return true +} + +type viewReportTemplateRes struct { + Template reports.ReportTemplate `json:"html_template"` +} + +func (res viewReportTemplateRes) Code() int { + return http.StatusOK +} + +func (res viewReportTemplateRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewReportTemplateRes) Empty() bool { + return false +} + +type updateReportTemplateRes struct { + updated bool +} + +func (res updateReportTemplateRes) Code() int { + if res.updated { + return http.StatusNoContent + } + return http.StatusOK +} + +func (res updateReportTemplateRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updateReportTemplateRes) Empty() bool { + return true +} + +type deleteReportTemplateRes struct { + deleted bool +} + +func (res deleteReportTemplateRes) Code() int { + if res.deleted { + return http.StatusNoContent + } + return http.StatusOK +} + +func (res deleteReportTemplateRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteReportTemplateRes) Empty() bool { + return true +} diff --git a/reports/api/transport.go b/reports/api/transport.go new file mode 100644 index 000000000..bc74fec90 --- /dev/null +++ b/reports/api/transport.go @@ -0,0 +1,313 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "strings" + + "github.com/absmach/supermq" + api "github.com/absmach/supermq/api/http" + apiutil "github.com/absmach/supermq/api/http/util" + smqauthn "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api" + "github.com/absmach/supermq/reports" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" +) + +const ( + reportIdKey = "reportID" + actionKey = "action" + defAction = "view" +) + +// MakeHandler creates an HTTP handler for the service endpoints. +func MakeHandler(svc reports.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + mux.Group(func(r chi.Router) { + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(true)).Middleware()) + r.Route("/{domainID}", func(r chi.Router) { + r.Route("/reports", func(r chi.Router) { + d := roleManagerHttp.NewDecoder("reportID") + + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + generateReportEndpoint(svc), + decodeGenerateReportRequest, + encodeFileDownloadResponse, + opts..., + ), "generate_report").ServeHTTP) + + r = roleManagerHttp.EntityAvailableActionsRouter(svc, d, r, opts) + + r.Route("/configs", func(r chi.Router) { + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + addReportConfigEndpoint(svc), + decodeAddReportConfigRequest, + api.EncodeResponse, + opts..., + ), "add_report_config").ServeHTTP) + + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listReportsConfigEndpoint(svc), + decodeListReportsConfigRequest, + api.EncodeResponse, + opts..., + ), "list_reports_config").ServeHTTP) + + r.Route("/{reportID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewReportConfigEndpoint(svc), + decodeViewReportConfigRequest, + api.EncodeResponse, + opts..., + ), "view_report_config").ServeHTTP) + + r.Patch("/", otelhttp.NewHandler(kithttp.NewServer( + updateReportConfigEndpoint(svc), + decodeUpdateReportConfigRequest, + api.EncodeResponse, + opts..., + ), "update_report_config").ServeHTTP) + + r.Patch("/schedule", otelhttp.NewHandler(kithttp.NewServer( + updateReportScheduleEndpoint(svc), + decodeUpdateReportScheduleRequest, + api.EncodeResponse, + opts..., + ), "update_report_scheduler").ServeHTTP) + + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteReportConfigEndpoint(svc), + decodeDeleteReportConfigRequest, + api.EncodeResponse, + opts..., + ), "delete_report_config").ServeHTTP) + + r.Post("/enable", otelhttp.NewHandler(kithttp.NewServer( + enableReportConfigEndpoint(svc), + decodeUpdateReportStatusRequest, + api.EncodeResponse, + opts..., + ), "enable_report_config").ServeHTTP) + + r.Post("/disable", otelhttp.NewHandler(kithttp.NewServer( + disableReportConfigEndpoint(svc), + decodeUpdateReportStatusRequest, + api.EncodeResponse, + opts..., + ), "disable_report_config").ServeHTTP) + + r.Put("/template", otelhttp.NewHandler(kithttp.NewServer( + updateReportTemplateEndpoint(svc), + decodeUpdateReportTemplateRequest, + api.EncodeResponse, + opts..., + ), "update_report_template").ServeHTTP) + + r.Get("/template", otelhttp.NewHandler(kithttp.NewServer( + viewReportTemplateEndpoint(svc), + decodeGetReportTemplateRequest, + api.EncodeResponse, + opts..., + ), "get_report_template").ServeHTTP) + + r.Delete("/template", otelhttp.NewHandler(kithttp.NewServer( + deleteReportTemplateEndpoint(svc), + decodeDeleteReportTemplateRequest, + api.EncodeResponse, + opts..., + ), "delete_report_template").ServeHTTP) + + roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts) + }) + }) + }) + }) + }) + + mux.Get("/health", supermq.Health("reports", instanceID)) + mux.Handle("/metrics", promhttp.Handler()) + + return mux +} + +func decodeGenerateReportRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + a, err := apiutil.ReadStringQuery(r, actionKey, defAction) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + action, err := reports.ToReportAction(a) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + req := generateReportReq{ + action: action, + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(err, apiutil.ErrValidation) + } + + return req, nil +} + +func decodeAddReportConfigRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + var config reports.ReportConfig + if err := json.NewDecoder(r.Body).Decode(&config); err != nil { + return nil, errors.Wrap(err, apiutil.ErrValidation) + } + return addReportConfigReq{ReportConfig: config}, nil +} + +func decodeViewReportConfigRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, reportIdKey) + withRoles, err := apiutil.ReadBoolQuery(r, api.RolesKey, false) + if err != nil { + return nil, err + } + return viewReportConfigReq{ID: id, withRoles: withRoles}, nil +} + +func decodeUpdateReportConfigRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + var config reports.ReportConfig + if err := json.NewDecoder(r.Body).Decode(&config); err != nil { + return nil, errors.Wrap(err, apiutil.ErrValidation) + } + config.ID = chi.URLParam(r, reportIdKey) + return updateReportConfigReq{ReportConfig: config}, nil +} + +func decodeUpdateReportScheduleRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateReportScheduleReq{ + id: chi.URLParam(r, reportIdKey), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeUpdateReportStatusRequest(_ context.Context, r *http.Request) (any, error) { + req := updateReportStatusReq{ + id: chi.URLParam(r, reportIdKey), + } + return req, nil +} + +func decodeDeleteReportConfigRequest(_ context.Context, r *http.Request) (any, error) { + id := chi.URLParam(r, reportIdKey) + return deleteReportConfigReq{ID: id}, nil +} + +func decodeUpdateReportTemplateRequest(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updateReportTemplateReq{} + req.ID = chi.URLParam(r, reportIdKey) + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(err, apiutil.ErrValidation) + } + + return req, nil +} + +func decodeGetReportTemplateRequest(_ context.Context, r *http.Request) (any, error) { + return getReportTemplateReq{ID: chi.URLParam(r, reportIdKey)}, nil +} + +func decodeDeleteReportTemplateRequest(_ context.Context, r *http.Request) (any, error) { + return deleteReportTemplateReq{ID: chi.URLParam(r, reportIdKey)}, nil +} + +func decodeListReportsConfigRequest(_ context.Context, r *http.Request) (any, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + status, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefStatus) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + st, err := reports.ToStatus(status) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + name, err := apiutil.ReadStringQuery(r, api.NameKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + dir, err := apiutil.ReadStringQuery(r, api.DirKey, "desc") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + return listReportsConfigReq{ + PageMeta: reports.PageMeta{ + Offset: offset, + Limit: limit, + Status: st, + Name: name, + Dir: dir, + Order: order, + }, + }, nil +} + +func encodeFileDownloadResponse(_ context.Context, w http.ResponseWriter, response any) error { + switch resp := response.(type) { + case downloadReportResp: + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", resp.File.Name)) + w.Header().Set("Content-Type", resp.File.Format.ContentType()) + _, err := w.Write(resp.File.Data) + return err + default: + if ar, ok := response.(supermq.Response); ok { + for k, v := range ar.Headers() { + w.Header().Set(k, v) + } + w.Header().Set("Content-Type", api.ContentType) + w.WriteHeader(ar.Code()) + + if ar.Empty() { + return nil + } + } + return json.NewEncoder(w).Encode(response) + } +} diff --git a/reports/builtinroles.go b/reports/builtinroles.go new file mode 100644 index 000000000..9ce5064d0 --- /dev/null +++ b/reports/builtinroles.go @@ -0,0 +1,8 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import "github.com/absmach/supermq/pkg/roles" + +const BuiltInRoleAdmin roles.BuiltInRoleName = "admin" diff --git a/reports/generator.go b/reports/generator.go new file mode 100644 index 000000000..b7f39c4c2 --- /dev/null +++ b/reports/generator.go @@ -0,0 +1,342 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "bytes" + "context" + "encoding/csv" + "fmt" + "html/template" + "io" + "log/slog" + "mime/multipart" + "net/http" + "sort" + "strings" + "time" + _ "time/tzdata" // Embed timezone database + + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/transformers/senml" +) + +const nanosecondThreshold = float64(10 * time.Second / time.Nanosecond) + +type ReportData struct { + Title string + GeneratedTime string + GeneratedDate string + Reports []Report + Timezone string +} + +func (r *report) generatePDFReport(ctx context.Context, title string, reports []Report, template ReportTemplate, timezone string) ([]byte, error) { + for i := range reports { + sort.Slice(reports[i].Messages, func(j, k int) bool { + return reports[i].Messages[j].Time < reports[i].Messages[k].Time + }) + } + + loc, err := resolveTimezone(timezone) + if err != nil { + r.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to resolve timezone '%s', falling back to UTC: %s", timezone, err), + Details: []slog.Attr{ + slog.String("report_title", title), + slog.Time("time", time.Now().UTC()), + }, + } + } + + now := time.Now().In(loc) + displayTZ := timezone + if strings.TrimSpace(displayTZ) == "" { + displayTZ = "UTC" + } + + data := ReportData{ + Title: title, + GeneratedTime: now.Format("15:04:05"), + GeneratedDate: now.Format("02 Jan 2006"), + Reports: reports, + Timezone: displayTZ, + } + + templateContent := r.defaultTemplate.String() + if template.String() != "" { + templateContent = template.String() + } + return r.generate(ctx, templateContent, data) +} + +func (r *report) generate(ctx context.Context, templateContent string, data ReportData) ([]byte, error) { + tmpl := template.New("report").Funcs(template.FuncMap{ + "formatTime": func(t float64) string { return r.formatTimeWithTimezone(t, data.Timezone) }, + "formatValue": formatValue, + "add": func(a, b int) int { return a + b }, + "sub": func(a, b int) int { return a - b }, + "iterate": func(count int) []int { return makeRange(count) }, + "ge": func(a, b int) bool { return a >= b }, + "lt": func(a, b int) bool { return a < b }, + "eq": func(a, b int) bool { return a == b }, + "div": func(a, b int) int { + if b == 0 { + return 0 + } + return a / b + }, + "mod": func(a, b int) int { + if b == 0 { + return 0 + } + return a % b + }, + "getStartRow": getStartRow, + "getEndRow": getEndRow, + }) + + tmpl, err := tmpl.Parse(templateContent) + if err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + var htmlBuf bytes.Buffer + if err := tmpl.Execute(&htmlBuf, data); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + htmlContent := htmlBuf.String() + pdfBytes, err := r.htmlToPDF(ctx, htmlContent) + if err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + return pdfBytes, nil +} + +func (r *report) htmlToPDF(ctx context.Context, htmlContent string) ([]byte, error) { + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + htmlPart, err := writer.CreateFormFile("files", "index.html") + if err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if _, err := htmlPart.Write([]byte(htmlContent)); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.WriteField("marginTop", "0"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.WriteField("marginBottom", "0"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.WriteField("marginLeft", "0"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.WriteField("marginRight", "0"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.WriteField("printBackground", "true"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.WriteField("preferCSSPageSize", "true"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.WriteField("emulatedMediaType", "print"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.WriteField("waitForSelector", "body"); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.Close(); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.converterURL, &requestBody) + if err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + defer resp.Body.Close() + + pdfBytes, err := io.ReadAll(resp.Body) + if err != nil || resp.StatusCode != http.StatusOK { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + return pdfBytes, nil +} + +func (r *report) formatTimeWithTimezone(t float64, timezone string) string { + loc, err := resolveTimezone(timezone) + if err != nil { + r.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to resolve timezone '%s', falling back to UTC: %s", timezone, err), + Details: []slog.Attr{slog.Time("time", time.Now().UTC())}, + } + } + + var timeVal time.Time + switch { + case t > nanosecondThreshold: + timeVal = time.Unix(0, int64(t)).In(loc) + default: + timeVal = time.Unix(int64(t), 0).In(loc) + } + + return timeVal.Format("2006-01-02 15:04:05") +} + +func formatValue(msg senml.Message) string { + switch { + case msg.Value != nil: + return fmt.Sprintf("%.2f", *msg.Value) + case msg.StringValue != nil: + return *msg.StringValue + case msg.BoolValue != nil: + return fmt.Sprintf("%t", *msg.BoolValue) + case msg.DataValue != nil: + return *msg.DataValue + default: + return "N/A" + } +} + +func makeRange(n int) []int { + result := make([]int, n) + for i := range result { + result[i] = i + } + return result +} + +func getStartRow(pageNum, firstPageRows, continuationPageRows int) int { + if pageNum == 0 { + return 0 + } + return firstPageRows + (pageNum-1)*continuationPageRows +} + +func getEndRow(pageNum, firstPageRows, continuationPageRows, totalMessages int) int { + var end int + if pageNum == 0 { + end = firstPageRows + } else { + start := firstPageRows + (pageNum-1)*continuationPageRows + end = start + continuationPageRows + } + + if end > totalMessages { + end = totalMessages + } + return end +} + +func (r *report) generateCSVReport(_ context.Context, title string, reports []Report, timezone string) ([]byte, error) { + var buf bytes.Buffer + writer := csv.NewWriter(&buf) + + headers := []string{"Time", "Value", "Unit", "Protocol", "Subtopic"} + + for i, report := range reports { + if i > 0 { + if err := writer.Write([]string{""}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.Write([]string{"=== NEW REPORT ==="}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.Write([]string{""}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + } else { + if err := writer.Write([]string{title}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.Write([]string{""}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + } + + if err := writer.Write([]string{"Report Information:"}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.Write([]string{"Name", report.Metric.Name}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if report.Metric.ClientID != "" { + if err := writer.Write([]string{"Device ID", report.Metric.ClientID}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + } + if err := writer.Write([]string{"Channel ID", report.Metric.ChannelID}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := writer.Write([]string{""}); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := writer.Write(headers); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + sort.Slice(report.Messages, func(i, j int) bool { + return report.Messages[i].Time < report.Messages[j].Time + }) + + for _, msg := range report.Messages { + timeStr := r.formatTimeWithTimezone(msg.Time, timezone) + + var valueStr string + if msg.Value != nil { + valueStr = fmt.Sprintf("%.2f", *msg.Value) + } else if msg.StringValue != nil { + valueStr = *msg.StringValue + } else if msg.BoolValue != nil { + valueStr = fmt.Sprintf("%v", *msg.BoolValue) + } else if msg.DataValue != nil { + valueStr = *msg.DataValue + } else { + valueStr = "N/A" + } + + row := []string{ + timeStr, + valueStr, + msg.Unit, + msg.Protocol, + msg.Subtopic, + } + + if err := writer.Write(row); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + } + } + + writer.Flush() + if err := writer.Error(); err != nil { + return nil, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + return buf.Bytes(), nil +} diff --git a/reports/handler.go b/reports/handler.go new file mode 100644 index 000000000..4457d7ed6 --- /dev/null +++ b/reports/handler.go @@ -0,0 +1,64 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "context" + "fmt" + "log/slog" + "time" + + pkglog "github.com/absmach/supermq/pkg/logger" +) + +func (r *report) StartScheduler(ctx context.Context) error { + defer r.ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-r.ticker.Tick(): + due := time.Now().UTC() + + pm := PageMeta{ + Status: EnabledStatus, + ScheduledBefore: &due, + } + + reportConfigs, err := r.repo.ListAllReportsConfig(ctx, pm) + if err != nil { + r.runInfo <- pkglog.RunInfo{ + Level: slog.LevelError, + Message: fmt.Sprintf("failed to list reports : %s", err), + Details: []slog.Attr{slog.Time("due", due)}, + } + continue + } + + for _, c := range reportConfigs.ReportConfigs { + go func(cfg ReportConfig) { + if _, err := r.repo.UpdateReportDue(ctx, cfg.ID, cfg.Schedule.NextDue()); err != nil { + r.runInfo <- pkglog.RunInfo{Level: slog.LevelError, Message: fmt.Sprintf("failed to update report: %s", err), Details: []slog.Attr{slog.Time("time", time.Now().UTC())}} + return + } + _, err := r.generateReport(ctx, cfg, EmailReport) + ret := pkglog.RunInfo{ + Details: []slog.Attr{ + slog.String("domain_id", cfg.DomainID), + slog.String("report_id", cfg.ID), + slog.String("report_name", cfg.Name), + slog.Time("exec_time", time.Now().UTC()), + }, + } + if err != nil { + ret.Level = slog.LevelError + ret.Message = fmt.Sprintf("failed to generate report: %s", err) + } + r.runInfo <- ret + }(c) + } + } + } +} diff --git a/reports/middleware/authorization.go b/reports/middleware/authorization.go new file mode 100644 index 000000000..8aa618852 --- /dev/null +++ b/reports/middleware/authorization.go @@ -0,0 +1,210 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + rolemgr "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/operations" +) + +var ( + errDomainCreateConfigs = errors.New("not authorized to create report configs in domain") + errDomainViewConfigs = errors.New("not authorized to view report configs in domain") + errDomainUpdateConfigs = errors.New("not authorized to update report configs in domain") + errDomainDeleteConfigs = errors.New("not authorized to delete report configs in domain") + errDomainGenerateReports = errors.New("not authorized to generate reports in domain") + + errDomainUpdateTemplates = errors.New("not authorized to update report templates in domain") + errDomainRemoveTemplates = errors.New("not authorized to delete report templates in domain") + errDomainViewTemplates = errors.New("not authorized to view report templates in domain") +) + +type authorizationMiddleware struct { + svc reports.Service + authz smqauthz.Authorization + entitiesOps permissions.EntitiesOperations[permissions.Operation] + rolemgr.RoleManagerAuthorizationMiddleware +} + +// AuthorizationMiddleware adds authorization to the reports service. +func AuthorizationMiddleware(svc reports.Service, authz smqauthz.Authorization, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (reports.Service, error) { + if err := entitiesOps.Validate(); err != nil { + return nil, err + } + ram, err := rolemgr.NewAuthorization(operations.EntityType, svc, authz, roleOps) + if err != nil { + return nil, err + } + return &authorizationMiddleware{ + svc: svc, + authz: authz, + entitiesOps: entitiesOps, + RoleManagerAuthorizationMiddleware: ram, + }, nil +} + +func (am *authorizationMiddleware) AddReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpAddReportConfig, session, policies.DomainType, session.DomainID); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainCreateConfigs, err) + } + + return am.svc.AddReportConfig(ctx, session, cfg) +} + +func (am *authorizationMiddleware) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpViewReportConfig, session, operations.EntityType, id); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainViewConfigs, err) + } + + return am.svc.ViewReportConfig(ctx, session, id, withRoles) +} + +func (am *authorizationMiddleware) UpdateReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpUpdateReportConfig, session, operations.EntityType, cfg.ID); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainUpdateConfigs, err) + } + + return am.svc.UpdateReportConfig(ctx, session, cfg) +} + +func (am *authorizationMiddleware) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpUpdateReportSchedule, session, operations.EntityType, cfg.ID); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainUpdateConfigs, err) + } + + return am.svc.UpdateReportSchedule(ctx, session, cfg) +} + +func (am *authorizationMiddleware) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, operations.OpRemoveReportConfig, session, operations.EntityType, id); err != nil { + return errors.Wrap(errDomainDeleteConfigs, err) + } + + return am.svc.RemoveReportConfig(ctx, session, id) +} + +func (am *authorizationMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { + switch err := am.checkSuperAdmin(ctx, session); { + case err == nil: + session.SuperAdmin = true + case errors.Contains(err, svcerr.ErrSuperAdminAction): + default: + return reports.ReportConfigPage{}, err + } + + return am.svc.ListReportsConfig(ctx, session, pm) +} + +func (am *authorizationMiddleware) EnableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpEnableReportConfig, session, operations.EntityType, id); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainUpdateConfigs, err) + } + + return am.svc.EnableReportConfig(ctx, session, id) +} + +func (am *authorizationMiddleware) DisableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + if err := am.authorize(ctx, operations.OpDisableReportConfig, session, operations.EntityType, id); err != nil { + return reports.ReportConfig{}, errors.Wrap(errDomainUpdateConfigs, err) + } + + return am.svc.DisableReportConfig(ctx, session, id) +} + +func (am *authorizationMiddleware) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error) { + if err := am.authorize(ctx, operations.OpGenerateReport, session, policies.DomainType, session.DomainID); err != nil { + return reports.ReportPage{}, errors.Wrap(errDomainGenerateReports, err) + } + + return am.svc.GenerateReport(ctx, session, config, action) +} + +func (am *authorizationMiddleware) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error { + if err := am.authorize(ctx, operations.OpUpdateReportTemplate, session, operations.EntityType, cfg.ID); err != nil { + return errors.Wrap(errDomainUpdateTemplates, err) + } + + return am.svc.UpdateReportTemplate(ctx, session, cfg) +} + +func (am *authorizationMiddleware) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error) { + if err := am.authorize(ctx, operations.OpViewReportTemplate, session, operations.EntityType, id); err != nil { + return "", errors.Wrap(errDomainViewTemplates, err) + } + + return am.svc.ViewReportTemplate(ctx, session, id) +} + +func (am *authorizationMiddleware) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + if err := am.authorize(ctx, operations.OpDeleteReportTemplate, session, operations.EntityType, id); err != nil { + return errors.Wrap(errDomainRemoveTemplates, err) + } + + return am.svc.DeleteReportTemplate(ctx, session, id) +} + +func (am *authorizationMiddleware) StartScheduler(ctx context.Context) error { + return am.svc.StartScheduler(ctx) +} + +func (am *authorizationMiddleware) authorize(ctx context.Context, op permissions.Operation, session authn.Session, objType, obj string) error { + perm, err := am.entitiesOps.GetPermission(operations.EntityType, op) + if err != nil { + return err + } + + pr := smqauthz.PolicyReq{ + Domain: session.DomainID, + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Object: obj, + ObjectType: objType, + Permission: perm.String(), + } + + var pat *smqauthz.PATReq + if session.PatID != "" { + opName := am.entitiesOps.OperationName(operations.EntityType, op) + pat = &smqauthz.PATReq{ + UserID: session.UserID, + PatID: session.PatID, + EntityID: session.DomainID, + EntityType: operations.EntityType, + Operation: opName, + Domain: session.DomainID, + } + } + + if err := am.authz.Authorize(ctx, pr, pat); err != nil { + return err + } + + return nil +} + +func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session authn.Session) error { + if session.Role != authn.SuperAdminRole { + return svcerr.ErrSuperAdminAction + } + if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{ + SubjectType: policies.UserType, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.MagistralaObject, + }, nil); err != nil { + return err + } + return nil +} diff --git a/reports/middleware/callout.go b/reports/middleware/callout.go new file mode 100644 index 000000000..81c058dd3 --- /dev/null +++ b/reports/middleware/callout.go @@ -0,0 +1,222 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/callout" + "github.com/absmach/supermq/pkg/permissions" + "github.com/absmach/supermq/pkg/policies" + mgPolicies "github.com/absmach/supermq/pkg/policies" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/operations" +) + +var _ reports.Service = (*calloutMiddleware)(nil) + +type calloutMiddleware struct { + svc reports.Service + callout callout.Callout + entitiesOps permissions.EntitiesOperations[permissions.Operation] + rolemw.RoleManagerCalloutMiddleware +} + +const entityType = "report" + +func NewCallout(svc reports.Service, callout callout.Callout, entitiesOps permissions.EntitiesOperations[permissions.Operation], roleOps permissions.Operations[permissions.RoleOperation]) (reports.Service, error) { + call, err := rolemw.NewCallout(mgPolicies.ReportsType, svc, callout, roleOps) + if err != nil { + return nil, err + } + + if err := entitiesOps.Validate(); err != nil { + return nil, err + } + + return &calloutMiddleware{ + svc: svc, + callout: callout, + entitiesOps: entitiesOps, + RoleManagerCalloutMiddleware: call, + }, nil +} + +func (cm *calloutMiddleware) AddReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + params := map[string]any{ + "entities": cfg, + "count": 1, + } + + if err := cm.callOut(ctx, session, operations.OpAddReportConfig, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.AddReportConfig(ctx, session, cfg) +} + +func (cm *calloutMiddleware) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpViewReportConfig, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.ViewReportConfig(ctx, session, id, withRoles) +} + +func (cm *calloutMiddleware) UpdateReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + params := map[string]any{ + "entity_id": cfg.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateReportConfig, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.UpdateReportConfig(ctx, session, cfg) +} + +func (cm *calloutMiddleware) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + params := map[string]any{ + "entity_id": cfg.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateReportSchedule, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.UpdateReportSchedule(ctx, session, cfg) +} + +func (cm *calloutMiddleware) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpRemoveReportConfig, params); err != nil { + return err + } + + return cm.svc.RemoveReportConfig(ctx, session, id) +} + +func (cm *calloutMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { + params := map[string]any{ + "pagemeta": pm, + } + + if err := cm.callOut(ctx, session, operations.OpListReportsConfig, params); err != nil { + return reports.ReportConfigPage{}, err + } + + return cm.svc.ListReportsConfig(ctx, session, pm) +} + +func (cm *calloutMiddleware) EnableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpEnableReportConfig, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.EnableReportConfig(ctx, session, id) +} + +func (cm *calloutMiddleware) DisableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpDisableReportConfig, params); err != nil { + return reports.ReportConfig{}, err + } + + return cm.svc.DisableReportConfig(ctx, session, id) +} + +func (cm *calloutMiddleware) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error) { + params := map[string]any{ + "entity_id": config.ID, + } + + if err := cm.callOut(ctx, session, operations.OpGenerateReport, params); err != nil { + return reports.ReportPage{}, err + } + + return cm.svc.GenerateReport(ctx, session, config, action) +} + +func (cm *calloutMiddleware) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error { + params := map[string]any{ + "entity_id": cfg.ID, + } + + if err := cm.callOut(ctx, session, operations.OpUpdateReportTemplate, params); err != nil { + return err + } + + return cm.svc.UpdateReportTemplate(ctx, session, cfg) +} + +func (cm *calloutMiddleware) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error) { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpViewReportTemplate, params); err != nil { + return "", err + } + + return cm.svc.ViewReportTemplate(ctx, session, id) +} + +func (cm *calloutMiddleware) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + params := map[string]any{ + "entity_id": id, + } + + if err := cm.callOut(ctx, session, operations.OpDeleteReportTemplate, params); err != nil { + return err + } + + return cm.svc.DeleteReportTemplate(ctx, session, id) +} + +func (cm *calloutMiddleware) StartScheduler(ctx context.Context) error { + return cm.svc.StartScheduler(ctx) +} + +func (cm *calloutMiddleware) callOut(ctx context.Context, session authn.Session, op permissions.Operation, pld map[string]any) error { + var entityID string + if id, ok := pld["entity_id"].(string); ok { + entityID = id + } + + req := callout.Request{ + BaseRequest: callout.BaseRequest{ + Operation: cm.entitiesOps.OperationName(entityType, op), + EntityType: entityType, + EntityID: entityID, + CallerID: session.UserID, + CallerType: policies.UserType, + DomainID: session.DomainID, + Time: time.Now().UTC(), + }, + Payload: pld, + } + + if err := cm.callout.Callout(ctx, req); err != nil { + return err + } + + return nil +} diff --git a/reports/middleware/logging.go b/reports/middleware/logging.go new file mode 100644 index 000000000..3da275c3c --- /dev/null +++ b/reports/middleware/logging.go @@ -0,0 +1,270 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "log/slog" + "time" + + "github.com/absmach/supermq/pkg/authn" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/reports" +) + +var _ reports.Service = (*loggingMiddleware)(nil) + +type loggingMiddleware struct { + logger *slog.Logger + svc reports.Service + rolemw.RoleManagerLoggingMiddleware +} + +func LoggingMiddleware(svc reports.Service, logger *slog.Logger) reports.Service { + return &loggingMiddleware{ + logger: logger, + svc: svc, + RoleManagerLoggingMiddleware: rolemw.NewLogging("reports", svc, logger), + } +} + +func (lm *loggingMiddleware) StartScheduler(ctx context.Context) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Start scheduler failed", args...) + return + } + lm.logger.Info("Start scheduler completed successfully", args...) + }(time.Now()) + return lm.svc.StartScheduler(ctx) +} + +func (lm *loggingMiddleware) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (page reports.ReportPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Generate report failed", args...) + return + } + lm.logger.Info("Generate report completed", args...) + }(time.Now()) + + return lm.svc.GenerateReport(ctx, session, config, action) +} + +func (lm *loggingMiddleware) AddReportConfig(ctx context.Context, session authn.Session, config reports.ReportConfig) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("report_name", config.Name), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Add report config failed", args...) + return + } + lm.logger.Info("Add report config completed successfully", args...) + }(time.Now()) + return lm.svc.AddReportConfig(ctx, session, config) +} + +func (lm *loggingMiddleware) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("report_config", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("View report config failed", args...) + return + } + lm.logger.Info("View report config completed successfully", args...) + }(time.Now()) + return lm.svc.ViewReportConfig(ctx, session, id, withRoles) +} + +func (lm *loggingMiddleware) UpdateReportConfig(ctx context.Context, session authn.Session, config reports.ReportConfig) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("report_config", + slog.String("id", config.ID), + slog.String("name", config.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update report config failed", args...) + return + } + lm.logger.Info("Update report config completed successfully", args...) + }(time.Now()) + return lm.svc.UpdateReportConfig(ctx, session, config) +} + +func (lm *loggingMiddleware) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("report", + slog.String("id", cfg.ID), + slog.Any("schedule", cfg.Schedule), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update report schedule failed", args...) + return + } + lm.logger.Info("Update report schedule completed successfully", args...) + }(time.Now()) + return lm.svc.UpdateReportSchedule(ctx, session, cfg) +} + +func (lm *loggingMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (pg reports.ReportConfigPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("page", + slog.Uint64("offset", pm.Offset), + slog.Uint64("limit", pm.Limit), + slog.Uint64("total", pg.Total), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("List reports config failed", args...) + return + } + lm.logger.Info("List reports config completed successfully", args...) + }(time.Now()) + return lm.svc.ListReportsConfig(ctx, session, pm) +} + +func (lm *loggingMiddleware) DisableReportConfig(ctx context.Context, session authn.Session, id string) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("report_config", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Disable report config failed", args...) + return + } + lm.logger.Info("Disable report config completed successfully", args...) + }(time.Now()) + return lm.svc.DisableReportConfig(ctx, session, id) +} + +func (lm *loggingMiddleware) EnableReportConfig(ctx context.Context, session authn.Session, id string) (res reports.ReportConfig, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.Group("report_config", + slog.String("id", res.ID), + slog.String("name", res.Name), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Enable report config failed", args...) + return + } + lm.logger.Info("Enable report config completed successfully", args...) + }(time.Now()) + return lm.svc.EnableReportConfig(ctx, session, id) +} + +func (lm *loggingMiddleware) RemoveReportConfig(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("report_config_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Remove report config failed", args...) + return + } + lm.logger.Info("Remove report config completed successfully", args...) + }(time.Now()) + return lm.svc.RemoveReportConfig(ctx, session, id) +} + +func (lm *loggingMiddleware) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("report_config_id", cfg.ID), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Update report template failed", args...) + return + } + lm.logger.Info("Update report template completed successfully", args...) + }(time.Now()) + + return lm.svc.UpdateReportTemplate(ctx, session, cfg) +} + +func (lm *loggingMiddleware) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (t reports.ReportTemplate, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("report_config_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("View report template failed", args...) + return + } + lm.logger.Info("View report template completed successfully", args...) + }(time.Now()) + + return lm.svc.ViewReportTemplate(ctx, session, id) +} + +func (lm *loggingMiddleware) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", session.DomainID), + slog.String("report_config_id", id), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Delete report template failed", args...) + return + } + lm.logger.Info("Delete report template completed successfully", args...) + }(time.Now()) + + return lm.svc.DeleteReportTemplate(ctx, session, id) +} diff --git a/reports/middleware/metrics.go b/reports/middleware/metrics.go new file mode 100644 index 000000000..f6c7b0a6c --- /dev/null +++ b/reports/middleware/metrics.go @@ -0,0 +1,149 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/authn" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + "github.com/absmach/supermq/reports" + "github.com/go-kit/kit/metrics" +) + +type metricsMiddleware struct { + counter metrics.Counter + latency metrics.Histogram + service reports.Service + rolemw.RoleManagerMetricsMiddleware +} + +var _ reports.Service = (*metricsMiddleware)(nil) + +func NewMetricsMiddleware(counter metrics.Counter, latency metrics.Histogram, service reports.Service) reports.Service { + return &metricsMiddleware{ + counter: counter, + latency: latency, + service: service, + RoleManagerMetricsMiddleware: rolemw.NewMetrics("reports", service, counter, latency), + } +} + +func (mm *metricsMiddleware) AddReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "add_report_config").Add(1) + mm.latency.With("method", "add_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.AddReportConfig(ctx, session, cfg) +} + +func (mm *metricsMiddleware) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "view_report_config").Add(1) + mm.latency.With("method", "view_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewReportConfig(ctx, session, id, withRoles) +} + +func (mm *metricsMiddleware) UpdateReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_report_config").Add(1) + mm.latency.With("method", "update_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateReportConfig(ctx, session, cfg) +} + +func (mm *metricsMiddleware) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "update_report_schedule").Add(1) + mm.latency.With("method", "update_report_schedule").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateReportSchedule(ctx, session, cfg) +} + +func (mm *metricsMiddleware) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "remove_report_config").Add(1) + mm.latency.With("method", "remove_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.RemoveReportConfig(ctx, session, id) +} + +func (mm *metricsMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_reports_config").Add(1) + mm.latency.With("method", "list_reports_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ListReportsConfig(ctx, session, pm) +} + +func (mm *metricsMiddleware) EnableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "enable_report_config").Add(1) + mm.latency.With("method", "enable_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.EnableReportConfig(ctx, session, id) +} + +func (mm *metricsMiddleware) DisableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + defer func(begin time.Time) { + mm.counter.With("method", "disable_report_config").Add(1) + mm.latency.With("method", "disable_report_config").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DisableReportConfig(ctx, session, id) +} + +func (mm *metricsMiddleware) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error { + defer func(begin time.Time) { + mm.counter.With("method", "update_report_template").Add(1) + mm.latency.With("method", "update_report_template").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.UpdateReportTemplate(ctx, session, cfg) +} + +func (mm *metricsMiddleware) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "view_report_template").Add(1) + mm.latency.With("method", "view_report_template").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.ViewReportTemplate(ctx, session, id) +} + +func (mm *metricsMiddleware) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + defer func(begin time.Time) { + mm.counter.With("method", "delete_report_template").Add(1) + mm.latency.With("method", "delete_report_template").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.DeleteReportTemplate(ctx, session, id) +} + +func (mm *metricsMiddleware) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error) { + defer func(begin time.Time) { + mm.counter.With("method", "generate_report").Add(1) + mm.latency.With("method", "generate_report").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.GenerateReport(ctx, session, config, action) +} + +func (mm *metricsMiddleware) StartScheduler(ctx context.Context) error { + defer func(begin time.Time) { + mm.counter.With("method", "start_scheduler").Add(1) + mm.latency.With("method", "start_scheduler").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.service.StartScheduler(ctx) +} diff --git a/reports/middleware/tracing.go b/reports/middleware/tracing.go new file mode 100644 index 000000000..05f0f126f --- /dev/null +++ b/reports/middleware/tracing.go @@ -0,0 +1,149 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + rolemw "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" + smqTracing "github.com/absmach/supermq/pkg/tracing" + "github.com/absmach/supermq/reports" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" +) + +type tracingMiddleware struct { + tracer trace.Tracer + svc reports.Service + rolemw.RoleManagerTracing +} + +var _ reports.Service = (*tracingMiddleware)(nil) + +func NewTracingMiddleware(tracer trace.Tracer, svc reports.Service) reports.Service { + return &tracingMiddleware{ + tracer: tracer, + svc: svc, + RoleManagerTracing: rolemw.NewTracing("reports", svc, tracer), + } +} + +func (tm *tracingMiddleware) AddReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "add_report_config", trace.WithAttributes( + attribute.String("name", cfg.Name), + attribute.String("domain_id", cfg.DomainID), + )) + defer span.End() + + return tm.svc.AddReportConfig(ctx, session, cfg) +} + +func (tm *tracingMiddleware) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "view_report_config", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewReportConfig(ctx, session, id, withRoles) +} + +func (tm *tracingMiddleware) UpdateReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_report_config", trace.WithAttributes( + attribute.String("id", cfg.ID), + )) + defer span.End() + + return tm.svc.UpdateReportConfig(ctx, session, cfg) +} + +func (tm *tracingMiddleware) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_report_schedule", trace.WithAttributes( + attribute.String("id", cfg.ID), + )) + defer span.End() + + return tm.svc.UpdateReportSchedule(ctx, session, cfg) +} + +func (tm *tracingMiddleware) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "remove_report_config", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.RemoveReportConfig(ctx, session, id) +} + +func (tm *tracingMiddleware) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "list_reports_config", trace.WithAttributes( + attribute.Int("offset", int(pm.Offset)), + attribute.Int("limit", int(pm.Limit)), + )) + defer span.End() + + return tm.svc.ListReportsConfig(ctx, session, pm) +} + +func (tm *tracingMiddleware) EnableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "enable_report_config", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.EnableReportConfig(ctx, session, id) +} + +func (tm *tracingMiddleware) DisableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "disable_report_config", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DisableReportConfig(ctx, session, id) +} + +func (tm *tracingMiddleware) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "update_report_template", trace.WithAttributes( + attribute.String("id", cfg.ID), + )) + defer span.End() + + return tm.svc.UpdateReportTemplate(ctx, session, cfg) +} + +func (tm *tracingMiddleware) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "view_report_template", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.ViewReportTemplate(ctx, session, id) +} + +func (tm *tracingMiddleware) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "delete_report_template", trace.WithAttributes( + attribute.String("id", id), + )) + defer span.End() + + return tm.svc.DeleteReportTemplate(ctx, session, id) +} + +func (tm *tracingMiddleware) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error) { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "generate_report", trace.WithAttributes( + attribute.String("config_id", config.ID), + attribute.String("action", string(action)), + )) + defer span.End() + + return tm.svc.GenerateReport(ctx, session, config, action) +} + +func (tm *tracingMiddleware) StartScheduler(ctx context.Context) error { + ctx, span := smqTracing.StartSpan(ctx, tm.tracer, "start_scheduler") + defer span.End() + + return tm.svc.StartScheduler(ctx) +} diff --git a/reports/mocks/repository.go b/reports/mocks/repository.go new file mode 100644 index 000000000..947725c2f --- /dev/null +++ b/reports/mocks/repository.go @@ -0,0 +1,2271 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + "time" + + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/reports" + mock "github.com/stretchr/testify/mock" +) + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// AddReportConfig provides a mock function for the type Repository +func (_mock *Repository) AddReportConfig(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, cfg) + + if len(ret) == 0 { + panic("no return value specified for AddReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_AddReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddReportConfig' +type Repository_AddReportConfig_Call struct { + *mock.Call +} + +// AddReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - cfg reports.ReportConfig +func (_e *Repository_Expecter) AddReportConfig(ctx interface{}, cfg interface{}) *Repository_AddReportConfig_Call { + return &Repository_AddReportConfig_Call{Call: _e.mock.On("AddReportConfig", ctx, cfg)} +} + +func (_c *Repository_AddReportConfig_Call) Run(run func(ctx context.Context, cfg reports.ReportConfig)) *Repository_AddReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.ReportConfig + if args[1] != nil { + arg1 = args[1].(reports.ReportConfig) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_AddReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_AddReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_AddReportConfig_Call) RunAndReturn(run func(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Repository_AddReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// AddRoles provides a mock function for the type Repository +func (_mock *Repository) AddRoles(ctx context.Context, rps []roles.RoleProvision) ([]roles.RoleProvision, error) { + ret := _mock.Called(ctx, rps) + + if len(ret) == 0 { + panic("no return value specified for AddRoles") + } + + var r0 []roles.RoleProvision + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []roles.RoleProvision) ([]roles.RoleProvision, error)); ok { + return returnFunc(ctx, rps) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, []roles.RoleProvision) []roles.RoleProvision); ok { + r0 = returnFunc(ctx, rps) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]roles.RoleProvision) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, []roles.RoleProvision) error); ok { + r1 = returnFunc(ctx, rps) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_AddRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRoles' +type Repository_AddRoles_Call struct { + *mock.Call +} + +// AddRoles is a helper method to define mock.On call +// - ctx context.Context +// - rps []roles.RoleProvision +func (_e *Repository_Expecter) AddRoles(ctx interface{}, rps interface{}) *Repository_AddRoles_Call { + return &Repository_AddRoles_Call{Call: _e.mock.On("AddRoles", ctx, rps)} +} + +func (_c *Repository_AddRoles_Call) Run(run func(ctx context.Context, rps []roles.RoleProvision)) *Repository_AddRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []roles.RoleProvision + if args[1] != nil { + arg1 = args[1].([]roles.RoleProvision) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_AddRoles_Call) Return(roleProvisions []roles.RoleProvision, err error) *Repository_AddRoles_Call { + _c.Call.Return(roleProvisions, err) + return _c +} + +func (_c *Repository_AddRoles_Call) RunAndReturn(run func(ctx context.Context, rps []roles.RoleProvision) ([]roles.RoleProvision, error)) *Repository_AddRoles_Call { + _c.Call.Return(run) + return _c +} + +// DeleteReportTemplate provides a mock function for the type Repository +func (_mock *Repository) DeleteReportTemplate(ctx context.Context, domainID string, reportID string) error { + ret := _mock.Called(ctx, domainID, reportID) + + if len(ret) == 0 { + panic("no return value specified for DeleteReportTemplate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, domainID, reportID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_DeleteReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteReportTemplate' +type Repository_DeleteReportTemplate_Call struct { + *mock.Call +} + +// DeleteReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - reportID string +func (_e *Repository_Expecter) DeleteReportTemplate(ctx interface{}, domainID interface{}, reportID interface{}) *Repository_DeleteReportTemplate_Call { + return &Repository_DeleteReportTemplate_Call{Call: _e.mock.On("DeleteReportTemplate", ctx, domainID, reportID)} +} + +func (_c *Repository_DeleteReportTemplate_Call) Run(run func(ctx context.Context, domainID string, reportID string)) *Repository_DeleteReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_DeleteReportTemplate_Call) Return(err error) *Repository_DeleteReportTemplate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_DeleteReportTemplate_Call) RunAndReturn(run func(ctx context.Context, domainID string, reportID string) error) *Repository_DeleteReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// ListAllReportsConfig provides a mock function for the type Repository +func (_mock *Repository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ret := _mock.Called(ctx, pm) + + if len(ret) == 0 { + panic("no return value specified for ListAllReportsConfig") + } + + var r0 reports.ReportConfigPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) (reports.ReportConfigPage, error)); ok { + return returnFunc(ctx, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.PageMeta) reports.ReportConfigPage); ok { + r0 = returnFunc(ctx, pm) + } else { + r0 = ret.Get(0).(reports.ReportConfigPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.PageMeta) error); ok { + r1 = returnFunc(ctx, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListAllReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAllReportsConfig' +type Repository_ListAllReportsConfig_Call struct { + *mock.Call +} + +// ListAllReportsConfig is a helper method to define mock.On call +// - ctx context.Context +// - pm reports.PageMeta +func (_e *Repository_Expecter) ListAllReportsConfig(ctx interface{}, pm interface{}) *Repository_ListAllReportsConfig_Call { + return &Repository_ListAllReportsConfig_Call{Call: _e.mock.On("ListAllReportsConfig", ctx, pm)} +} + +func (_c *Repository_ListAllReportsConfig_Call) Run(run func(ctx context.Context, pm reports.PageMeta)) *Repository_ListAllReportsConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.PageMeta + if args[1] != nil { + arg1 = args[1].(reports.PageMeta) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ListAllReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListAllReportsConfig_Call { + _c.Call.Return(reportConfigPage, err) + return _c +} + +func (_c *Repository_ListAllReportsConfig_Call) RunAndReturn(run func(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListAllReportsConfig_Call { + _c.Call.Return(run) + return _c +} + +// ListEntityMembers provides a mock function for the type Repository +func (_mock *Repository) ListEntityMembers(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) { + ret := _mock.Called(ctx, entityID, pageQuery) + + if len(ret) == 0 { + panic("no return value specified for ListEntityMembers") + } + + var r0 roles.MembersRolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, roles.MembersRolePageQuery) (roles.MembersRolePage, error)); ok { + return returnFunc(ctx, entityID, pageQuery) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, roles.MembersRolePageQuery) roles.MembersRolePage); ok { + r0 = returnFunc(ctx, entityID, pageQuery) + } else { + r0 = ret.Get(0).(roles.MembersRolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, roles.MembersRolePageQuery) error); ok { + r1 = returnFunc(ctx, entityID, pageQuery) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListEntityMembers' +type Repository_ListEntityMembers_Call struct { + *mock.Call +} + +// ListEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - pageQuery roles.MembersRolePageQuery +func (_e *Repository_Expecter) ListEntityMembers(ctx interface{}, entityID interface{}, pageQuery interface{}) *Repository_ListEntityMembers_Call { + return &Repository_ListEntityMembers_Call{Call: _e.mock.On("ListEntityMembers", ctx, entityID, pageQuery)} +} + +func (_c *Repository_ListEntityMembers_Call) Run(run func(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery)) *Repository_ListEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 roles.MembersRolePageQuery + if args[2] != nil { + arg2 = args[2].(roles.MembersRolePageQuery) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ListEntityMembers_Call) Return(membersRolePage roles.MembersRolePage, err error) *Repository_ListEntityMembers_Call { + _c.Call.Return(membersRolePage, err) + return _c +} + +func (_c *Repository_ListEntityMembers_Call) RunAndReturn(run func(ctx context.Context, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error)) *Repository_ListEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// ListUserReportsConfig provides a mock function for the type Repository +func (_mock *Repository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ret := _mock.Called(ctx, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for ListUserReportsConfig") + } + + var r0 reports.ReportConfigPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) (reports.ReportConfigPage, error)); ok { + return returnFunc(ctx, userID, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, reports.PageMeta) reports.ReportConfigPage); ok { + r0 = returnFunc(ctx, userID, pm) + } else { + r0 = ret.Get(0).(reports.ReportConfigPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, reports.PageMeta) error); ok { + r1 = returnFunc(ctx, userID, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ListUserReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserReportsConfig' +type Repository_ListUserReportsConfig_Call struct { + *mock.Call +} + +// ListUserReportsConfig is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - pm reports.PageMeta +func (_e *Repository_Expecter) ListUserReportsConfig(ctx interface{}, userID interface{}, pm interface{}) *Repository_ListUserReportsConfig_Call { + return &Repository_ListUserReportsConfig_Call{Call: _e.mock.On("ListUserReportsConfig", ctx, userID, pm)} +} + +func (_c *Repository_ListUserReportsConfig_Call) Run(run func(ctx context.Context, userID string, pm reports.PageMeta)) *Repository_ListUserReportsConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 reports.PageMeta + if args[2] != nil { + arg2 = args[2].(reports.PageMeta) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ListUserReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Repository_ListUserReportsConfig_Call { + _c.Call.Return(reportConfigPage, err) + return _c +} + +func (_c *Repository_ListUserReportsConfig_Call) RunAndReturn(run func(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Repository_ListUserReportsConfig_Call { + _c.Call.Return(run) + return _c +} + +// RemoveEntityMembers provides a mock function for the type Repository +func (_mock *Repository) RemoveEntityMembers(ctx context.Context, entityID string, members []string) error { + ret := _mock.Called(ctx, entityID, members) + + if len(ret) == 0 { + panic("no return value specified for RemoveEntityMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) error); ok { + r0 = returnFunc(ctx, entityID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveEntityMembers' +type Repository_RemoveEntityMembers_Call struct { + *mock.Call +} + +// RemoveEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - members []string +func (_e *Repository_Expecter) RemoveEntityMembers(ctx interface{}, entityID interface{}, members interface{}) *Repository_RemoveEntityMembers_Call { + return &Repository_RemoveEntityMembers_Call{Call: _e.mock.On("RemoveEntityMembers", ctx, entityID, members)} +} + +func (_c *Repository_RemoveEntityMembers_Call) Run(run func(ctx context.Context, entityID string, members []string)) *Repository_RemoveEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RemoveEntityMembers_Call) Return(err error) *Repository_RemoveEntityMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveEntityMembers_Call) RunAndReturn(run func(ctx context.Context, entityID string, members []string) error) *Repository_RemoveEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// RemoveMemberFromAllRoles provides a mock function for the type Repository +func (_mock *Repository) RemoveMemberFromAllRoles(ctx context.Context, memberID string) error { + ret := _mock.Called(ctx, memberID) + + if len(ret) == 0 { + panic("no return value specified for RemoveMemberFromAllRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, memberID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveMemberFromAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveMemberFromAllRoles' +type Repository_RemoveMemberFromAllRoles_Call struct { + *mock.Call +} + +// RemoveMemberFromAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - memberID string +func (_e *Repository_Expecter) RemoveMemberFromAllRoles(ctx interface{}, memberID interface{}) *Repository_RemoveMemberFromAllRoles_Call { + return &Repository_RemoveMemberFromAllRoles_Call{Call: _e.mock.On("RemoveMemberFromAllRoles", ctx, memberID)} +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) Run(run func(ctx context.Context, memberID string)) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) Return(err error) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveMemberFromAllRoles_Call) RunAndReturn(run func(ctx context.Context, memberID string) error) *Repository_RemoveMemberFromAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RemoveReportConfig provides a mock function for the type Repository +func (_mock *Repository) RemoveReportConfig(ctx context.Context, id string) error { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveReportConfig") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveReportConfig' +type Repository_RemoveReportConfig_Call struct { + *mock.Call +} + +// RemoveReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) RemoveReportConfig(ctx interface{}, id interface{}) *Repository_RemoveReportConfig_Call { + return &Repository_RemoveReportConfig_Call{Call: _e.mock.On("RemoveReportConfig", ctx, id)} +} + +func (_c *Repository_RemoveReportConfig_Call) Run(run func(ctx context.Context, id string)) *Repository_RemoveReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveReportConfig_Call) Return(err error) *Repository_RemoveReportConfig_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string) error) *Repository_RemoveReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRoles provides a mock function for the type Repository +func (_mock *Repository) RemoveRoles(ctx context.Context, roleIDs []string) error { + ret := _mock.Called(ctx, roleIDs) + + if len(ret) == 0 { + panic("no return value specified for RemoveRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) error); ok { + r0 = returnFunc(ctx, roleIDs) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RemoveRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRoles' +type Repository_RemoveRoles_Call struct { + *mock.Call +} + +// RemoveRoles is a helper method to define mock.On call +// - ctx context.Context +// - roleIDs []string +func (_e *Repository_Expecter) RemoveRoles(ctx interface{}, roleIDs interface{}) *Repository_RemoveRoles_Call { + return &Repository_RemoveRoles_Call{Call: _e.mock.On("RemoveRoles", ctx, roleIDs)} +} + +func (_c *Repository_RemoveRoles_Call) Run(run func(ctx context.Context, roleIDs []string)) *Repository_RemoveRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RemoveRoles_Call) Return(err error) *Repository_RemoveRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RemoveRoles_Call) RunAndReturn(run func(ctx context.Context, roleIDs []string) error) *Repository_RemoveRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAllRoles provides a mock function for the type Repository +func (_mock *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { + ret := _mock.Called(ctx, entityID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAllRoles") + } + + var r0 roles.RolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (roles.RolePage, error)); ok { + return returnFunc(ctx, entityID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) roles.RolePage); ok { + r0 = returnFunc(ctx, entityID, limit, offset) + } else { + r0 = ret.Get(0).(roles.RolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, entityID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAllRoles' +type Repository_RetrieveAllRoles_Call struct { + *mock.Call +} + +// RetrieveAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - limit uint64 +// - offset uint64 +func (_e *Repository_Expecter) RetrieveAllRoles(ctx interface{}, entityID interface{}, limit interface{}, offset interface{}) *Repository_RetrieveAllRoles_Call { + return &Repository_RetrieveAllRoles_Call{Call: _e.mock.On("RetrieveAllRoles", ctx, entityID, limit, offset)} +} + +func (_c *Repository_RetrieveAllRoles_Call) Run(run func(ctx context.Context, entityID string, limit uint64, offset uint64)) *Repository_RetrieveAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint64 + if args[2] != nil { + arg2 = args[2].(uint64) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Repository_RetrieveAllRoles_Call) Return(rolePage roles.RolePage, err error) *Repository_RetrieveAllRoles_Call { + _c.Call.Return(rolePage, err) + return _c +} + +func (_c *Repository_RetrieveAllRoles_Call) RunAndReturn(run func(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error)) *Repository_RetrieveAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveByIDWithRoles provides a mock function for the type Repository +func (_mock *Repository) RetrieveByIDWithRoles(ctx context.Context, id string, memberID string) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, id, memberID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveByIDWithRoles") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, id, memberID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) reports.ReportConfig); ok { + r0 = returnFunc(ctx, id, memberID) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, id, memberID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveByIDWithRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByIDWithRoles' +type Repository_RetrieveByIDWithRoles_Call struct { + *mock.Call +} + +// RetrieveByIDWithRoles is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - memberID string +func (_e *Repository_Expecter) RetrieveByIDWithRoles(ctx interface{}, id interface{}, memberID interface{}) *Repository_RetrieveByIDWithRoles_Call { + return &Repository_RetrieveByIDWithRoles_Call{Call: _e.mock.On("RetrieveByIDWithRoles", ctx, id, memberID)} +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) Run(run func(ctx context.Context, id string, memberID string)) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_RetrieveByIDWithRoles_Call) RunAndReturn(run func(ctx context.Context, id string, memberID string) (reports.ReportConfig, error)) *Repository_RetrieveByIDWithRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveEntitiesRolesActionsMembers provides a mock function for the type Repository +func (_mock *Repository) RetrieveEntitiesRolesActionsMembers(ctx context.Context, entityIDs []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error) { + ret := _mock.Called(ctx, entityIDs) + + if len(ret) == 0 { + panic("no return value specified for RetrieveEntitiesRolesActionsMembers") + } + + var r0 []roles.EntityActionRole + var r1 []roles.EntityMemberRole + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error)); ok { + return returnFunc(ctx, entityIDs) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, []string) []roles.EntityActionRole); ok { + r0 = returnFunc(ctx, entityIDs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]roles.EntityActionRole) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, []string) []roles.EntityMemberRole); ok { + r1 = returnFunc(ctx, entityIDs) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]roles.EntityMemberRole) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, []string) error); ok { + r2 = returnFunc(ctx, entityIDs) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +// Repository_RetrieveEntitiesRolesActionsMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntitiesRolesActionsMembers' +type Repository_RetrieveEntitiesRolesActionsMembers_Call struct { + *mock.Call +} + +// RetrieveEntitiesRolesActionsMembers is a helper method to define mock.On call +// - ctx context.Context +// - entityIDs []string +func (_e *Repository_Expecter) RetrieveEntitiesRolesActionsMembers(ctx interface{}, entityIDs interface{}) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + return &Repository_RetrieveEntitiesRolesActionsMembers_Call{Call: _e.mock.On("RetrieveEntitiesRolesActionsMembers", ctx, entityIDs)} +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) Run(run func(ctx context.Context, entityIDs []string)) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 []string + if args[1] != nil { + arg1 = args[1].([]string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) Return(entityActionRoles []roles.EntityActionRole, entityMemberRoles []roles.EntityMemberRole, err error) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Return(entityActionRoles, entityMemberRoles, err) + return _c +} + +func (_c *Repository_RetrieveEntitiesRolesActionsMembers_Call) RunAndReturn(run func(ctx context.Context, entityIDs []string) ([]roles.EntityActionRole, []roles.EntityMemberRole, error)) *Repository_RetrieveEntitiesRolesActionsMembers_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveEntityRole provides a mock function for the type Repository +func (_mock *Repository) RetrieveEntityRole(ctx context.Context, entityID string, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveEntityRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) roles.Role); ok { + r0 = returnFunc(ctx, entityID, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveEntityRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveEntityRole' +type Repository_RetrieveEntityRole_Call struct { + *mock.Call +} + +// RetrieveEntityRole is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - roleID string +func (_e *Repository_Expecter) RetrieveEntityRole(ctx interface{}, entityID interface{}, roleID interface{}) *Repository_RetrieveEntityRole_Call { + return &Repository_RetrieveEntityRole_Call{Call: _e.mock.On("RetrieveEntityRole", ctx, entityID, roleID)} +} + +func (_c *Repository_RetrieveEntityRole_Call) Run(run func(ctx context.Context, entityID string, roleID string)) *Repository_RetrieveEntityRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RetrieveEntityRole_Call) Return(role roles.Role, err error) *Repository_RetrieveEntityRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_RetrieveEntityRole_Call) RunAndReturn(run func(ctx context.Context, entityID string, roleID string) (roles.Role, error)) *Repository_RetrieveEntityRole_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveRole provides a mock function for the type Repository +func (_mock *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (roles.Role, error)); ok { + return returnFunc(ctx, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) roles.Role); ok { + r0 = returnFunc(ctx, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveRole' +type Repository_RetrieveRole_Call struct { + *mock.Call +} + +// RetrieveRole is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +func (_e *Repository_Expecter) RetrieveRole(ctx interface{}, roleID interface{}) *Repository_RetrieveRole_Call { + return &Repository_RetrieveRole_Call{Call: _e.mock.On("RetrieveRole", ctx, roleID)} +} + +func (_c *Repository_RetrieveRole_Call) Run(run func(ctx context.Context, roleID string)) *Repository_RetrieveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RetrieveRole_Call) Return(role roles.Role, err error) *Repository_RetrieveRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_RetrieveRole_Call) RunAndReturn(run func(ctx context.Context, roleID string) (roles.Role, error)) *Repository_RetrieveRole_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddActions provides a mock function for the type Repository +func (_mock *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) { + ret := _mock.Called(ctx, role, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleAddActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) ([]string, error)); ok { + return returnFunc(ctx, role, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) []string); ok { + r0 = returnFunc(ctx, role, actions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role, []string) error); ok { + r1 = returnFunc(ctx, role, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleAddActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddActions' +type Repository_RoleAddActions_Call struct { + *mock.Call +} + +// RoleAddActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - actions []string +func (_e *Repository_Expecter) RoleAddActions(ctx interface{}, role interface{}, actions interface{}) *Repository_RoleAddActions_Call { + return &Repository_RoleAddActions_Call{Call: _e.mock.On("RoleAddActions", ctx, role, actions)} +} + +func (_c *Repository_RoleAddActions_Call) Run(run func(ctx context.Context, role roles.Role, actions []string)) *Repository_RoleAddActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleAddActions_Call) Return(ops []string, err error) *Repository_RoleAddActions_Call { + _c.Call.Return(ops, err) + return _c +} + +func (_c *Repository_RoleAddActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, actions []string) ([]string, error)) *Repository_RoleAddActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddMembers provides a mock function for the type Repository +func (_mock *Repository) RoleAddMembers(ctx context.Context, role roles.Role, members []string) ([]string, error) { + ret := _mock.Called(ctx, role, members) + + if len(ret) == 0 { + panic("no return value specified for RoleAddMembers") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) ([]string, error)); ok { + return returnFunc(ctx, role, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) []string); ok { + r0 = returnFunc(ctx, role, members) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role, []string) error); ok { + r1 = returnFunc(ctx, role, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleAddMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddMembers' +type Repository_RoleAddMembers_Call struct { + *mock.Call +} + +// RoleAddMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - members []string +func (_e *Repository_Expecter) RoleAddMembers(ctx interface{}, role interface{}, members interface{}) *Repository_RoleAddMembers_Call { + return &Repository_RoleAddMembers_Call{Call: _e.mock.On("RoleAddMembers", ctx, role, members)} +} + +func (_c *Repository_RoleAddMembers_Call) Run(run func(ctx context.Context, role roles.Role, members []string)) *Repository_RoleAddMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleAddMembers_Call) Return(strings []string, err error) *Repository_RoleAddMembers_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Repository_RoleAddMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, members []string) ([]string, error)) *Repository_RoleAddMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckActionsExists provides a mock function for the type Repository +func (_mock *Repository) RoleCheckActionsExists(ctx context.Context, roleID string, actions []string) (bool, error) { + ret := _mock.Called(ctx, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckActionsExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) (bool, error)); ok { + return returnFunc(ctx, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) bool); ok { + r0 = returnFunc(ctx, roleID, actions) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { + r1 = returnFunc(ctx, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleCheckActionsExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckActionsExists' +type Repository_RoleCheckActionsExists_Call struct { + *mock.Call +} + +// RoleCheckActionsExists is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - actions []string +func (_e *Repository_Expecter) RoleCheckActionsExists(ctx interface{}, roleID interface{}, actions interface{}) *Repository_RoleCheckActionsExists_Call { + return &Repository_RoleCheckActionsExists_Call{Call: _e.mock.On("RoleCheckActionsExists", ctx, roleID, actions)} +} + +func (_c *Repository_RoleCheckActionsExists_Call) Run(run func(ctx context.Context, roleID string, actions []string)) *Repository_RoleCheckActionsExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleCheckActionsExists_Call) Return(b bool, err error) *Repository_RoleCheckActionsExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Repository_RoleCheckActionsExists_Call) RunAndReturn(run func(ctx context.Context, roleID string, actions []string) (bool, error)) *Repository_RoleCheckActionsExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckMembersExists provides a mock function for the type Repository +func (_mock *Repository) RoleCheckMembersExists(ctx context.Context, roleID string, members []string) (bool, error) { + ret := _mock.Called(ctx, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckMembersExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) (bool, error)); ok { + return returnFunc(ctx, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, []string) bool); ok { + r0 = returnFunc(ctx, roleID, members) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, []string) error); ok { + r1 = returnFunc(ctx, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleCheckMembersExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckMembersExists' +type Repository_RoleCheckMembersExists_Call struct { + *mock.Call +} + +// RoleCheckMembersExists is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - members []string +func (_e *Repository_Expecter) RoleCheckMembersExists(ctx interface{}, roleID interface{}, members interface{}) *Repository_RoleCheckMembersExists_Call { + return &Repository_RoleCheckMembersExists_Call{Call: _e.mock.On("RoleCheckMembersExists", ctx, roleID, members)} +} + +func (_c *Repository_RoleCheckMembersExists_Call) Run(run func(ctx context.Context, roleID string, members []string)) *Repository_RoleCheckMembersExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleCheckMembersExists_Call) Return(b bool, err error) *Repository_RoleCheckMembersExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Repository_RoleCheckMembersExists_Call) RunAndReturn(run func(ctx context.Context, roleID string, members []string) (bool, error)) *Repository_RoleCheckMembersExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleListActions provides a mock function for the type Repository +func (_mock *Repository) RoleListActions(ctx context.Context, roleID string) ([]string, error) { + ret := _mock.Called(ctx, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleListActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]string, error)); ok { + return returnFunc(ctx, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []string); ok { + r0 = returnFunc(ctx, roleID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleListActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListActions' +type Repository_RoleListActions_Call struct { + *mock.Call +} + +// RoleListActions is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +func (_e *Repository_Expecter) RoleListActions(ctx interface{}, roleID interface{}) *Repository_RoleListActions_Call { + return &Repository_RoleListActions_Call{Call: _e.mock.On("RoleListActions", ctx, roleID)} +} + +func (_c *Repository_RoleListActions_Call) Run(run func(ctx context.Context, roleID string)) *Repository_RoleListActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleListActions_Call) Return(strings []string, err error) *Repository_RoleListActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Repository_RoleListActions_Call) RunAndReturn(run func(ctx context.Context, roleID string) ([]string, error)) *Repository_RoleListActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleListMembers provides a mock function for the type Repository +func (_mock *Repository) RoleListMembers(ctx context.Context, roleID string, limit uint64, offset uint64) (roles.MembersPage, error) { + ret := _mock.Called(ctx, roleID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RoleListMembers") + } + + var r0 roles.MembersPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) (roles.MembersPage, error)); ok { + return returnFunc(ctx, roleID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, uint64, uint64) roles.MembersPage); ok { + r0 = returnFunc(ctx, roleID, limit, offset) + } else { + r0 = ret.Get(0).(roles.MembersPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, roleID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RoleListMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListMembers' +type Repository_RoleListMembers_Call struct { + *mock.Call +} + +// RoleListMembers is a helper method to define mock.On call +// - ctx context.Context +// - roleID string +// - limit uint64 +// - offset uint64 +func (_e *Repository_Expecter) RoleListMembers(ctx interface{}, roleID interface{}, limit interface{}, offset interface{}) *Repository_RoleListMembers_Call { + return &Repository_RoleListMembers_Call{Call: _e.mock.On("RoleListMembers", ctx, roleID, limit, offset)} +} + +func (_c *Repository_RoleListMembers_Call) Run(run func(ctx context.Context, roleID string, limit uint64, offset uint64)) *Repository_RoleListMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 uint64 + if args[2] != nil { + arg2 = args[2].(uint64) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Repository_RoleListMembers_Call) Return(membersPage roles.MembersPage, err error) *Repository_RoleListMembers_Call { + _c.Call.Return(membersPage, err) + return _c +} + +func (_c *Repository_RoleListMembers_Call) RunAndReturn(run func(ctx context.Context, roleID string, limit uint64, offset uint64) (roles.MembersPage, error)) *Repository_RoleListMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveActions provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveActions(ctx context.Context, role roles.Role, actions []string) error { + ret := _mock.Called(ctx, role, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) error); ok { + r0 = returnFunc(ctx, role, actions) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveActions' +type Repository_RoleRemoveActions_Call struct { + *mock.Call +} + +// RoleRemoveActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - actions []string +func (_e *Repository_Expecter) RoleRemoveActions(ctx interface{}, role interface{}, actions interface{}) *Repository_RoleRemoveActions_Call { + return &Repository_RoleRemoveActions_Call{Call: _e.mock.On("RoleRemoveActions", ctx, role, actions)} +} + +func (_c *Repository_RoleRemoveActions_Call) Run(run func(ctx context.Context, role roles.Role, actions []string)) *Repository_RoleRemoveActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveActions_Call) Return(err error) *Repository_RoleRemoveActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, actions []string) error) *Repository_RoleRemoveActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllActions provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveAllActions(ctx context.Context, role roles.Role) error { + ret := _mock.Called(ctx, role) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) error); ok { + r0 = returnFunc(ctx, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveAllActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllActions' +type Repository_RoleRemoveAllActions_Call struct { + *mock.Call +} + +// RoleRemoveAllActions is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +func (_e *Repository_Expecter) RoleRemoveAllActions(ctx interface{}, role interface{}) *Repository_RoleRemoveAllActions_Call { + return &Repository_RoleRemoveAllActions_Call{Call: _e.mock.On("RoleRemoveAllActions", ctx, role)} +} + +func (_c *Repository_RoleRemoveAllActions_Call) Run(run func(ctx context.Context, role roles.Role)) *Repository_RoleRemoveAllActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveAllActions_Call) Return(err error) *Repository_RoleRemoveAllActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveAllActions_Call) RunAndReturn(run func(ctx context.Context, role roles.Role) error) *Repository_RoleRemoveAllActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllMembers provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveAllMembers(ctx context.Context, role roles.Role) error { + ret := _mock.Called(ctx, role) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) error); ok { + r0 = returnFunc(ctx, role) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveAllMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllMembers' +type Repository_RoleRemoveAllMembers_Call struct { + *mock.Call +} + +// RoleRemoveAllMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +func (_e *Repository_Expecter) RoleRemoveAllMembers(ctx interface{}, role interface{}) *Repository_RoleRemoveAllMembers_Call { + return &Repository_RoleRemoveAllMembers_Call{Call: _e.mock.On("RoleRemoveAllMembers", ctx, role)} +} + +func (_c *Repository_RoleRemoveAllMembers_Call) Run(run func(ctx context.Context, role roles.Role)) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveAllMembers_Call) Return(err error) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveAllMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role) error) *Repository_RoleRemoveAllMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveMembers provides a mock function for the type Repository +func (_mock *Repository) RoleRemoveMembers(ctx context.Context, role roles.Role, members []string) error { + ret := _mock.Called(ctx, role, members) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role, []string) error); ok { + r0 = returnFunc(ctx, role, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_RoleRemoveMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveMembers' +type Repository_RoleRemoveMembers_Call struct { + *mock.Call +} + +// RoleRemoveMembers is a helper method to define mock.On call +// - ctx context.Context +// - role roles.Role +// - members []string +func (_e *Repository_Expecter) RoleRemoveMembers(ctx interface{}, role interface{}, members interface{}) *Repository_RoleRemoveMembers_Call { + return &Repository_RoleRemoveMembers_Call{Call: _e.mock.On("RoleRemoveMembers", ctx, role, members)} +} + +func (_c *Repository_RoleRemoveMembers_Call) Run(run func(ctx context.Context, role roles.Role, members []string)) *Repository_RoleRemoveMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + var arg2 []string + if args[2] != nil { + arg2 = args[2].([]string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RoleRemoveMembers_Call) Return(err error) *Repository_RoleRemoveMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_RoleRemoveMembers_Call) RunAndReturn(run func(ctx context.Context, role roles.Role, members []string) error) *Repository_RoleRemoveMembers_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportConfig provides a mock function for the type Repository +func (_mock *Repository) UpdateReportConfig(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportConfig' +type Repository_UpdateReportConfig_Call struct { + *mock.Call +} + +// UpdateReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - cfg reports.ReportConfig +func (_e *Repository_Expecter) UpdateReportConfig(ctx interface{}, cfg interface{}) *Repository_UpdateReportConfig_Call { + return &Repository_UpdateReportConfig_Call{Call: _e.mock.On("UpdateReportConfig", ctx, cfg)} +} + +func (_c *Repository_UpdateReportConfig_Call) Run(run func(ctx context.Context, cfg reports.ReportConfig)) *Repository_UpdateReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.ReportConfig + if args[1] != nil { + arg1 = args[1].(reports.ReportConfig) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_UpdateReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_UpdateReportConfig_Call) RunAndReturn(run func(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Repository_UpdateReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportConfigStatus provides a mock function for the type Repository +func (_mock *Repository) UpdateReportConfigStatus(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportConfigStatus") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateReportConfigStatus_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportConfigStatus' +type Repository_UpdateReportConfigStatus_Call struct { + *mock.Call +} + +// UpdateReportConfigStatus is a helper method to define mock.On call +// - ctx context.Context +// - cfg reports.ReportConfig +func (_e *Repository_Expecter) UpdateReportConfigStatus(ctx interface{}, cfg interface{}) *Repository_UpdateReportConfigStatus_Call { + return &Repository_UpdateReportConfigStatus_Call{Call: _e.mock.On("UpdateReportConfigStatus", ctx, cfg)} +} + +func (_c *Repository_UpdateReportConfigStatus_Call) Run(run func(ctx context.Context, cfg reports.ReportConfig)) *Repository_UpdateReportConfigStatus_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.ReportConfig + if args[1] != nil { + arg1 = args[1].(reports.ReportConfig) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateReportConfigStatus_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_UpdateReportConfigStatus_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_UpdateReportConfigStatus_Call) RunAndReturn(run func(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Repository_UpdateReportConfigStatus_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportDue provides a mock function for the type Repository +func (_mock *Repository) UpdateReportDue(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, id, due) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportDue") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, time.Time) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, id, due) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, time.Time) reports.ReportConfig); ok { + r0 = returnFunc(ctx, id, due) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, time.Time) error); ok { + r1 = returnFunc(ctx, id, due) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateReportDue_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportDue' +type Repository_UpdateReportDue_Call struct { + *mock.Call +} + +// UpdateReportDue is a helper method to define mock.On call +// - ctx context.Context +// - id string +// - due time.Time +func (_e *Repository_Expecter) UpdateReportDue(ctx interface{}, id interface{}, due interface{}) *Repository_UpdateReportDue_Call { + return &Repository_UpdateReportDue_Call{Call: _e.mock.On("UpdateReportDue", ctx, id, due)} +} + +func (_c *Repository_UpdateReportDue_Call) Run(run func(ctx context.Context, id string, due time.Time)) *Repository_UpdateReportDue_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 time.Time + if args[2] != nil { + arg2 = args[2].(time.Time) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_UpdateReportDue_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_UpdateReportDue_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_UpdateReportDue_Call) RunAndReturn(run func(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error)) *Repository_UpdateReportDue_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportSchedule provides a mock function for the type Repository +func (_mock *Repository) UpdateReportSchedule(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportSchedule") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateReportSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportSchedule' +type Repository_UpdateReportSchedule_Call struct { + *mock.Call +} + +// UpdateReportSchedule is a helper method to define mock.On call +// - ctx context.Context +// - cfg reports.ReportConfig +func (_e *Repository_Expecter) UpdateReportSchedule(ctx interface{}, cfg interface{}) *Repository_UpdateReportSchedule_Call { + return &Repository_UpdateReportSchedule_Call{Call: _e.mock.On("UpdateReportSchedule", ctx, cfg)} +} + +func (_c *Repository_UpdateReportSchedule_Call) Run(run func(ctx context.Context, cfg reports.ReportConfig)) *Repository_UpdateReportSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 reports.ReportConfig + if args[1] != nil { + arg1 = args[1].(reports.ReportConfig) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateReportSchedule_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_UpdateReportSchedule_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_UpdateReportSchedule_Call) RunAndReturn(run func(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Repository_UpdateReportSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportTemplate provides a mock function for the type Repository +func (_mock *Repository) UpdateReportTemplate(ctx context.Context, domainID string, reportID string, template reports.ReportTemplate) error { + ret := _mock.Called(ctx, domainID, reportID, template) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportTemplate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, reports.ReportTemplate) error); ok { + r0 = returnFunc(ctx, domainID, reportID, template) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_UpdateReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportTemplate' +type Repository_UpdateReportTemplate_Call struct { + *mock.Call +} + +// UpdateReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - reportID string +// - template reports.ReportTemplate +func (_e *Repository_Expecter) UpdateReportTemplate(ctx interface{}, domainID interface{}, reportID interface{}, template interface{}) *Repository_UpdateReportTemplate_Call { + return &Repository_UpdateReportTemplate_Call{Call: _e.mock.On("UpdateReportTemplate", ctx, domainID, reportID, template)} +} + +func (_c *Repository_UpdateReportTemplate_Call) Run(run func(ctx context.Context, domainID string, reportID string, template reports.ReportTemplate)) *Repository_UpdateReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 reports.ReportTemplate + if args[3] != nil { + arg3 = args[3].(reports.ReportTemplate) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Repository_UpdateReportTemplate_Call) Return(err error) *Repository_UpdateReportTemplate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_UpdateReportTemplate_Call) RunAndReturn(run func(ctx context.Context, domainID string, reportID string, template reports.ReportTemplate) error) *Repository_UpdateReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRole provides a mock function for the type Repository +func (_mock *Repository) UpdateRole(ctx context.Context, ro roles.Role) (roles.Role, error) { + ret := _mock.Called(ctx, ro) + + if len(ret) == 0 { + panic("no return value specified for UpdateRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) (roles.Role, error)); ok { + return returnFunc(ctx, ro) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, roles.Role) roles.Role); ok { + r0 = returnFunc(ctx, ro) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, roles.Role) error); ok { + r1 = returnFunc(ctx, ro) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole' +type Repository_UpdateRole_Call struct { + *mock.Call +} + +// UpdateRole is a helper method to define mock.On call +// - ctx context.Context +// - ro roles.Role +func (_e *Repository_Expecter) UpdateRole(ctx interface{}, ro interface{}) *Repository_UpdateRole_Call { + return &Repository_UpdateRole_Call{Call: _e.mock.On("UpdateRole", ctx, ro)} +} + +func (_c *Repository_UpdateRole_Call) Run(run func(ctx context.Context, ro roles.Role)) *Repository_UpdateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 roles.Role + if args[1] != nil { + arg1 = args[1].(roles.Role) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRole_Call) Return(role roles.Role, err error) *Repository_UpdateRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context, ro roles.Role) (roles.Role, error)) *Repository_UpdateRole_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportConfig provides a mock function for the type Repository +func (_mock *Repository) ViewReportConfig(ctx context.Context, id string) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for ViewReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) reports.ReportConfig); ok { + r0 = returnFunc(ctx, id) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportConfig' +type Repository_ViewReportConfig_Call struct { + *mock.Call +} + +// ViewReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - id string +func (_e *Repository_Expecter) ViewReportConfig(ctx interface{}, id interface{}) *Repository_ViewReportConfig_Call { + return &Repository_ViewReportConfig_Call{Call: _e.mock.On("ViewReportConfig", ctx, id)} +} + +func (_c *Repository_ViewReportConfig_Call) Run(run func(ctx context.Context, id string)) *Repository_ViewReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_ViewReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Repository_ViewReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Repository_ViewReportConfig_Call) RunAndReturn(run func(ctx context.Context, id string) (reports.ReportConfig, error)) *Repository_ViewReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportTemplate provides a mock function for the type Repository +func (_mock *Repository) ViewReportTemplate(ctx context.Context, domainID string, reportID string) (reports.ReportTemplate, error) { + ret := _mock.Called(ctx, domainID, reportID) + + if len(ret) == 0 { + panic("no return value specified for ViewReportTemplate") + } + + var r0 reports.ReportTemplate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (reports.ReportTemplate, error)); ok { + return returnFunc(ctx, domainID, reportID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) reports.ReportTemplate); ok { + r0 = returnFunc(ctx, domainID, reportID) + } else { + r0 = ret.Get(0).(reports.ReportTemplate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, domainID, reportID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_ViewReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportTemplate' +type Repository_ViewReportTemplate_Call struct { + *mock.Call +} + +// ViewReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - domainID string +// - reportID string +func (_e *Repository_Expecter) ViewReportTemplate(ctx interface{}, domainID interface{}, reportID interface{}) *Repository_ViewReportTemplate_Call { + return &Repository_ViewReportTemplate_Call{Call: _e.mock.On("ViewReportTemplate", ctx, domainID, reportID)} +} + +func (_c *Repository_ViewReportTemplate_Call) Run(run func(ctx context.Context, domainID string, reportID string)) *Repository_ViewReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_ViewReportTemplate_Call) Return(reportTemplate reports.ReportTemplate, err error) *Repository_ViewReportTemplate_Call { + _c.Call.Return(reportTemplate, err) + return _c +} + +func (_c *Repository_ViewReportTemplate_Call) RunAndReturn(run func(ctx context.Context, domainID string, reportID string) (reports.ReportTemplate, error)) *Repository_ViewReportTemplate_Call { + _c.Call.Return(run) + return _c +} diff --git a/reports/mocks/service.go b/reports/mocks/service.go new file mode 100644 index 000000000..2fd3d0071 --- /dev/null +++ b/reports/mocks/service.go @@ -0,0 +1,2426 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/reports" + mock "github.com/stretchr/testify/mock" +) + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewService(t interface { + mock.TestingT + Cleanup(func()) +}) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +type Service_Expecter struct { + mock *mock.Mock +} + +func (_m *Service) EXPECT() *Service_Expecter { + return &Service_Expecter{mock: &_m.Mock} +} + +// AddReportConfig provides a mock function for the type Service +func (_mock *Service) AddReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, cfg) + + if len(ret) == 0 { + panic("no return value specified for AddReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, session, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_AddReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddReportConfig' +type Service_AddReportConfig_Call struct { + *mock.Call +} + +// AddReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - cfg reports.ReportConfig +func (_e *Service_Expecter) AddReportConfig(ctx interface{}, session interface{}, cfg interface{}) *Service_AddReportConfig_Call { + return &Service_AddReportConfig_Call{Call: _e.mock.On("AddReportConfig", ctx, session, cfg)} +} + +func (_c *Service_AddReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig)) *Service_AddReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.ReportConfig + if args[2] != nil { + arg2 = args[2].(reports.ReportConfig) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_AddReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Service_AddReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_AddReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Service_AddReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// AddRole provides a mock function for the type Service +func (_mock *Service) AddRole(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) { + ret := _mock.Called(ctx, session, entityID, roleName, optionalActions, optionalMembers) + + if len(ret) == 0 { + panic("no return value specified for AddRole") + } + + var r0 roles.RoleProvision + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, []string) (roles.RoleProvision, error)); ok { + return returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string, []string) roles.RoleProvision); ok { + r0 = returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } else { + r0 = ret.Get(0).(roles.RoleProvision) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleName, optionalActions, optionalMembers) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_AddRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddRole' +type Service_AddRole_Call struct { + *mock.Call +} + +// AddRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleName string +// - optionalActions []string +// - optionalMembers []string +func (_e *Service_Expecter) AddRole(ctx interface{}, session interface{}, entityID interface{}, roleName interface{}, optionalActions interface{}, optionalMembers interface{}) *Service_AddRole_Call { + return &Service_AddRole_Call{Call: _e.mock.On("AddRole", ctx, session, entityID, roleName, optionalActions, optionalMembers)} +} + +func (_c *Service_AddRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string)) *Service_AddRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + var arg5 []string + if args[5] != nil { + arg5 = args[5].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_AddRole_Call) Return(roleProvision roles.RoleProvision, err error) *Service_AddRole_Call { + _c.Call.Return(roleProvision, err) + return _c +} + +func (_c *Service_AddRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error)) *Service_AddRole_Call { + _c.Call.Return(run) + return _c +} + +// DeleteReportTemplate provides a mock function for the type Service +func (_mock *Service) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteReportTemplate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_DeleteReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteReportTemplate' +type Service_DeleteReportTemplate_Call struct { + *mock.Call +} + +// DeleteReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) DeleteReportTemplate(ctx interface{}, session interface{}, id interface{}) *Service_DeleteReportTemplate_Call { + return &Service_DeleteReportTemplate_Call{Call: _e.mock.On("DeleteReportTemplate", ctx, session, id)} +} + +func (_c *Service_DeleteReportTemplate_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DeleteReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_DeleteReportTemplate_Call) Return(err error) *Service_DeleteReportTemplate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_DeleteReportTemplate_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_DeleteReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// DisableReportConfig provides a mock function for the type Service +func (_mock *Service) DisableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for DisableReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_DisableReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DisableReportConfig' +type Service_DisableReportConfig_Call struct { + *mock.Call +} + +// DisableReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) DisableReportConfig(ctx interface{}, session interface{}, id interface{}) *Service_DisableReportConfig_Call { + return &Service_DisableReportConfig_Call{Call: _e.mock.On("DisableReportConfig", ctx, session, id)} +} + +func (_c *Service_DisableReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_DisableReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_DisableReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Service_DisableReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_DisableReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error)) *Service_DisableReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// EnableReportConfig provides a mock function for the type Service +func (_mock *Service) EnableReportConfig(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for EnableReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_EnableReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EnableReportConfig' +type Service_EnableReportConfig_Call struct { + *mock.Call +} + +// EnableReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) EnableReportConfig(ctx interface{}, session interface{}, id interface{}) *Service_EnableReportConfig_Call { + return &Service_EnableReportConfig_Call{Call: _e.mock.On("EnableReportConfig", ctx, session, id)} +} + +func (_c *Service_EnableReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_EnableReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_EnableReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Service_EnableReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_EnableReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (reports.ReportConfig, error)) *Service_EnableReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// GenerateReport provides a mock function for the type Service +func (_mock *Service) GenerateReport(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error) { + ret := _mock.Called(ctx, session, config, action) + + if len(ret) == 0 { + panic("no return value specified for GenerateReport") + } + + var r0 reports.ReportPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig, reports.ReportAction) (reports.ReportPage, error)); ok { + return returnFunc(ctx, session, config, action) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig, reports.ReportAction) reports.ReportPage); ok { + r0 = returnFunc(ctx, session, config, action) + } else { + r0 = ret.Get(0).(reports.ReportPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, reports.ReportConfig, reports.ReportAction) error); ok { + r1 = returnFunc(ctx, session, config, action) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_GenerateReport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GenerateReport' +type Service_GenerateReport_Call struct { + *mock.Call +} + +// GenerateReport is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - config reports.ReportConfig +// - action reports.ReportAction +func (_e *Service_Expecter) GenerateReport(ctx interface{}, session interface{}, config interface{}, action interface{}) *Service_GenerateReport_Call { + return &Service_GenerateReport_Call{Call: _e.mock.On("GenerateReport", ctx, session, config, action)} +} + +func (_c *Service_GenerateReport_Call) Run(run func(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction)) *Service_GenerateReport_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.ReportConfig + if args[2] != nil { + arg2 = args[2].(reports.ReportConfig) + } + var arg3 reports.ReportAction + if args[3] != nil { + arg3 = args[3].(reports.ReportAction) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_GenerateReport_Call) Return(reportPage reports.ReportPage, err error) *Service_GenerateReport_Call { + _c.Call.Return(reportPage, err) + return _c +} + +func (_c *Service_GenerateReport_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, config reports.ReportConfig, action reports.ReportAction) (reports.ReportPage, error)) *Service_GenerateReport_Call { + _c.Call.Return(run) + return _c +} + +// ListAvailableActions provides a mock function for the type Service +func (_mock *Service) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) { + ret := _mock.Called(ctx, session) + + if len(ret) == 0 { + panic("no return value specified for ListAvailableActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) ([]string, error)); ok { + return returnFunc(ctx, session) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) []string); ok { + r0 = returnFunc(ctx, session) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session) error); ok { + r1 = returnFunc(ctx, session) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListAvailableActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListAvailableActions' +type Service_ListAvailableActions_Call struct { + *mock.Call +} + +// ListAvailableActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +func (_e *Service_Expecter) ListAvailableActions(ctx interface{}, session interface{}) *Service_ListAvailableActions_Call { + return &Service_ListAvailableActions_Call{Call: _e.mock.On("ListAvailableActions", ctx, session)} +} + +func (_c *Service_ListAvailableActions_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_ListAvailableActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_ListAvailableActions_Call) Return(strings []string, err error) *Service_ListAvailableActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_ListAvailableActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) ([]string, error)) *Service_ListAvailableActions_Call { + _c.Call.Return(run) + return _c +} + +// ListEntityMembers provides a mock function for the type Service +func (_mock *Service) ListEntityMembers(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery) (roles.MembersRolePage, error) { + ret := _mock.Called(ctx, session, entityID, pq) + + if len(ret) == 0 { + panic("no return value specified for ListEntityMembers") + } + + var r0 roles.MembersRolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) (roles.MembersRolePage, error)); ok { + return returnFunc(ctx, session, entityID, pq) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) roles.MembersRolePage); ok { + r0 = returnFunc(ctx, session, entityID, pq) + } else { + r0 = ret.Get(0).(roles.MembersRolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, roles.MembersRolePageQuery) error); ok { + r1 = returnFunc(ctx, session, entityID, pq) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListEntityMembers' +type Service_ListEntityMembers_Call struct { + *mock.Call +} + +// ListEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - pq roles.MembersRolePageQuery +func (_e *Service_Expecter) ListEntityMembers(ctx interface{}, session interface{}, entityID interface{}, pq interface{}) *Service_ListEntityMembers_Call { + return &Service_ListEntityMembers_Call{Call: _e.mock.On("ListEntityMembers", ctx, session, entityID, pq)} +} + +func (_c *Service_ListEntityMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery)) *Service_ListEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 roles.MembersRolePageQuery + if args[3] != nil { + arg3 = args[3].(roles.MembersRolePageQuery) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ListEntityMembers_Call) Return(membersRolePage roles.MembersRolePage, err error) *Service_ListEntityMembers_Call { + _c.Call.Return(membersRolePage, err) + return _c +} + +func (_c *Service_ListEntityMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, pq roles.MembersRolePageQuery) (roles.MembersRolePage, error)) *Service_ListEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// ListReportsConfig provides a mock function for the type Service +func (_mock *Service) ListReportsConfig(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error) { + ret := _mock.Called(ctx, session, pm) + + if len(ret) == 0 { + panic("no return value specified for ListReportsConfig") + } + + var r0 reports.ReportConfigPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.PageMeta) (reports.ReportConfigPage, error)); ok { + return returnFunc(ctx, session, pm) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.PageMeta) reports.ReportConfigPage); ok { + r0 = returnFunc(ctx, session, pm) + } else { + r0 = ret.Get(0).(reports.ReportConfigPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, reports.PageMeta) error); ok { + r1 = returnFunc(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListReportsConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListReportsConfig' +type Service_ListReportsConfig_Call struct { + *mock.Call +} + +// ListReportsConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - pm reports.PageMeta +func (_e *Service_Expecter) ListReportsConfig(ctx interface{}, session interface{}, pm interface{}) *Service_ListReportsConfig_Call { + return &Service_ListReportsConfig_Call{Call: _e.mock.On("ListReportsConfig", ctx, session, pm)} +} + +func (_c *Service_ListReportsConfig_Call) Run(run func(ctx context.Context, session authn.Session, pm reports.PageMeta)) *Service_ListReportsConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.PageMeta + if args[2] != nil { + arg2 = args[2].(reports.PageMeta) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ListReportsConfig_Call) Return(reportConfigPage reports.ReportConfigPage, err error) *Service_ListReportsConfig_Call { + _c.Call.Return(reportConfigPage, err) + return _c +} + +func (_c *Service_ListReportsConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, pm reports.PageMeta) (reports.ReportConfigPage, error)) *Service_ListReportsConfig_Call { + _c.Call.Return(run) + return _c +} + +// RemoveEntityMembers provides a mock function for the type Service +func (_mock *Service) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error { + ret := _mock.Called(ctx, session, entityID, members) + + if len(ret) == 0 { + panic("no return value specified for RemoveEntityMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveEntityMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveEntityMembers' +type Service_RemoveEntityMembers_Call struct { + *mock.Call +} + +// RemoveEntityMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - members []string +func (_e *Service_Expecter) RemoveEntityMembers(ctx interface{}, session interface{}, entityID interface{}, members interface{}) *Service_RemoveEntityMembers_Call { + return &Service_RemoveEntityMembers_Call{Call: _e.mock.On("RemoveEntityMembers", ctx, session, entityID, members)} +} + +func (_c *Service_RemoveEntityMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, members []string)) *Service_RemoveEntityMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 []string + if args[3] != nil { + arg3 = args[3].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RemoveEntityMembers_Call) Return(err error) *Service_RemoveEntityMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveEntityMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, members []string) error) *Service_RemoveEntityMembers_Call { + _c.Call.Return(run) + return _c +} + +// RemoveMemberFromAllRoles provides a mock function for the type Service +func (_mock *Service) RemoveMemberFromAllRoles(ctx context.Context, session authn.Session, memberID string) error { + ret := _mock.Called(ctx, session, memberID) + + if len(ret) == 0 { + panic("no return value specified for RemoveMemberFromAllRoles") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, memberID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveMemberFromAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveMemberFromAllRoles' +type Service_RemoveMemberFromAllRoles_Call struct { + *mock.Call +} + +// RemoveMemberFromAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - memberID string +func (_e *Service_Expecter) RemoveMemberFromAllRoles(ctx interface{}, session interface{}, memberID interface{}) *Service_RemoveMemberFromAllRoles_Call { + return &Service_RemoveMemberFromAllRoles_Call{Call: _e.mock.On("RemoveMemberFromAllRoles", ctx, session, memberID)} +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) Run(run func(ctx context.Context, session authn.Session, memberID string)) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) Return(err error) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveMemberFromAllRoles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, memberID string) error) *Service_RemoveMemberFromAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RemoveReportConfig provides a mock function for the type Service +func (_mock *Service) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for RemoveReportConfig") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveReportConfig' +type Service_RemoveReportConfig_Call struct { + *mock.Call +} + +// RemoveReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) RemoveReportConfig(ctx interface{}, session interface{}, id interface{}) *Service_RemoveReportConfig_Call { + return &Service_RemoveReportConfig_Call{Call: _e.mock.On("RemoveReportConfig", ctx, session, id)} +} + +func (_c *Service_RemoveReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_RemoveReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RemoveReportConfig_Call) Return(err error) *Service_RemoveReportConfig_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) error) *Service_RemoveReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// RemoveRole provides a mock function for the type Service +func (_mock *Service) RemoveRole(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RemoveRole") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RemoveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveRole' +type Service_RemoveRole_Call struct { + *mock.Call +} + +// RemoveRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RemoveRole(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RemoveRole_Call { + return &Service_RemoveRole_Call{Call: _e.mock.On("RemoveRole", ctx, session, entityID, roleID)} +} + +func (_c *Service_RemoveRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RemoveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RemoveRole_Call) Return(err error) *Service_RemoveRole_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RemoveRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RemoveRole_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveAllRoles provides a mock function for the type Service +func (_mock *Service) RetrieveAllRoles(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { + ret := _mock.Called(ctx, session, entityID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAllRoles") + } + + var r0 roles.RolePage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, uint64, uint64) (roles.RolePage, error)); ok { + return returnFunc(ctx, session, entityID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, uint64, uint64) roles.RolePage); ok { + r0 = returnFunc(ctx, session, entityID, limit, offset) + } else { + r0 = ret.Get(0).(roles.RolePage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, session, entityID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RetrieveAllRoles_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveAllRoles' +type Service_RetrieveAllRoles_Call struct { + *mock.Call +} + +// RetrieveAllRoles is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - limit uint64 +// - offset uint64 +func (_e *Service_Expecter) RetrieveAllRoles(ctx interface{}, session interface{}, entityID interface{}, limit interface{}, offset interface{}) *Service_RetrieveAllRoles_Call { + return &Service_RetrieveAllRoles_Call{Call: _e.mock.On("RetrieveAllRoles", ctx, session, entityID, limit, offset)} +} + +func (_c *Service_RetrieveAllRoles_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64)) *Service_RetrieveAllRoles_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 uint64 + if args[3] != nil { + arg3 = args[3].(uint64) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RetrieveAllRoles_Call) Return(rolePage roles.RolePage, err error) *Service_RetrieveAllRoles_Call { + _c.Call.Return(rolePage, err) + return _c +} + +func (_c *Service_RetrieveAllRoles_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, limit uint64, offset uint64) (roles.RolePage, error)) *Service_RetrieveAllRoles_Call { + _c.Call.Return(run) + return _c +} + +// RetrieveRole provides a mock function for the type Service +func (_mock *Service) RetrieveRole(ctx context.Context, session authn.Session, entityID string, roleID string) (roles.Role, error) { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveRole") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, session, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) roles.Role); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RetrieveRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveRole' +type Service_RetrieveRole_Call struct { + *mock.Call +} + +// RetrieveRole is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RetrieveRole(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RetrieveRole_Call { + return &Service_RetrieveRole_Call{Call: _e.mock.On("RetrieveRole", ctx, session, entityID, roleID)} +} + +func (_c *Service_RetrieveRole_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RetrieveRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RetrieveRole_Call) Return(role roles.Role, err error) *Service_RetrieveRole_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Service_RetrieveRole_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) (roles.Role, error)) *Service_RetrieveRole_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddActions provides a mock function for the type Service +func (_mock *Service) RoleAddActions(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleAddActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleAddActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddActions' +type Service_RoleAddActions_Call struct { + *mock.Call +} + +// RoleAddActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleAddActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleAddActions_Call { + return &Service_RoleAddActions_Call{Call: _e.mock.On("RoleAddActions", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleAddActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleAddActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleAddActions_Call) Return(ops []string, err error) *Service_RoleAddActions_Call { + _c.Call.Return(ops, err) + return _c +} + +func (_c *Service_RoleAddActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) ([]string, error)) *Service_RoleAddActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleAddMembers provides a mock function for the type Service +func (_mock *Service) RoleAddMembers(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleAddMembers") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleAddMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleAddMembers' +type Service_RoleAddMembers_Call struct { + *mock.Call +} + +// RoleAddMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleAddMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleAddMembers_Call { + return &Service_RoleAddMembers_Call{Call: _e.mock.On("RoleAddMembers", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleAddMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleAddMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleAddMembers_Call) Return(strings []string, err error) *Service_RoleAddMembers_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_RoleAddMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) ([]string, error)) *Service_RoleAddMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckActionsExists provides a mock function for the type Service +func (_mock *Service) RoleCheckActionsExists(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) (bool, error) { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckActionsExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) (bool, error)); ok { + return returnFunc(ctx, session, entityID, roleID, actions) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) bool); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleCheckActionsExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckActionsExists' +type Service_RoleCheckActionsExists_Call struct { + *mock.Call +} + +// RoleCheckActionsExists is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleCheckActionsExists(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleCheckActionsExists_Call { + return &Service_RoleCheckActionsExists_Call{Call: _e.mock.On("RoleCheckActionsExists", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleCheckActionsExists_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleCheckActionsExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleCheckActionsExists_Call) Return(b bool, err error) *Service_RoleCheckActionsExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Service_RoleCheckActionsExists_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) (bool, error)) *Service_RoleCheckActionsExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleCheckMembersExists provides a mock function for the type Service +func (_mock *Service) RoleCheckMembersExists(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) (bool, error) { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleCheckMembersExists") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) (bool, error)); ok { + return returnFunc(ctx, session, entityID, roleID, members) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) bool); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, []string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleCheckMembersExists_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleCheckMembersExists' +type Service_RoleCheckMembersExists_Call struct { + *mock.Call +} + +// RoleCheckMembersExists is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleCheckMembersExists(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleCheckMembersExists_Call { + return &Service_RoleCheckMembersExists_Call{Call: _e.mock.On("RoleCheckMembersExists", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleCheckMembersExists_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleCheckMembersExists_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleCheckMembersExists_Call) Return(b bool, err error) *Service_RoleCheckMembersExists_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *Service_RoleCheckMembersExists_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) (bool, error)) *Service_RoleCheckMembersExists_Call { + _c.Call.Return(run) + return _c +} + +// RoleListActions provides a mock function for the type Service +func (_mock *Service) RoleListActions(ctx context.Context, session authn.Session, entityID string, roleID string) ([]string, error) { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleListActions") + } + + var r0 []string + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) ([]string, error)); ok { + return returnFunc(ctx, session, entityID, roleID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) []string); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleListActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListActions' +type Service_RoleListActions_Call struct { + *mock.Call +} + +// RoleListActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleListActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleListActions_Call { + return &Service_RoleListActions_Call{Call: _e.mock.On("RoleListActions", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleListActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleListActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleListActions_Call) Return(strings []string, err error) *Service_RoleListActions_Call { + _c.Call.Return(strings, err) + return _c +} + +func (_c *Service_RoleListActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) ([]string, error)) *Service_RoleListActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleListMembers provides a mock function for the type Service +func (_mock *Service) RoleListMembers(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64) (roles.MembersPage, error) { + ret := _mock.Called(ctx, session, entityID, roleID, limit, offset) + + if len(ret) == 0 { + panic("no return value specified for RoleListMembers") + } + + var r0 roles.MembersPage + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, uint64, uint64) (roles.MembersPage, error)); ok { + return returnFunc(ctx, session, entityID, roleID, limit, offset) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, uint64, uint64) roles.MembersPage); ok { + r0 = returnFunc(ctx, session, entityID, roleID, limit, offset) + } else { + r0 = ret.Get(0).(roles.MembersPage) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, uint64, uint64) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, limit, offset) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_RoleListMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleListMembers' +type Service_RoleListMembers_Call struct { + *mock.Call +} + +// RoleListMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - limit uint64 +// - offset uint64 +func (_e *Service_Expecter) RoleListMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, limit interface{}, offset interface{}) *Service_RoleListMembers_Call { + return &Service_RoleListMembers_Call{Call: _e.mock.On("RoleListMembers", ctx, session, entityID, roleID, limit, offset)} +} + +func (_c *Service_RoleListMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64)) *Service_RoleListMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 uint64 + if args[4] != nil { + arg4 = args[4].(uint64) + } + var arg5 uint64 + if args[5] != nil { + arg5 = args[5].(uint64) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + arg5, + ) + }) + return _c +} + +func (_c *Service_RoleListMembers_Call) Return(membersPage roles.MembersPage, err error) *Service_RoleListMembers_Call { + _c.Call.Return(membersPage, err) + return _c +} + +func (_c *Service_RoleListMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, limit uint64, offset uint64) (roles.MembersPage, error)) *Service_RoleListMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveActions provides a mock function for the type Service +func (_mock *Service) RoleRemoveActions(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) error { + ret := _mock.Called(ctx, session, entityID, roleID, actions) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID, actions) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveActions' +type Service_RoleRemoveActions_Call struct { + *mock.Call +} + +// RoleRemoveActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - actions []string +func (_e *Service_Expecter) RoleRemoveActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, actions interface{}) *Service_RoleRemoveActions_Call { + return &Service_RoleRemoveActions_Call{Call: _e.mock.On("RoleRemoveActions", ctx, session, entityID, roleID, actions)} +} + +func (_c *Service_RoleRemoveActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string)) *Service_RoleRemoveActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveActions_Call) Return(err error) *Service_RoleRemoveActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, actions []string) error) *Service_RoleRemoveActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllActions provides a mock function for the type Service +func (_mock *Service) RoleRemoveAllActions(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllActions") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveAllActions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllActions' +type Service_RoleRemoveAllActions_Call struct { + *mock.Call +} + +// RoleRemoveAllActions is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleRemoveAllActions(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleRemoveAllActions_Call { + return &Service_RoleRemoveAllActions_Call{Call: _e.mock.On("RoleRemoveAllActions", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleRemoveAllActions_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleRemoveAllActions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveAllActions_Call) Return(err error) *Service_RoleRemoveAllActions_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveAllActions_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RoleRemoveAllActions_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveAllMembers provides a mock function for the type Service +func (_mock *Service) RoleRemoveAllMembers(ctx context.Context, session authn.Session, entityID string, roleID string) error { + ret := _mock.Called(ctx, session, entityID, roleID) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveAllMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveAllMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveAllMembers' +type Service_RoleRemoveAllMembers_Call struct { + *mock.Call +} + +// RoleRemoveAllMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +func (_e *Service_Expecter) RoleRemoveAllMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}) *Service_RoleRemoveAllMembers_Call { + return &Service_RoleRemoveAllMembers_Call{Call: _e.mock.On("RoleRemoveAllMembers", ctx, session, entityID, roleID)} +} + +func (_c *Service_RoleRemoveAllMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string)) *Service_RoleRemoveAllMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveAllMembers_Call) Return(err error) *Service_RoleRemoveAllMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveAllMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string) error) *Service_RoleRemoveAllMembers_Call { + _c.Call.Return(run) + return _c +} + +// RoleRemoveMembers provides a mock function for the type Service +func (_mock *Service) RoleRemoveMembers(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) error { + ret := _mock.Called(ctx, session, entityID, roleID, members) + + if len(ret) == 0 { + panic("no return value specified for RoleRemoveMembers") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, []string) error); ok { + r0 = returnFunc(ctx, session, entityID, roleID, members) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RoleRemoveMembers_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RoleRemoveMembers' +type Service_RoleRemoveMembers_Call struct { + *mock.Call +} + +// RoleRemoveMembers is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - members []string +func (_e *Service_Expecter) RoleRemoveMembers(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, members interface{}) *Service_RoleRemoveMembers_Call { + return &Service_RoleRemoveMembers_Call{Call: _e.mock.On("RoleRemoveMembers", ctx, session, entityID, roleID, members)} +} + +func (_c *Service_RoleRemoveMembers_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string)) *Service_RoleRemoveMembers_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 []string + if args[4] != nil { + arg4 = args[4].([]string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_RoleRemoveMembers_Call) Return(err error) *Service_RoleRemoveMembers_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RoleRemoveMembers_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, members []string) error) *Service_RoleRemoveMembers_Call { + _c.Call.Return(run) + return _c +} + +// StartScheduler provides a mock function for the type Service +func (_mock *Service) StartScheduler(ctx context.Context) error { + ret := _mock.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for StartScheduler") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = returnFunc(ctx) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_StartScheduler_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StartScheduler' +type Service_StartScheduler_Call struct { + *mock.Call +} + +// StartScheduler is a helper method to define mock.On call +// - ctx context.Context +func (_e *Service_Expecter) StartScheduler(ctx interface{}) *Service_StartScheduler_Call { + return &Service_StartScheduler_Call{Call: _e.mock.On("StartScheduler", ctx)} +} + +func (_c *Service_StartScheduler_Call) Run(run func(ctx context.Context)) *Service_StartScheduler_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *Service_StartScheduler_Call) Return(err error) *Service_StartScheduler_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_StartScheduler_Call) RunAndReturn(run func(ctx context.Context) error) *Service_StartScheduler_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportConfig provides a mock function for the type Service +func (_mock *Service) UpdateReportConfig(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, session, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportConfig' +type Service_UpdateReportConfig_Call struct { + *mock.Call +} + +// UpdateReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - cfg reports.ReportConfig +func (_e *Service_Expecter) UpdateReportConfig(ctx interface{}, session interface{}, cfg interface{}) *Service_UpdateReportConfig_Call { + return &Service_UpdateReportConfig_Call{Call: _e.mock.On("UpdateReportConfig", ctx, session, cfg)} +} + +func (_c *Service_UpdateReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig)) *Service_UpdateReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.ReportConfig + if args[2] != nil { + arg2 = args[2].(reports.ReportConfig) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Service_UpdateReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_UpdateReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Service_UpdateReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportSchedule provides a mock function for the type Service +func (_mock *Service) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportSchedule") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, cfg) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, reports.ReportConfig) error); ok { + r1 = returnFunc(ctx, session, cfg) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateReportSchedule_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportSchedule' +type Service_UpdateReportSchedule_Call struct { + *mock.Call +} + +// UpdateReportSchedule is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - cfg reports.ReportConfig +func (_e *Service_Expecter) UpdateReportSchedule(ctx interface{}, session interface{}, cfg interface{}) *Service_UpdateReportSchedule_Call { + return &Service_UpdateReportSchedule_Call{Call: _e.mock.On("UpdateReportSchedule", ctx, session, cfg)} +} + +func (_c *Service_UpdateReportSchedule_Call) Run(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig)) *Service_UpdateReportSchedule_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.ReportConfig + if args[2] != nil { + arg2 = args[2].(reports.ReportConfig) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateReportSchedule_Call) Return(reportConfig reports.ReportConfig, err error) *Service_UpdateReportSchedule_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_UpdateReportSchedule_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig) (reports.ReportConfig, error)) *Service_UpdateReportSchedule_Call { + _c.Call.Return(run) + return _c +} + +// UpdateReportTemplate provides a mock function for the type Service +func (_mock *Service) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error { + ret := _mock.Called(ctx, session, cfg) + + if len(ret) == 0 { + panic("no return value specified for UpdateReportTemplate") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, reports.ReportConfig) error); ok { + r0 = returnFunc(ctx, session, cfg) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_UpdateReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateReportTemplate' +type Service_UpdateReportTemplate_Call struct { + *mock.Call +} + +// UpdateReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - cfg reports.ReportConfig +func (_e *Service_Expecter) UpdateReportTemplate(ctx interface{}, session interface{}, cfg interface{}) *Service_UpdateReportTemplate_Call { + return &Service_UpdateReportTemplate_Call{Call: _e.mock.On("UpdateReportTemplate", ctx, session, cfg)} +} + +func (_c *Service_UpdateReportTemplate_Call) Run(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig)) *Service_UpdateReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 reports.ReportConfig + if args[2] != nil { + arg2 = args[2].(reports.ReportConfig) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_UpdateReportTemplate_Call) Return(err error) *Service_UpdateReportTemplate_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_UpdateReportTemplate_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, cfg reports.ReportConfig) error) *Service_UpdateReportTemplate_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRoleName provides a mock function for the type Service +func (_mock *Service) UpdateRoleName(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string) (roles.Role, error) { + ret := _mock.Called(ctx, session, entityID, roleID, newRoleName) + + if len(ret) == 0 { + panic("no return value specified for UpdateRoleName") + } + + var r0 roles.Role + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) (roles.Role, error)); ok { + return returnFunc(ctx, session, entityID, roleID, newRoleName) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, string, string) roles.Role); ok { + r0 = returnFunc(ctx, session, entityID, roleID, newRoleName) + } else { + r0 = ret.Get(0).(roles.Role) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, string, string) error); ok { + r1 = returnFunc(ctx, session, entityID, roleID, newRoleName) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_UpdateRoleName_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRoleName' +type Service_UpdateRoleName_Call struct { + *mock.Call +} + +// UpdateRoleName is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - entityID string +// - roleID string +// - newRoleName string +func (_e *Service_Expecter) UpdateRoleName(ctx interface{}, session interface{}, entityID interface{}, roleID interface{}, newRoleName interface{}) *Service_UpdateRoleName_Call { + return &Service_UpdateRoleName_Call{Call: _e.mock.On("UpdateRoleName", ctx, session, entityID, roleID, newRoleName)} +} + +func (_c *Service_UpdateRoleName_Call) Run(run func(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string)) *Service_UpdateRoleName_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 string + if args[4] != nil { + arg4 = args[4].(string) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *Service_UpdateRoleName_Call) Return(role roles.Role, err error) *Service_UpdateRoleName_Call { + _c.Call.Return(role, err) + return _c +} + +func (_c *Service_UpdateRoleName_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, entityID string, roleID string, newRoleName string) (roles.Role, error)) *Service_UpdateRoleName_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportConfig provides a mock function for the type Service +func (_mock *Service) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error) { + ret := _mock.Called(ctx, session, id, withRoles) + + if len(ret) == 0 { + panic("no return value specified for ViewReportConfig") + } + + var r0 reports.ReportConfig + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bool) (reports.ReportConfig, error)); ok { + return returnFunc(ctx, session, id, withRoles) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string, bool) reports.ReportConfig); ok { + r0 = returnFunc(ctx, session, id, withRoles) + } else { + r0 = ret.Get(0).(reports.ReportConfig) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string, bool) error); ok { + r1 = returnFunc(ctx, session, id, withRoles) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewReportConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportConfig' +type Service_ViewReportConfig_Call struct { + *mock.Call +} + +// ViewReportConfig is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +// - withRoles bool +func (_e *Service_Expecter) ViewReportConfig(ctx interface{}, session interface{}, id interface{}, withRoles interface{}) *Service_ViewReportConfig_Call { + return &Service_ViewReportConfig_Call{Call: _e.mock.On("ViewReportConfig", ctx, session, id, withRoles)} +} + +func (_c *Service_ViewReportConfig_Call) Run(run func(ctx context.Context, session authn.Session, id string, withRoles bool)) *Service_ViewReportConfig_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + run( + arg0, + arg1, + arg2, + arg3, + ) + }) + return _c +} + +func (_c *Service_ViewReportConfig_Call) Return(reportConfig reports.ReportConfig, err error) *Service_ViewReportConfig_Call { + _c.Call.Return(reportConfig, err) + return _c +} + +func (_c *Service_ViewReportConfig_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string, withRoles bool) (reports.ReportConfig, error)) *Service_ViewReportConfig_Call { + _c.Call.Return(run) + return _c +} + +// ViewReportTemplate provides a mock function for the type Service +func (_mock *Service) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error) { + ret := _mock.Called(ctx, session, id) + + if len(ret) == 0 { + panic("no return value specified for ViewReportTemplate") + } + + var r0 reports.ReportTemplate + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) (reports.ReportTemplate, error)); ok { + return returnFunc(ctx, session, id) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) reports.ReportTemplate); ok { + r0 = returnFunc(ctx, session, id) + } else { + r0 = ret.Get(0).(reports.ReportTemplate) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session, string) error); ok { + r1 = returnFunc(ctx, session, id) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ViewReportTemplate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewReportTemplate' +type Service_ViewReportTemplate_Call struct { + *mock.Call +} + +// ViewReportTemplate is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - id string +func (_e *Service_Expecter) ViewReportTemplate(ctx interface{}, session interface{}, id interface{}) *Service_ViewReportTemplate_Call { + return &Service_ViewReportTemplate_Call{Call: _e.mock.On("ViewReportTemplate", ctx, session, id)} +} + +func (_c *Service_ViewReportTemplate_Call) Run(run func(ctx context.Context, session authn.Session, id string)) *Service_ViewReportTemplate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_ViewReportTemplate_Call) Return(reportTemplate reports.ReportTemplate, err error) *Service_ViewReportTemplate_Call { + _c.Call.Return(reportTemplate, err) + return _c +} + +func (_c *Service_ViewReportTemplate_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, id string) (reports.ReportTemplate, error)) *Service_ViewReportTemplate_Call { + _c.Call.Return(run) + return _c +} diff --git a/reports/operations/operations.go b/reports/operations/operations.go new file mode 100644 index 000000000..69cd5934a --- /dev/null +++ b/reports/operations/operations.go @@ -0,0 +1,77 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package operations + +import "github.com/absmach/supermq/pkg/permissions" + +const EntityType = "report" + +// Report Operations. +const ( + OpAddReportConfig permissions.Operation = iota + OpViewReportConfig + OpUpdateReportConfig + OpUpdateReportSchedule + OpRemoveReportConfig + OpListReportsConfig + OpEnableReportConfig + OpDisableReportConfig + OpGenerateReport + OpUpdateReportTemplate + OpViewReportTemplate + OpDeleteReportTemplate +) + +func OperationDetails() map[permissions.Operation]permissions.OperationDetails { + return map[permissions.Operation]permissions.OperationDetails{ + OpAddReportConfig: { + Name: "add", + PermissionRequired: true, + }, + OpViewReportConfig: { + Name: "view", + PermissionRequired: true, + }, + OpUpdateReportConfig: { + Name: "update", + PermissionRequired: true, + }, + OpUpdateReportSchedule: { + Name: "update_schedule", + PermissionRequired: true, + }, + OpRemoveReportConfig: { + Name: "delete", + PermissionRequired: true, + }, + OpListReportsConfig: { + Name: "list", + PermissionRequired: true, + }, + OpEnableReportConfig: { + Name: "enable", + PermissionRequired: true, + }, + OpDisableReportConfig: { + Name: "disable", + PermissionRequired: true, + }, + OpGenerateReport: { + Name: "generate", + PermissionRequired: true, + }, + OpUpdateReportTemplate: { + Name: "update_template", + PermissionRequired: true, + }, + OpViewReportTemplate: { + Name: "view_template", + PermissionRequired: true, + }, + OpDeleteReportTemplate: { + Name: "delete_template", + PermissionRequired: true, + }, + } +} diff --git a/reports/postgres/errors.go b/reports/postgres/errors.go new file mode 100644 index 000000000..80a3489b6 --- /dev/null +++ b/reports/postgres/errors.go @@ -0,0 +1,27 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" +) + +var _ errors.Mapper = (*duplicateErrors)(nil) + +type duplicateErrors struct{} + +// GetError maps constraint names to known errors. +func (d duplicateErrors) GetError(constraint string) (error, bool) { + switch constraint { + case "report_config_pkey": + return repoerr.ErrConflict, true + default: + return nil, false + } +} + +func NewDuplicateErrors() errors.Mapper { + return duplicateErrors{} +} diff --git a/reports/postgres/init.go b/reports/postgres/init.go new file mode 100644 index 000000000..15329caf6 --- /dev/null +++ b/reports/postgres/init.go @@ -0,0 +1,69 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + dpostgres "github.com/absmach/supermq/domains/postgres" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" + _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access + migrate "github.com/rubenv/sql-migrate" +) + +func Migration() (*migrate.MemoryMigrationSource, error) { + rolesMigration, err := rolesPostgres.Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName) + if err != nil { + return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err) + } + reportsMigration := &migrate.MemoryMigrationSource{ + Migrations: []*migrate.Migration{ + { + Id: "reports_01", + Up: []string{ + `CREATE TABLE IF NOT EXISTS report_config ( + id VARCHAR(36) PRIMARY KEY, + name VARCHAR(1024), + description TEXT, + domain_id VARCHAR(36) NOT NULL, + status SMALLINT NOT NULL DEFAULT 0 CHECK (status >= 0), + created_at TIMESTAMP, + created_by VARCHAR(254), + updated_at TIMESTAMP, + updated_by VARCHAR(254), + due TIMESTAMPTZ, + recurring SMALLINT, + recurring_period SMALLINT, + start_datetime TIMESTAMP, + config JSONB, + email JSONB, + metrics JSONB + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS report_config;`, + }, + }, + { + Id: "reports_02", + Up: []string{ + `ALTER TABLE report_config ADD COLUMN report_template TEXT;`, + }, + Down: []string{ + `ALTER TABLE report_config DROP COLUMN report_template;`, + }, + }, + }, + } + + reportsMigration.Migrations = append(reportsMigration.Migrations, rolesMigration.Migrations...) + + domainsMigration, err := dpostgres.Migration() + if err != nil { + return &migrate.MemoryMigrationSource{}, errors.Wrap(repoerr.ErrRoleMigration, err) + } + reportsMigration.Migrations = append(reportsMigration.Migrations, domainsMigration.Migrations...) + + return reportsMigration, nil +} diff --git a/reports/postgres/reports.go b/reports/postgres/reports.go new file mode 100644 index 000000000..d12914e7a --- /dev/null +++ b/reports/postgres/reports.go @@ -0,0 +1,150 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "database/sql" + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/reports" +) + +// dbReport represents the database structure for a Report. +type dbReport struct { + ID string `db:"id"` + Name string `db:"name"` + Description string `db:"description"` + DomainID string `db:"domain_id"` + StartDateTime sql.NullTime `db:"start_datetime"` + Due sql.NullTime `db:"due"` + Recurring schedule.Recurring `db:"recurring"` + RecurringPeriod uint `db:"recurring_period"` + Status reports.Status `db:"status"` + CreatedAt time.Time `db:"created_at"` + CreatedBy string `db:"created_by"` + UpdatedAt time.Time `db:"updated_at"` + UpdatedBy string `db:"updated_by"` + Config []byte `db:"config,omitempty"` + Metrics []byte `db:"metrics"` + Email []byte `db:"email"` + ReportTemplate reports.ReportTemplate `db:"report_template"` + MemberID string `db:"member_id,omitempty"` + Roles json.RawMessage `db:"roles,omitempty"` +} + +func reportToDb(r reports.ReportConfig) (dbReport, error) { + config := []byte("{}") + if r.Config != nil { + b, err := json.Marshal(r.Config) + if err != nil { + return dbReport{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + config = b + } + + metrics := []byte("{}") + if r.Metrics != nil { + m, err := json.Marshal(r.Metrics) + if err != nil { + return dbReport{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + metrics = m + } + + email := []byte("{}") + if r.Email != nil { + e, err := json.Marshal(r.Email) + if err != nil { + return dbReport{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + email = e + } + start := sql.NullTime{Time: r.Schedule.StartDateTime} + if !r.Schedule.StartDateTime.IsZero() { + start.Valid = true + } + t := sql.NullTime{Time: r.Schedule.Time} + if !r.Schedule.Time.IsZero() { + t.Valid = true + } + + return dbReport{ + ID: r.ID, + Name: r.Name, + Description: r.Description, + DomainID: r.DomainID, + StartDateTime: start, + Due: t, + Recurring: r.Schedule.Recurring, + RecurringPeriod: r.Schedule.RecurringPeriod, + Status: r.Status, + CreatedAt: r.CreatedAt, + CreatedBy: r.CreatedBy, + UpdatedAt: r.UpdatedAt, + UpdatedBy: r.UpdatedBy, + Config: config, + Metrics: metrics, + Email: email, + ReportTemplate: r.ReportTemplate, + }, nil +} + +func dbToReport(dto dbReport) (reports.ReportConfig, error) { + var config reports.MetricConfig + if dto.Config != nil { + if err := json.Unmarshal(dto.Config, &config); err != nil { + return reports.ReportConfig{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + var email reports.EmailSetting + if dto.Email != nil { + if err := json.Unmarshal(dto.Email, &email); err != nil { + return reports.ReportConfig{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + var metrics []reports.ReqMetric + if dto.Metrics != nil { + if err := json.Unmarshal(dto.Metrics, &metrics); err != nil { + return reports.ReportConfig{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + var roles []roles.MemberRoleActions + if dto.Roles != nil { + if err := json.Unmarshal(dto.Roles, &roles); err != nil { + return reports.ReportConfig{}, errors.Wrap(errors.ErrMalformedEntity, err) + } + } + + rpt := reports.ReportConfig{ + ID: dto.ID, + Name: dto.Name, + Description: dto.Description, + DomainID: dto.DomainID, + Config: &config, + Metrics: metrics, + Schedule: schedule.Schedule{ + StartDateTime: dto.StartDateTime.Time, + Time: dto.Due.Time, + Recurring: dto.Recurring, + RecurringPeriod: dto.RecurringPeriod, + }, + Email: &email, + Status: dto.Status, + CreatedAt: dto.CreatedAt, + CreatedBy: dto.CreatedBy, + UpdatedAt: dto.UpdatedAt, + UpdatedBy: dto.UpdatedBy, + ReportTemplate: dto.ReportTemplate, + Roles: roles, + } + + return rpt, nil +} diff --git a/reports/postgres/repository.go b/reports/postgres/repository.go new file mode 100644 index 000000000..6dfcd4999 --- /dev/null +++ b/reports/postgres/repository.go @@ -0,0 +1,663 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + api "github.com/absmach/supermq/api/http" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + mgPolicies "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/postgres" + rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" + "github.com/absmach/supermq/reports" +) + +const ( + rolesTableNamePrefix = "reports" + entityTableName = "report_config" + entityIDColumnName = "id" +) + +type PostgresRepository struct { + DB postgres.Database + eh errors.Handler + rolesPostgres.Repository +} + +func NewRepository(db postgres.Database) reports.Repository { + rolesRepo := rolesPostgres.NewRepository(db, mgPolicies.ReportsType, rolesTableNamePrefix, entityTableName, entityIDColumnName) + errHandlerOptions := []errors.HandlerOption{ + postgres.WithDuplicateErrors(NewDuplicateErrors()), + } + return &PostgresRepository{ + DB: db, + eh: postgres.NewErrorHandler(errHandlerOptions...), + Repository: rolesRepo, + } +} + +func (repo *PostgresRepository) AddReportConfig(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + q := ` + INSERT INTO report_config (id, name, description, domain_id, config, metrics, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, report_template) + VALUES (:id, :name, :description, :domain_id, :config, :metrics, + :email, :start_datetime, :due, :recurring, :recurring_period, :created_at, :created_by, :updated_at, :updated_by, :status, :report_template) + RETURNING id, name, description, domain_id, config, metrics, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status, report_template; + ` + dbr, err := reportToDb(cfg) + if err != nil { + return reports.ReportConfig{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err) + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return reports.ReportConfig{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err) + } + defer row.Close() + + var dbReport dbReport + if row.Next() { + if err := row.StructScan(&dbReport); err != nil { + return reports.ReportConfig{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err) + } + } + + report, err := dbToReport(dbReport) + if err != nil { + return reports.ReportConfig{}, repo.eh.HandleError(repoerr.ErrCreateEntity, err) + } + + return report, nil +} + +func (repo *PostgresRepository) ViewReportConfig(ctx context.Context, id string) (reports.ReportConfig, error) { + q := ` + SELECT id, name, description, domain_id, config, metrics, report_template, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + FROM report_config + WHERE id = $1; + ` + row := repo.DB.QueryRowxContext(ctx, q, id) + if err := row.Err(); err != nil { + return reports.ReportConfig{}, err + } + var dbr dbReport + if err := row.StructScan(&dbr); err != nil { + if err == sql.ErrNoRows { + return reports.ReportConfig{}, repoerr.ErrNotFound + } + return reports.ReportConfig{}, err + } + rpt, err := dbToReport(dbr) + if err != nil { + return reports.ReportConfig{}, err + } + + return rpt, nil +} + +func (repo *PostgresRepository) RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (reports.ReportConfig, error) { + query := ` + WITH selected_report AS ( + SELECT + r.id, + r.domain_id + FROM + report_config r + WHERE + r.id = :id + LIMIT 1 + ), + selected_report_roles AS ( + SELECT + rr.entity_id AS report_id, + rrm.member_id AS member_id, + rr.id AS role_id, + rr."name" AS role_name, + jsonb_agg(DISTINCT rra."action") AS actions, + 'direct' AS access_type, + '' AS access_provider_id + FROM + reports_roles rr + JOIN + reports_role_members rrm ON rr.id = rrm.role_id + JOIN + reports_role_actions rra ON rr.id = rra.role_id + JOIN + selected_report sr ON sr.id = rr.entity_id + AND rrm.member_id = :member_id + GROUP BY + rr.entity_id, rr.id, rr.name, rrm.member_id + ), + selected_domain_roles AS ( + SELECT + sr.id AS report_id, + drm.member_id AS member_id, + dr.id AS role_id, + dr."name" AS role_name, + jsonb_agg(DISTINCT all_actions."action") AS actions, + 'domain' AS access_type, + dr.entity_id AS access_provider_id + FROM + domains d + JOIN + selected_report sr ON sr.domain_id = d.id + JOIN + domains_roles dr ON dr.entity_id = d.id + JOIN + domains_role_members drm ON dr.id = drm.role_id + JOIN + domains_role_actions dra ON dr.id = dra.role_id + JOIN + domains_role_actions all_actions ON dr.id = all_actions.role_id + WHERE + drm.member_id = :member_id + AND dra."action" LIKE 'report%' + GROUP BY + sr.id, dr.entity_id, dr.id, dr."name", drm.member_id + ), + all_roles AS ( + SELECT + srr.report_id, + srr.member_id, + srr.role_id, + srr.role_name, + srr.actions, + srr.access_type, + srr.access_provider_id + FROM + selected_report_roles srr + UNION + SELECT + sdr.report_id, + sdr.member_id, + sdr.role_id, + sdr.role_name, + sdr.actions, + sdr.access_type, + sdr.access_provider_id + FROM + selected_domain_roles sdr + ), + final_roles AS ( + SELECT + ar.report_id, + ar.member_id, + jsonb_agg( + jsonb_build_object( + 'role_id', ar.role_id, + 'role_name', ar.role_name, + 'actions', ar.actions, + 'access_type', ar.access_type, + 'access_provider_id', ar.access_provider_id + ) + ) AS roles + FROM all_roles ar + GROUP BY + ar.report_id, ar.member_id + ) + SELECT + r2.id, + r2."name", + r2.description, + r2.domain_id, + r2.status, + r2.created_at, + r2.created_by, + r2.updated_at, + r2.updated_by, + r2.due, + r2.recurring, + r2.recurring_period, + r2.start_datetime, + r2.config, + r2.email, + r2.metrics, + r2.report_template, + fr.member_id, + fr.roles + FROM report_config r2 + JOIN final_roles fr ON fr.report_id = r2.id + ` + parameters := map[string]any{ + "id": id, + "member_id": memberID, + } + row, err := repo.DB.NamedQueryContext(ctx, query, parameters) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer row.Close() + + dbreport := dbReport{} + if !row.Next() { + return reports.ReportConfig{}, repoerr.ErrNotFound + } + + if err := row.StructScan(&dbreport); err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + cfg, err := dbToReport(dbreport) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + return cfg, nil +} + +func (repo *PostgresRepository) UpdateReportConfigStatus(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + q := `UPDATE report_config SET status = :status, updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id + RETURNING id, name, description, domain_id, metrics, email, config, + start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status;` + + dbRpt, err := reportToDb(cfg) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + row, err := repo.DB.NamedQueryContext(ctx, q, dbRpt) + if err != nil { + return reports.ReportConfig{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + dbr := dbReport{} + if row.Next() { + if err := row.StructScan(&dbr); err != nil { + return reports.ReportConfig{}, err + } + + res, err := dbToReport(dbr) + if err != nil { + return reports.ReportConfig{}, err + } + return res, err + } + + return reports.ReportConfig{}, repoerr.ErrNotFound +} + +func (repo *PostgresRepository) UpdateReportConfig(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + var query []string + + if cfg.Name != "" { + query = append(query, "name = :name") + } + + if cfg.Description != "" { + query = append(query, "description = :description") + } + + if len(cfg.Metrics) > 0 { + query = append(query, "metrics = :metrics") + } + + if cfg.Email != nil { + query = append(query, "email = :email") + } + + if cfg.Config != nil { + query = append(query, "config = :config") + } + + var q string + if len(query) > 0 { + q = strings.Join(query, ", ") + } + + q = fmt.Sprintf(` + UPDATE report_config + SET %s, + updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id + RETURNING id, name, description, domain_id, config, metrics, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + `, q) + + dbr, err := reportToDb(cfg) + if err != nil { + return reports.ReportConfig{}, err + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return reports.ReportConfig{}, err + } + defer row.Close() + + var dbReport dbReport + if !row.Next() { + if err := row.Err(); err != nil { + return reports.ReportConfig{}, err + } + return reports.ReportConfig{}, repoerr.ErrNotFound + } + if err := row.StructScan(&dbReport); err != nil { + return reports.ReportConfig{}, err + } + rpt, err := dbToReport(dbReport) + if err != nil { + return reports.ReportConfig{}, err + } + + return rpt, nil +} + +func (repo *PostgresRepository) UpdateReportSchedule(ctx context.Context, cfg reports.ReportConfig) (reports.ReportConfig, error) { + q := ` + UPDATE report_config + SET start_datetime = :start_datetime, due = :due, recurring = :recurring, + recurring_period = :recurring_period, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id + RETURNING id, name, description, domain_id, config, metrics, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + ` + + dbr, err := reportToDb(cfg) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return reports.ReportConfig{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + var dbReport dbReport + if !row.Next() { + if err := row.Err(); err != nil { + return reports.ReportConfig{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + return reports.ReportConfig{}, repoerr.ErrNotFound + } + if err := row.StructScan(&dbReport); err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + report, err := dbToReport(dbReport) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return report, nil +} + +func (repo *PostgresRepository) RemoveReportConfig(ctx context.Context, id string) error { + q := ` + DELETE FROM report_config + WHERE id = $1; + ` + + result, err := repo.DB.ExecContext(ctx, q, id) + if err != nil { + return err + } + + if _, err := result.RowsAffected(); err != nil { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo *PostgresRepository) ListAllReportsConfig(ctx context.Context, pm reports.PageMeta) (reports.ReportConfigPage, error) { + listReportsQuery := ` + SELECT id, name, description, domain_id, metrics, email, config, + start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status + FROM report_config rc %s %s %s; + ` + + pq := pageReportQuery(pm) + orderClause := reportsOrderClause(pm) + pgData := reportsPageData(pm) + + q := fmt.Sprintf(listReportsQuery, pq, orderClause, pgData) + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return reports.ReportConfigPage{}, err + } + defer rows.Close() + + cfgs := []reports.ReportConfig{} + for rows.Next() { + var r dbReport + if err := rows.StructScan(&r); err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + rpt, err := dbToReport(r) + if err != nil { + return reports.ReportConfigPage{}, err + } + cfgs = append(cfgs, rpt) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM report_config rc %s;`, pq) + + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + pm.Total = total + ret := reports.ReportConfigPage{ + PageMeta: pm, + ReportConfigs: cfgs, + } + + return ret, nil +} + +func (repo *PostgresRepository) ListUserReportsConfig(ctx context.Context, userID string, pm reports.PageMeta) (reports.ReportConfigPage, error) { + pq := pageReportQuery(pm) + orderClause := reportsOrderClause(pm) + pgData := reportsPageData(pm) + + pm.UserID = userID + userJoin := ` + INNER JOIN reports_roles rr ON rr.entity_id = rc.id + INNER JOIN reports_role_members rrm ON rrm.role_id = rr.id AND rrm.member_id = :user_id + ` + + whereClause := pq + + innerQ := fmt.Sprintf(` + SELECT DISTINCT rc.id, rc.name, rc.description, rc.domain_id, rc.metrics, rc.email, rc.config, + rc.start_datetime, rc.due, rc.recurring, rc.recurring_period, rc.created_at, rc.created_by, rc.updated_at, rc.updated_by, rc.status + FROM report_config rc + %s + %s + `, userJoin, whereClause) + + q := fmt.Sprintf(` + SELECT * FROM (%s) AS sub %s %s; + `, innerQ, orderClause, pgData) + + rows, err := repo.DB.NamedQueryContext(ctx, q, pm) + if err != nil { + return reports.ReportConfigPage{}, err + } + defer rows.Close() + + cfgs := []reports.ReportConfig{} + for rows.Next() { + var r dbReport + if err := rows.StructScan(&r); err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + rpt, err := dbToReport(r) + if err != nil { + return reports.ReportConfigPage{}, err + } + cfgs = append(cfgs, rpt) + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM (%s) AS count_sub;`, innerQ) + total, err := postgres.Total(ctx, repo.DB, cq, pm) + if err != nil { + return reports.ReportConfigPage{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + pm.Total = total + + return reports.ReportConfigPage{ + PageMeta: pm, + ReportConfigs: cfgs, + }, nil +} + +func (repo *PostgresRepository) UpdateReportDue(ctx context.Context, id string, due time.Time) (reports.ReportConfig, error) { + q := ` + UPDATE report_config + SET due = :due, updated_at = :updated_at WHERE id = :id + RETURNING id, name, description, domain_id, config, metrics, + email, start_datetime, due, recurring, recurring_period, created_at, created_by, updated_at, updated_by, status; + ` + + dbr := dbReport{ + ID: id, + UpdatedAt: time.Now().UTC(), + Due: sql.NullTime{Time: due}, + } + if !due.IsZero() { + dbr.Due.Valid = true + } + + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return reports.ReportConfig{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + var dbReport dbReport + if !row.Next() { + if err := row.Err(); err != nil { + return reports.ReportConfig{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + return reports.ReportConfig{}, repoerr.ErrNotFound + } + if err := row.StructScan(&dbReport); err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + report, err := dbToReport(dbReport) + if err != nil { + return reports.ReportConfig{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return report, nil +} + +func (repo *PostgresRepository) UpdateReportTemplate(ctx context.Context, domainID, reportID string, template reports.ReportTemplate) error { + q := ` + UPDATE report_config + SET report_template = :report_template, updated_at = :updated_at + WHERE id = :id AND domain_id = :domain_id` + + dbr := dbReport{ + ID: reportID, + DomainID: domainID, + UpdatedAt: time.Now().UTC(), + ReportTemplate: template, + } + + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + defer row.Close() + + return nil +} + +func (repo *PostgresRepository) ViewReportTemplate(ctx context.Context, domainID, reportID string) (reports.ReportTemplate, error) { + q := ` + SELECT COALESCE(report_template, '') as report_template + FROM report_config + WHERE id = $1 AND domain_id = $2` + + var template reports.ReportTemplate + err := repo.DB.QueryRowxContext(ctx, q, reportID, domainID).Scan(&template) + if err != nil { + if err == sql.ErrNoRows { + return "", repoerr.ErrNotFound + } + return "", errors.Wrap(repoerr.ErrViewEntity, err) + } + + return template, nil +} + +func (repo *PostgresRepository) DeleteReportTemplate(ctx context.Context, domainID, reportID string) error { + q := ` + UPDATE report_config + SET report_template = '', updated_at = :updated_at + WHERE id = :id AND domain_id = :domain_id` + + dbr := dbReport{ + ID: reportID, + DomainID: domainID, + UpdatedAt: time.Now().UTC(), + } + row, err := repo.DB.NamedQueryContext(ctx, q, dbr) + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + defer row.Close() + + return nil +} + +func reportsOrderClause(pm reports.PageMeta) string { + dir := api.DescDir + if pm.Dir == api.AscDir { + dir = api.AscDir + } + switch pm.Order { + case api.NameKey: + return fmt.Sprintf("ORDER BY name %s, id %s", dir, dir) + case api.CreatedAtOrder: + return fmt.Sprintf("ORDER BY created_at %s, id %s", dir, dir) + default: + return fmt.Sprintf("ORDER BY COALESCE(updated_at, created_at) %s, id %s", dir, dir) + } +} + +func reportsPageData(pm reports.PageMeta) string { + pgData := "" + if pm.Limit != 0 { + pgData = "LIMIT :limit" + } + if pm.Offset != 0 { + pgData += " OFFSET :offset" + } + return pgData +} + +func pageReportQuery(pm reports.PageMeta) string { + var query []string + if pm.Status != reports.AllStatus { + query = append(query, "rc.status = :status") + } + if pm.Domain != "" { + query = append(query, "rc.domain_id = :domain_id") + } + if pm.ScheduledBefore != nil { + query = append(query, "rc.due < :scheduled_before") + } + if pm.ScheduledAfter != nil { + query = append(query, "rc.due > :scheduled_after") + } + if pm.Name != "" { + query = append(query, "rc.name ILIKE '%' || :name || '%'") + } + + var q string + if len(query) > 0 { + q = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return q +} diff --git a/reports/postgres/repository_test.go b/reports/postgres/repository_test.go new file mode 100644 index 000000000..6b648bd76 --- /dev/null +++ b/reports/postgres/repository_test.go @@ -0,0 +1,951 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/pkg/uuid" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var ( + namegen = namegenerator.NewGenerator() + idProvider = uuid.New() +) + +func generateUUID(t *testing.T) string { + id, err := idProvider.ID() + require.Nil(t, err, fmt.Sprintf("generate uuid unexpected error: %s", err)) + return id +} + +func TestAddReportConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + Description: namegen.Generate(), + DomainID: generateUUID(t), + Config: &reports.MetricConfig{ + From: "now-1h", + To: "now", + Title: "Test Report", + }, + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + Email: &reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + Content: "Report content", + }, + Schedule: schedule.Schedule{ + StartDateTime: time.Now().UTC(), + Time: time.Now().UTC().Add(time.Hour), + Recurring: schedule.Daily, + RecurringPeriod: 1, + }, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + } + + cases := []struct { + desc string + report reports.ReportConfig + err error + }{ + { + desc: "add valid report config", + report: reportConfig, + err: nil, + }, + { + desc: "add duplicate report config", + report: reportConfig, + err: repoerr.ErrConflict, + }, + { + desc: "add report config with empty ID", + report: reports.ReportConfig{ + Name: namegen.Generate(), + DomainID: generateUUID(t), + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + }, + err: repoerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.AddReportConfig(context.Background(), tc.report) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.NotEmpty(t, rpt.ID) + require.Equal(t, tc.report.Name, rpt.Name) + require.Equal(t, tc.report.DomainID, rpt.DomainID) + require.Equal(t, tc.report.Status, rpt.Status) + }) + } +} + +func TestViewReportConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + Description: namegen.Generate(), + DomainID: generateUUID(t), + Config: &reports.MetricConfig{ + From: "now-1h", + To: "now", + Title: "Test Report", + }, + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + Email: &reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + }, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "view existing report config", + id: saved.ID, + err: nil, + }, + { + desc: "view non-existing report config", + id: generateUUID(t), + err: repoerr.ErrNotFound, + }, + { + desc: "view with empty id", + id: "", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.ViewReportConfig(context.Background(), tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, saved.ID, rpt.ID) + require.Equal(t, saved.Name, rpt.Name) + require.Equal(t, saved.DomainID, rpt.DomainID) + }) + } +} + +func TestUpdateReportConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + Description: namegen.Generate(), + DomainID: generateUUID(t), + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + report reports.ReportConfig + err error + }{ + { + desc: "update report name", + report: reports.ReportConfig{ + ID: saved.ID, + Name: "Updated Name", + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + }, + err: nil, + }, + { + desc: "update report description", + report: reports.ReportConfig{ + ID: saved.ID, + Description: "Updated Description", + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + }, + err: nil, + }, + { + desc: "update non-existing report", + report: reports.ReportConfig{ + ID: generateUUID(t), + Name: "New Name", + UpdatedAt: time.Now().UTC(), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.UpdateReportConfig(context.Background(), tc.report) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.report.ID, rpt.ID) + if tc.report.Name != "" { + require.Equal(t, tc.report.Name, rpt.Name) + } + if tc.report.Description != "" { + require.Equal(t, tc.report.Description, rpt.Description) + } + }) + } +} + +func TestUpdateReportConfigStatus(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + Metrics: []reports.ReqMetric{}, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + report reports.ReportConfig + err error + }{ + { + desc: "disable report", + report: reports.ReportConfig{ + ID: saved.ID, + Status: reports.DisabledStatus, + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + }, + err: nil, + }, + { + desc: "enable report", + report: reports.ReportConfig{ + ID: saved.ID, + Status: reports.EnabledStatus, + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + }, + err: nil, + }, + { + desc: "update status of non-existing report", + report: reports.ReportConfig{ + ID: generateUUID(t), + Status: reports.DisabledStatus, + UpdatedAt: time.Now().UTC(), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.UpdateReportConfigStatus(context.Background(), tc.report) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.report.Status, rpt.Status) + }) + } +} + +func TestRemoveReportConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + Metrics: []reports.ReqMetric{}, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "remove existing report", + id: saved.ID, + err: nil, + }, + { + desc: "remove non-existing report", + id: generateUUID(t), + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.RemoveReportConfig(context.Background(), tc.id) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} + +func TestListReportsConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + + num := uint64(10) + for i := uint64(0); i < num; i++ { + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: fmt.Sprintf("Report-%d", i), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + Metrics: []reports.ReqMetric{}, + } + _, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + } + + cases := []struct { + desc string + pageMeta reports.PageMeta + size uint64 + err error + }{ + { + desc: "list all reports", + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: num, + Offset: 0, + }, + size: num, + err: nil, + }, + { + desc: "list with limit", + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 5, + Offset: 0, + }, + size: 5, + err: nil, + }, + { + desc: "list with offset", + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: num, + Offset: 5, + }, + size: 5, + err: nil, + }, + { + desc: "list enabled reports", + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: num, + Status: reports.EnabledStatus, + }, + size: num, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListAllReportsConfig(context.Background(), tc.pageMeta) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.size, uint64(len(page.ReportConfigs))) + }) + } +} + +func TestListUserReportsConfig(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + userID := generateUUID(t) + otherUserID := generateUUID(t) + + num := 10 + var allCfgs []reports.ReportConfig + for i := range num { + cfg := reports.ReportConfig{ + ID: generateUUID(t), + Name: fmt.Sprintf("Report-%d", i), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute), + UpdatedAt: time.Now().UTC().Add(time.Duration(i) * time.Minute), + Metrics: []reports.ReqMetric{}, + } + cfg, err := repo.AddReportConfig(context.Background(), cfg) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + allCfgs = append(allCfgs, cfg) + } + + // Assign userID to the first 5 report configs via direct role INSERT. + for i := range 5 { + roleID := generateUUID(t) + _, err := db.Exec(`INSERT INTO reports_roles (id, name, entity_id) VALUES ($1, $2, $3)`, roleID, "admin", allCfgs[i].ID) + require.Nil(t, err, fmt.Sprintf("insert reports_roles unexpected error: %s", err)) + _, err = db.Exec(`INSERT INTO reports_role_members (role_id, member_id, entity_id) VALUES ($1, $2, $3)`, roleID, userID, allCfgs[i].ID) + require.Nil(t, err, fmt.Sprintf("insert reports_role_members unexpected error: %s", err)) + } + + cases := []struct { + desc string + userID string + pageMeta reports.PageMeta + size int + err error + }{ + { + desc: "list user reports returns only accessible reports", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 0, + }, + size: 5, + err: nil, + }, + { + desc: "list user reports with limit", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 3, + Offset: 0, + }, + size: 3, + err: nil, + }, + { + desc: "list user reports with offset", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 3, + }, + size: 2, + err: nil, + }, + { + desc: "list user reports with enabled status filter", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Status: reports.EnabledStatus, + }, + size: 5, + err: nil, + }, + { + desc: "list reports for user with no role assignments returns 0", + userID: otherUserID, + pageMeta: reports.PageMeta{ + Domain: domainID, + Limit: 100, + Offset: 0, + }, + size: 0, + err: nil, + }, + { + desc: "list user reports with non-existing domain returns 0", + userID: userID, + pageMeta: reports.PageMeta{ + Domain: generateUUID(t), + Limit: 100, + Offset: 0, + }, + size: 0, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.ListUserReportsConfig(context.Background(), tc.userID, tc.pageMeta) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.size, len(page.ReportConfigs), fmt.Sprintf("%s: expected %d reports, got %d", tc.desc, tc.size, len(page.ReportConfigs))) + }) + } +} + +func TestUpdateReportSchedule(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + CreatedBy: generateUUID(t), + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + Schedule: schedule.Schedule{ + StartDateTime: time.Now().UTC(), + Time: time.Now().UTC().Add(time.Hour), + Recurring: schedule.Daily, + RecurringPeriod: 1, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newSchedule := schedule.Schedule{ + StartDateTime: time.Now().UTC().Add(24 * time.Hour), + Time: time.Now().UTC().Add(25 * time.Hour), + Recurring: schedule.Weekly, + RecurringPeriod: 2, + } + + cases := []struct { + desc string + report reports.ReportConfig + expected schedule.Schedule + err error + }{ + { + desc: "update schedule", + report: reports.ReportConfig{ + ID: saved.ID, + Schedule: newSchedule, + UpdatedAt: time.Now().UTC(), + UpdatedBy: generateUUID(t), + }, + expected: newSchedule, + err: nil, + }, + { + desc: "update schedule of non-existing report", + report: reports.ReportConfig{ + ID: generateUUID(t), + Schedule: newSchedule, + UpdatedAt: time.Now().UTC(), + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.UpdateReportSchedule(context.Background(), tc.report) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.expected.Recurring, rpt.Schedule.Recurring) + require.Equal(t, tc.expected.RecurringPeriod, rpt.Schedule.RecurringPeriod) + }) + } +} + +func TestUpdateReportDue(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: generateUUID(t), + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + newDue := time.Now().UTC().Add(24 * time.Hour).Truncate(time.Microsecond) + + cases := []struct { + desc string + id string + due time.Time + err error + }{ + { + desc: "update due time", + id: saved.ID, + due: newDue, + err: nil, + }, + { + desc: "update due time of non-existing report", + id: generateUUID(t), + due: newDue, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + rpt, err := repo.UpdateReportDue(context.Background(), tc.id, tc.due) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.True(t, tc.due.Equal(rpt.Schedule.Time)) + }) + } +} + +func TestUpdateReportTemplate(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + template := reports.ReportTemplate("Test Template") + + cases := []struct { + desc string + domainID string + reportID string + template reports.ReportTemplate + err error + }{ + { + desc: "update template", + domainID: domainID, + reportID: saved.ID, + template: template, + err: nil, + }, + { + desc: "update template for non-existing report", + domainID: domainID, + reportID: generateUUID(t), + template: template, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.UpdateReportTemplate(context.Background(), tc.domainID, tc.reportID, tc.template) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + }) + } +} + +func TestViewReportTemplate(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + template := reports.ReportTemplate("Test Template") + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + ReportTemplate: template, + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + domainID string + reportID string + expected reports.ReportTemplate + err error + }{ + { + desc: "view existing template", + domainID: domainID, + reportID: saved.ID, + expected: template, + err: nil, + }, + { + desc: "view template for non-existing report", + domainID: domainID, + reportID: generateUUID(t), + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tmpl, err := repo.ViewReportTemplate(context.Background(), tc.domainID, tc.reportID) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + require.Equal(t, tc.expected, tmpl) + }) + } +} + +func TestDeleteReportTemplate(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM report_config") + require.Nil(t, err, fmt.Sprintf("clean report_config unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + domainID := generateUUID(t) + template := reports.ReportTemplate("Test Template") + + reportConfig := reports.ReportConfig{ + ID: generateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: time.Now().UTC(), + UpdatedAt: time.Now().UTC(), + ReportTemplate: template, + Metrics: []reports.ReqMetric{ + { + ChannelID: generateUUID(t), + Name: "temperature", + }, + }, + } + + saved, err := repo.AddReportConfig(context.Background(), reportConfig) + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + cases := []struct { + desc string + domainID string + reportID string + err error + }{ + { + desc: "delete existing template", + domainID: domainID, + reportID: saved.ID, + err: nil, + }, + { + desc: "delete template for non-existing report", + domainID: domainID, + reportID: generateUUID(t), + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.DeleteReportTemplate(context.Background(), tc.domainID, tc.reportID) + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + return + } + require.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) + + if tc.reportID == saved.ID { + tmpl, err := repo.ViewReportTemplate(context.Background(), tc.domainID, tc.reportID) + require.Nil(t, err) + require.Empty(t, tmpl) + } + }) + } +} diff --git a/reports/postgres/setup_test.go b/reports/postgres/setup_test.go new file mode 100644 index 000000000..3549e478f --- /dev/null +++ b/reports/postgres/setup_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "database/sql" + "fmt" + "log" + "os" + "testing" + "time" + + "github.com/absmach/supermq/pkg/postgres" + rpostgres "github.com/absmach/supermq/reports/postgres" + "github.com/jmoiron/sqlx" + dockertest "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "go.opentelemetry.io/otel" +) + +var ( + db *sqlx.DB + database postgres.Database + tracer = otel.Tracer("repo_tests") +) + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16.2-alpine", + Env: []string{ + "POSTGRES_USER=test", + "POSTGRES_PASSWORD=test", + "POSTGRES_DB=test", + "listen_addresses = '*'", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + port := container.GetPort("5432/tcp") + + pool.MaxWait = 120 * time.Second + if err := pool.Retry(func() error { + url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) + db, err := sql.Open("pgx", url) + if err != nil { + return err + } + return db.Ping() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + dbConfig := postgres.Config{ + Host: "localhost", + Port: port, + User: "test", + Pass: "test", + Name: "test", + SSLMode: "disable", + SSLCert: "", + SSLKey: "", + SSLRootCert: "", + } + + migration, err := rpostgres.Migration() + if err != nil { + log.Fatalf("Could not get migration: %s", err) + } + if db, err = postgres.Setup(dbConfig, *migration); err != nil { + log.Fatalf("Could not setup test DB connection: %s", err) + } + + database = postgres.NewDatabase(db, dbConfig, tracer) + + code := m.Run() + + db.Close() + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} diff --git a/reports/reports.go b/reports/reports.go new file mode 100644 index 000000000..c3d292ec1 --- /dev/null +++ b/reports/reports.go @@ -0,0 +1,436 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "context" + "encoding/json" + "fmt" + "net/mail" + "strings" + "time" + + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/reltime" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/schedule" + "github.com/absmach/supermq/pkg/transformers/senml" +) + +var ( + errFromTimeNotProvided = errors.New("\"from time\" not provided") + errInvalidFromTime = errors.New("invalid \"from time\" ") + errToTimeNotProvided = errors.New("\"to time\" not provided") + errTitleNotProvided = errors.New("title not provided") + errInvalidToTime = errors.New("invalid \"to time\"") + errAggIntervalTimeNotProvided = errors.New("aggregation interval time not provided") + errInvalidAggInterval = errors.New("invalid aggregation interval time") + errNoToEmail = errors.New("no \"To\" email address found") + errChannelIDNotProvided = errors.New("channel id not provided") + errNameNotProvided = errors.New("name not provided") +) + +const ( + errInvalidFormatFmt = "invalid format %s" + errInvalidReportActionFmt = "invalid action %s" + errInvalidToEmail = "invalid \"To\" email %s" + + errUnknownAggregationFmt = "unknown aggregation type %d" + errUnknownAggregationStringFmt = "unknown aggregation type %s" +) + +type Report struct { + Metric Metric `json:"metric,omitempty"` + Messages []senml.Message `json:"messages,omitempty"` +} + +type ReportPage struct { + Total uint64 `json:"total"` + From time.Time `json:"from,omitempty"` + To time.Time `json:"to,omitempty"` + Aggregation AggConfig `json:"aggregation,omitempty"` + Reports []Report `json:"reports,omitempty"` + File ReportFile `json:"file,omitempty"` +} + +type ReportFile struct { + Name string `json:"name,omitempty"` + Data []byte `json:"data,omitempty"` + Format Format `json:"format,omitempty"` +} + +type AggConfig struct { + AggType Aggregation `json:"agg_type,omitempty"` // Optional field + Interval string `json:"interval,omitempty"` // Mandatory field if "AggType" field is set MAX, MIN, COUNT, SUM, AVG +} + +func (ac AggConfig) Validate() error { + if ac.AggType != AggregationNONE { + if ac.Interval == "" { + return errAggIntervalTimeNotProvided + } + + if _, err := time.ParseDuration(ac.Interval); err != nil { + return errInvalidAggInterval + } + } + return nil +} + +type MetricConfig struct { + From string `json:"from,omitempty"` // Mandatory field + To string `json:"to,omitempty"` // Mandatory field + Title string `json:"title,omitempty"` // Mandatory field + + FileFormat Format `json:"file_format"` // Optional field + Timezone string `json:"timezone,omitempty"` // Optional field, defaults to UTC + + Aggregation AggConfig `json:"aggregation,omitempty"` // Optional field +} + +func (mc MetricConfig) Validate() error { + if mc.From == "" { + return errFromTimeNotProvided + } + + if _, err := reltime.Parse(mc.From); err != nil { + return errInvalidFromTime + } + + if mc.To == "" { + return errToTimeNotProvided + } + + if _, err := reltime.Parse(mc.To); err != nil { + return errInvalidToTime + } + + if mc.Title == "" { + return errTitleNotProvided + } + + if err := mc.Aggregation.Validate(); err != nil { + return err + } + + if tz := strings.TrimSpace(mc.Timezone); tz != "" { + if _, err := time.LoadLocation(tz); err != nil { + return errors.Wrap(fmt.Errorf("invalid timezone: %s", tz), err) + } + } + + return nil +} + +type Metric struct { + ChannelID string `json:"channel_id,omitempty"` // Mandatory field + ClientID string `json:"client_id,omitempty"` // Optional field + Name string `json:"name,omitempty"` // Mandatory field + Subtopic string `json:"subtopic,omitempty"` // Optional field + Protocol string `json:"protocol,omitempty"` // Optional field + Format string `json:"format,omitempty"` // Optional field +} + +type ReqMetric struct { + ChannelID string `json:"channel_id,omitempty"` // Mandatory field + ClientIDs []string `json:"client_ids,omitempty"` // Optional field + Name string `json:"name,omitempty"` // Mandatory field + Subtopic string `json:"subtopic,omitempty"` // Optional field + Protocol string `json:"protocol,omitempty"` // Optional field + Format string `json:"format,omitempty"` // Optional field +} + +func (rm ReqMetric) Validate() error { + if rm.ChannelID == "" { + return errChannelIDNotProvided + } + if rm.Name == "" { + return errNameNotProvided + } + return nil +} + +type ReportConfig struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + DomainID string `json:"domain_id"` + Schedule schedule.Schedule `json:"schedule,omitempty"` + Config *MetricConfig `json:"config,omitempty"` + Email *EmailSetting `json:"email,omitempty"` + Metrics []ReqMetric `json:"metrics,omitempty"` + ReportTemplate ReportTemplate `json:"report_template,omitempty"` + Status Status `json:"status"` + CreatedAt time.Time `json:"created_at"` + CreatedBy string `json:"created_by,omitempty"` + UpdatedAt time.Time `json:"updated_at"` + UpdatedBy string `json:"updated_by,omitempty"` + Roles []roles.MemberRoleActions `json:"roles,omitempty"` +} + +type ReportConfigPage struct { + PageMeta + ReportConfigs []ReportConfig `json:"report_configs"` +} + +type EmailSetting struct { + To []string `json:"to,omitempty"` + Subject string `json:"subject,omitempty"` + Content string `json:"content,omitempty"` +} + +func (es *EmailSetting) Validate() error { + if len(es.To) == 0 { + return errNoToEmail + } + for _, to := range es.To { + if _, err := mail.ParseAddress(to); err != nil { + return errors.Wrap(fmt.Errorf(errInvalidToEmail, to), err) + } + } + return nil +} + +type Format uint8 + +const ( + PDF = iota + CSV + AllFormats +) + +const ( + PdfFormat = "pdf" + CsvFormat = "csv" + All_Formats = "AllFormats" +) + +func (f Format) String() string { + switch f { + case PDF: + return PdfFormat + case CSV: + return CsvFormat + case AllFormats: + return All_Formats + default: + return Unknown + } +} + +func (f Format) Extension() string { + switch f { + case PDF: + return PdfFormat + case CSV: + return CsvFormat + default: + return Unknown + } +} + +func (f Format) ContentType() string { + switch f { + case PDF: + return "application/pdf" + case CSV: + return "text/csv" + default: + return Unknown + } +} + +func ToFormat(format string) (Format, error) { + switch format { + case "", PdfFormat: + return PDF, nil + case CsvFormat: + return CSV, nil + case All_Formats: + return AllFormats, nil + } + return Format(0), fmt.Errorf(errInvalidFormatFmt, format) +} + +func (f Format) MarshalJSON() ([]byte, error) { + return json.Marshal(f.String()) +} + +func (f *Format) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToFormat(str) + *f = val + return err +} + +type ReportAction uint8 + +const ( + ViewReport = iota + DownloadReport + EmailReport +) + +const ( + ViewReportAction = "view" + DownloadReportAction = "download" + EmailReportAction = "email" +) + +func (ra ReportAction) String() string { + switch ra { + case ViewReport: + return ViewReportAction + case DownloadReport: + return DownloadReportAction + case EmailReport: + return EmailReportAction + default: + return Unknown + } +} + +func ToReportAction(action string) (ReportAction, error) { + switch action { + case "", ViewReportAction: + return ViewReport, nil + case DownloadReportAction: + return DownloadReport, nil + case EmailReportAction: + return EmailReport, nil + } + return ReportAction(0), fmt.Errorf(errInvalidReportActionFmt, action) +} + +func (ra ReportAction) MarshalJSON() ([]byte, error) { + return json.Marshal(ra.String()) +} + +func (ra *ReportAction) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToReportAction(str) + *ra = val + return err +} + +type Aggregation uint8 + +const ( + AggregationNONE = iota + AggregationMAX + AggregationMIN + AggregationSUM + AggregationCOUNT + AggregationAVG +) + +const ( + aggregationNONE = "none" + aggregationMAX = "max" + aggregationMIN = "min" + aggregationSUM = "sum" + aggregationCOUNT = "count" + aggregationAVG = "avg" +) + +func (a Aggregation) String() string { + switch a { + case AggregationNONE: + return aggregationNONE + case AggregationMAX: + return aggregationMAX + case AggregationMIN: + return aggregationMIN + case AggregationSUM: + return aggregationSUM + case AggregationCOUNT: + return aggregationCOUNT + case AggregationAVG: + return aggregationAVG + default: + return fmt.Sprintf(errUnknownAggregationFmt, a) + } +} + +func ToAggregation(agg string) (Aggregation, error) { + switch strings.ToLower(agg) { + case "", aggregationNONE: + return AggregationNONE, nil + case aggregationMAX: + return AggregationMAX, nil + case aggregationMIN: + return AggregationMIN, nil + case aggregationSUM: + return AggregationSUM, nil + case aggregationCOUNT: + return AggregationCOUNT, nil + case aggregationAVG: + return AggregationAVG, nil + default: + return Aggregation(0), fmt.Errorf(errUnknownAggregationStringFmt, agg) + } +} + +func (a Aggregation) MarshalJSON() ([]byte, error) { + return json.Marshal(a.String()) +} + +func (a *Aggregation) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToAggregation(str) + *a = val + return err +} + +type PageMeta struct { + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + Name string `json:"name" db:"name"` + Dir string `json:"dir" db:"dir"` + Order string `json:"order" db:"order"` + Status Status `json:"status,omitempty" db:"status"` + Domain string `json:"domain_id,omitempty" db:"domain_id"` + ScheduledBefore *time.Time `json:"scheduled_before,omitempty" db:"scheduled_before"` // Filter rules scheduled before this time + ScheduledAfter *time.Time `json:"scheduled_after,omitempty" db:"scheduled_after"` // Filter rules scheduled after this time + UserID string `json:"user_id,omitempty" db:"user_id"` +} + +type Repository interface { + AddReportConfig(ctx context.Context, cfg ReportConfig) (ReportConfig, error) + ViewReportConfig(ctx context.Context, id string) (ReportConfig, error) + RetrieveByIDWithRoles(ctx context.Context, id, memberID string) (ReportConfig, error) + UpdateReportConfig(ctx context.Context, cfg ReportConfig) (ReportConfig, error) + UpdateReportSchedule(ctx context.Context, cfg ReportConfig) (ReportConfig, error) + RemoveReportConfig(ctx context.Context, id string) error + UpdateReportConfigStatus(ctx context.Context, cfg ReportConfig) (ReportConfig, error) + ListAllReportsConfig(ctx context.Context, pm PageMeta) (ReportConfigPage, error) + ListUserReportsConfig(ctx context.Context, userID string, pm PageMeta) (ReportConfigPage, error) + UpdateReportDue(ctx context.Context, id string, due time.Time) (ReportConfig, error) + + UpdateReportTemplate(ctx context.Context, domainID, reportID string, template ReportTemplate) error + ViewReportTemplate(ctx context.Context, domainID, reportID string) (ReportTemplate, error) + DeleteReportTemplate(ctx context.Context, domainID, reportID string) error + roles.Repository +} + +type Service interface { + AddReportConfig(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) + ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (ReportConfig, error) + UpdateReportConfig(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) + UpdateReportSchedule(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) + RemoveReportConfig(ctx context.Context, session authn.Session, id string) error + ListReportsConfig(ctx context.Context, session authn.Session, pm PageMeta) (ReportConfigPage, error) + EnableReportConfig(ctx context.Context, session authn.Session, id string) (ReportConfig, error) + DisableReportConfig(ctx context.Context, session authn.Session, id string) (ReportConfig, error) + + UpdateReportTemplate(ctx context.Context, session authn.Session, cfg ReportConfig) error + ViewReportTemplate(ctx context.Context, session authn.Session, id string) (ReportTemplate, error) + DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error + + GenerateReport(ctx context.Context, session authn.Session, config ReportConfig, action ReportAction) (ReportPage, error) + StartScheduler(ctx context.Context) error + roles.RoleManager +} diff --git a/reports/service.go b/reports/service.go new file mode 100644 index 000000000..5d283527a --- /dev/null +++ b/reports/service.go @@ -0,0 +1,530 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + "github.com/absmach/supermq" + grpcReadersV1 "github.com/absmach/supermq/api/grpc/readers/v1" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/emailer" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkglog "github.com/absmach/supermq/pkg/logger" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/reltime" + "github.com/absmach/supermq/pkg/roles" + "github.com/absmach/supermq/pkg/ticker" + "github.com/absmach/supermq/pkg/transformers/senml" + "github.com/absmach/supermq/reports/operations" +) + +const limit = 1000 + +type report struct { + repo Repository + runInfo chan pkglog.RunInfo + idp supermq.IDProvider + email emailer.Emailer + ticker ticker.Ticker + readers grpcReadersV1.ReadersServiceClient + defaultTemplate ReportTemplate + converterURL string + roles.ProvisionManageService +} + +func NewService(repo Repository, runInfo chan pkglog.RunInfo, policy policies.Service, idp supermq.IDProvider, tck ticker.Ticker, emailer emailer.Emailer, readers grpcReadersV1.ReadersServiceClient, template ReportTemplate, converterURL string, availableActions []roles.Action, builtInRoles map[roles.BuiltInRoleName][]roles.Action) (Service, error) { + rpms, err := roles.NewProvisionManageService(operations.EntityType, repo, policy, idp, availableActions, builtInRoles) + if err != nil { + return nil, err + } + return &report{ + repo: repo, + idp: idp, + runInfo: runInfo, + email: emailer, + ticker: tck, + readers: readers, + defaultTemplate: template, + converterURL: converterURL, + ProvisionManageService: rpms, + }, nil +} + +func (r *report) AddReportConfig(ctx context.Context, session authn.Session, cfg ReportConfig) (retCfg ReportConfig, retErr error) { + id, err := r.idp.ID() + if err != nil { + return ReportConfig{}, err + } + + now := time.Now().UTC() + cfg.ID = id + cfg.CreatedAt = now + cfg.CreatedBy = session.UserID + cfg.DomainID = session.DomainID + cfg.Status = EnabledStatus + + if cfg.Schedule.StartDateTime.IsZero() { + cfg.Schedule.StartDateTime = now + } + cfg.Schedule.Time = cfg.Schedule.StartDateTime + + reportConfig, err := r.repo.AddReportConfig(ctx, cfg) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + defer func() { + if retErr != nil { + if errRollBack := r.repo.RemoveReportConfig(ctx, reportConfig.ID); errRollBack != nil { + retErr = errors.Wrap(retErr, errors.Wrap(svcerr.ErrRollbackRepo, errRollBack)) + } + } + }() + + newBuiltInRoleMembers := map[roles.BuiltInRoleName][]roles.Member{ + BuiltInRoleAdmin: {roles.Member(session.UserID)}, + } + + optionalPolicies := []policies.Policy{ + { + SubjectType: policies.DomainType, + Subject: session.DomainID, + Relation: policies.DomainRelation, + ObjectType: operations.EntityType, + Object: reportConfig.ID, + }, + } + + _, err = r.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{reportConfig.ID}, optionalPolicies, newBuiltInRoleMembers) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrAddPolicies, err) + } + + return reportConfig, nil +} + +func (r *report) ViewReportConfig(ctx context.Context, session authn.Session, id string, withRoles bool) (ReportConfig, error) { + var cfg ReportConfig + var err error + switch withRoles { + case true: + cfg, err = r.repo.RetrieveByIDWithRoles(ctx, id, session.UserID) + default: + cfg, err = r.repo.ViewReportConfig(ctx, id) + } + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + return cfg, nil +} + +func (r *report) UpdateReportConfig(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) { + cfg.UpdatedAt = time.Now().UTC() + cfg.UpdatedBy = session.UserID + reportConfig, err := r.repo.UpdateReportConfig(ctx, cfg) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return reportConfig, nil +} + +func (r *report) UpdateReportSchedule(ctx context.Context, session authn.Session, cfg ReportConfig) (ReportConfig, error) { + cfg.UpdatedAt = time.Now().UTC() + cfg.UpdatedBy = session.UserID + cfg.Schedule.Time = cfg.Schedule.StartDateTime + c, err := r.repo.UpdateReportSchedule(ctx, cfg) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return c, nil +} + +func (r *report) RemoveReportConfig(ctx context.Context, session authn.Session, id string) error { + if err := r.repo.RemoveReportConfig(ctx, id); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + return nil +} + +func (r *report) ListReportsConfig(ctx context.Context, session authn.Session, pm PageMeta) (ReportConfigPage, error) { + pm.Domain = session.DomainID + if session.SuperAdmin { + page, err := r.repo.ListAllReportsConfig(ctx, pm) + if err != nil { + return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil + } + page, err := r.repo.ListUserReportsConfig(ctx, session.UserID, pm) + if err != nil { + return ReportConfigPage{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return page, nil +} + +func (r *report) EnableReportConfig(ctx context.Context, session authn.Session, id string) (ReportConfig, error) { + status, err := ToStatus(Enabled) + if err != nil { + return ReportConfig{}, err + } + cfg := ReportConfig{ + ID: id, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, + Status: status, + } + cfg, err = r.repo.UpdateReportConfigStatus(ctx, cfg) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return cfg, nil +} + +func (r *report) DisableReportConfig(ctx context.Context, session authn.Session, id string) (ReportConfig, error) { + status, err := ToStatus(Disabled) + if err != nil { + return ReportConfig{}, err + } + cfg := ReportConfig{ + ID: id, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, + Status: status, + } + cfg, err = r.repo.UpdateReportConfigStatus(ctx, cfg) + if err != nil { + return ReportConfig{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + return cfg, nil +} + +func (r *report) GenerateReport(ctx context.Context, session authn.Session, config ReportConfig, action ReportAction) (ReportPage, error) { + config.DomainID = session.DomainID + + if action != ViewReport && action != DownloadReport && action != EmailReport { + if config.Status != EnabledStatus { + return ReportPage{}, svcerr.ErrInvalidStatus + } + } + + reportPage, err := r.generateReport(ctx, config, action) + if err != nil { + return ReportPage{}, err + } + + return reportPage, nil +} + +func (r *report) generateReport(ctx context.Context, cfg ReportConfig, action ReportAction) (ReportPage, error) { + genReportFile, err := r.generateFileFunc(ctx, action, cfg.Config.FileFormat, cfg.ReportTemplate, cfg.Config.Timezone) + if err != nil { + return ReportPage{}, err + } + + agg := grpcReadersV1.Aggregation_AGGREGATION_UNSPECIFIED + switch cfg.Config.Aggregation.AggType { + case AggregationMAX: + agg = grpcReadersV1.Aggregation_AGGREGATION_MAX + case AggregationMIN: + agg = grpcReadersV1.Aggregation_AGGREGATION_MIN + case AggregationCOUNT: + agg = grpcReadersV1.Aggregation_AGGREGATION_COUNT + case AggregationAVG: + agg = grpcReadersV1.Aggregation_AGGREGATION_AVG + case AggregationSUM: + agg = grpcReadersV1.Aggregation_AGGREGATION_SUM + } + + loc, err := resolveTimezone(cfg.Config.Timezone) + if err != nil { + r.runInfo <- pkglog.RunInfo{ + Level: slog.LevelWarn, + Message: fmt.Sprintf("failed to resolve timezone '%s', falling back to UTC: %s", cfg.Config.Timezone, err), + Details: []slog.Attr{ + slog.String("report_name", cfg.Name), + slog.String("timezone", cfg.Config.Timezone), + }, + } + } + + from, err := reltime.Parse(cfg.Config.From) + if err != nil { + return ReportPage{}, err + } + + to, err := reltime.Parse(cfg.Config.To) + if err != nil { + return ReportPage{}, err + } + + fromDisplay := from.In(loc) + toDisplay := to.In(loc) + + pm := &grpcReadersV1.PageMetadata{ + Aggregation: agg, + Limit: limit, + From: float64(from.UnixNano()), + To: float64(to.UnixNano()), + Interval: cfg.Config.Aggregation.Interval, + } + + var mets []Metric + var reports []Report + for _, metric := range cfg.Metrics { + switch { + case len(metric.ClientIDs) != 0: + for _, clientID := range metric.ClientIDs { + mets = append(mets, Metric{ + ChannelID: metric.ChannelID, + ClientID: clientID, + Name: metric.Name, + Subtopic: metric.Subtopic, + Protocol: metric.Protocol, + Format: metric.Format, + }) + } + default: + mets = append(mets, Metric{ + ChannelID: metric.ChannelID, + Name: metric.Name, + Subtopic: metric.Subtopic, + Protocol: metric.Protocol, + Format: metric.Format, + }) + } + } + + for _, metric := range mets { + sMsgs := []senml.Message{} + + pm.Offset = uint64(0) + pm.Name = metric.Name + if metric.ClientID != "" { + pm.Publisher = metric.ClientID + } + if metric.Subtopic != "" { + pm.Subtopic = metric.Subtopic + } + if metric.Protocol != "" { + pm.Protocol = metric.Protocol + } + if metric.Format != "" { + pm.Format = metric.Format + } + + msgs, err := r.readers.ReadMessages(ctx, &grpcReadersV1.ReadMessagesReq{ + ChannelId: metric.ChannelID, + DomainId: cfg.DomainID, + PageMetadata: pm, + }) + if err != nil { + return ReportPage{}, err + } + for _, msg := range msgs.Messages { + sMsgs = append(sMsgs, convertToSenml(msg.GetSenml())) + } + + for msgs.GetTotal() > (pm.Offset + pm.Limit) { + pm.Offset = pm.Offset + pm.Limit + msgs, err := r.readers.ReadMessages(ctx, &grpcReadersV1.ReadMessagesReq{ + ChannelId: metric.ChannelID, + DomainId: cfg.DomainID, + PageMetadata: pm, + }) + if err != nil { + return ReportPage{}, err + } + for _, msg := range msgs.Messages { + sMsgs = append(sMsgs, convertToSenml(msg.GetSenml())) + } + } + + reports = append(reports, convertToReports(metric, sMsgs)...) + } + + switch { + case genReportFile != nil: + data, err := genReportFile(ctx, cfg.Config.Title, reports) + if err != nil { + return ReportPage{}, err + } + timeStr := strings.ReplaceAll(time.Now().Format(time.RFC3339), ":", "") + filePrefix := cfg.Name + if filePrefix == "" { + filePrefix = "report" + } + fileName := fmt.Sprintf("%s_%s.%s", filePrefix, timeStr, cfg.Config.FileFormat.Extension()) + + file := ReportFile{ + Name: fileName, + Data: data, + Format: cfg.Config.FileFormat, + } + + switch action { + case EmailReport: + if err := r.emailReports(*cfg.Email, file); err != nil { + return ReportPage{}, errors.Wrap(err, svcerr.ErrCreateEntity) + } + + return ReportPage{}, nil + default: + return ReportPage{ + File: file, + }, nil + } + + default: + return ReportPage{ + From: fromDisplay, + To: toDisplay, + Aggregation: cfg.Config.Aggregation, + Total: uint64(len(reports)), + Reports: reports, + }, nil + } +} + +func (r *report) generateFileFunc(_ context.Context, action ReportAction, format Format, customTemplate ReportTemplate, timezone string) (func(context.Context, string, []Report) ([]byte, error), error) { + switch action { + case DownloadReport, EmailReport: + switch format { + case PDF: + return func(ctx context.Context, title string, reports []Report) ([]byte, error) { + return r.generatePDFReport(ctx, title, reports, customTemplate, timezone) + }, nil + case CSV: + return func(ctx context.Context, title string, reports []Report) ([]byte, error) { + return r.generateCSVReport(ctx, title, reports, timezone) + }, nil + default: + return nil, errors.New("file format not supported") + } + default: + return nil, nil + } +} + +func (r *report) emailReports(es EmailSetting, file ReportFile) error { + if err := es.Validate(); err != nil { + return errors.Wrap(svcerr.ErrMalformedEntity, err) + } + + attachments := map[string][]byte{ + file.Name: file.Data, + } + + if err := r.email.SendEmailNotification( + es.To, + "", + es.Subject, + "", + "", + es.Content, + "", + attachments, + ); err != nil { + return err + } + return nil +} + +func convertToSenml(g *grpcReadersV1.SenMLMessage) senml.Message { + if g == nil { + return senml.Message{} + } + return senml.Message{ + Protocol: g.Base.GetProtocol(), + Subtopic: g.Base.GetSubtopic(), + Publisher: g.Base.GetPublisher(), + Channel: g.Base.GetChannel(), + Name: g.GetName(), + Unit: g.GetUnit(), + Time: g.GetTime(), + UpdateTime: g.GetUpdateTime(), + Value: g.Value, + StringValue: g.StringValue, + DataValue: g.DataValue, + BoolValue: g.BoolValue, + Sum: g.Sum, + } +} + +func convertToReports(metric Metric, senmlMsgs []senml.Message) []Report { + if metric.ClientID != "" { + return []Report{ + { + Metric: metric, + Messages: senmlMsgs, + }, + } + } + + return groupReportsByPublisher(metric, senmlMsgs) +} + +func groupReportsByPublisher(metric Metric, sMsgs []senml.Message) []Report { + publishers := map[string][]senml.Message{} + + for _, msg := range sMsgs { + publishers[msg.Publisher] = append(publishers[msg.Publisher], msg) + } + + var groupedReports []Report + for publisher, messages := range publishers { + gMetric := metric + gMetric.ClientID = publisher + groupedReports = append(groupedReports, Report{ + Metric: gMetric, + Messages: messages, + }) + } + + if len(groupedReports) == 0 { + groupedReports = append(groupedReports, Report{ + Metric: metric, + Messages: []senml.Message{}, + }) + } + + return groupedReports +} + +func (r *report) UpdateReportTemplate(ctx context.Context, session authn.Session, cfg ReportConfig) error { + err := r.repo.UpdateReportTemplate(ctx, session.DomainID, cfg.ID, cfg.ReportTemplate) + if err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return nil +} + +func (r *report) ViewReportTemplate(ctx context.Context, session authn.Session, id string) (ReportTemplate, error) { + template, err := r.repo.ViewReportTemplate(ctx, session.DomainID, id) + if err != nil { + return "", errors.Wrap(svcerr.ErrCreateEntity, err) + } + + return template, nil +} + +func (r *report) DeleteReportTemplate(ctx context.Context, session authn.Session, id string) error { + err := r.repo.DeleteReportTemplate(ctx, session.DomainID, id) + if err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/reports/service_test.go b/reports/service_test.go new file mode 100644 index 000000000..4a6dcf511 --- /dev/null +++ b/reports/service_test.go @@ -0,0 +1,702 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/0x6flab/namegenerator" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/authn" + emocks "github.com/absmach/supermq/pkg/emailer/mocks" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + svcerr "github.com/absmach/supermq/pkg/errors/service" + pkglog "github.com/absmach/supermq/pkg/logger" + policymocks "github.com/absmach/supermq/pkg/policies/mocks" + "github.com/absmach/supermq/pkg/roles" + pkgSch "github.com/absmach/supermq/pkg/schedule" + tmocks "github.com/absmach/supermq/pkg/ticker/mocks" + "github.com/absmach/supermq/pkg/uuid" + readmocks "github.com/absmach/supermq/readers/mocks" + "github.com/absmach/supermq/reports" + "github.com/absmach/supermq/reports/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var ( + namegen = namegenerator.NewGenerator() + userID = testsutil.GenerateUUID(&testing.T{}) + domainID = testsutil.GenerateUUID(&testing.T{}) + now = time.Now().UTC() + template = reports.ReportTemplate("") + schedule = pkgSch.Schedule{ + StartDateTime: now, + Recurring: pkgSch.Daily, + RecurringPeriod: 1, + Time: time.Now().Add(-time.Hour), + } + reportName = namegen.Generate() + rptConfig = reports.ReportConfig{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: reportName, + DomainID: domainID, + Status: reports.EnabledStatus, + Schedule: schedule, + CreatedBy: userID, + UpdatedBy: userID, + UpdatedAt: time.Now(), + } +) + +func newService(t *testing.T, runInfo chan pkglog.RunInfo) (reports.Service, *mocks.Repository, *tmocks.Ticker, *policymocks.Service) { + repo := new(mocks.Repository) + mockTicker := new(tmocks.Ticker) + idProvider := uuid.NewMock() + readersSvc := new(readmocks.ReadersServiceClient) + e := new(emocks.Emailer) + policy := new(policymocks.Service) + + availableActions := []roles.Action{} + builtInRoles := map[roles.BuiltInRoleName][]roles.Action{ + "admin": availableActions, + } + + svc, err := reports.NewService(repo, runInfo, policy, idProvider, mockTicker, e, readersSvc, template, "", availableActions, builtInRoles) + if err != nil { + t.Fatalf("Failed to create service: %v", err) + } + return svc, repo, mockTicker, policy +} + +func TestAddReportConfig(t *testing.T) { + svc, repo, _, policies := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + cfg reports.ReportConfig + res reports.ReportConfig + err error + addPoliciesErr error + deletePolicies error + addRoleErr error + deleteErr error + }{ + { + desc: "Add report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + res: rptConfig, + err: nil, + addPoliciesErr: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + err: repoerr.ErrCreateEntity, + addPoliciesErr: nil, + deletePolicies: nil, + addRoleErr: nil, + deleteErr: nil, + }, + { + desc: "Add report config with failed to add policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + res: rptConfig, + addPoliciesErr: svcerr.ErrAuthorization, + err: svcerr.ErrAddPolicies, + }, + { + desc: "Add report config with failed to add policies and failed rollback", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + res: rptConfig, + addPoliciesErr: svcerr.ErrAuthorization, + deleteErr: svcerr.ErrRemoveEntity, + err: svcerr.ErrRollbackRepo, + }, + { + desc: "Add report config with failed to add roles", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + res: rptConfig, + addRoleErr: svcerr.ErrCreateEntity, + err: svcerr.ErrAddPolicies, + }, + { + desc: "Add report config with failed to add roles and failed to delete policies", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: reportName, + Schedule: schedule, + }, + res: rptConfig, + addRoleErr: svcerr.ErrCreateEntity, + deletePolicies: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("AddReportConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr) + policyCall2 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePolicies).Maybe() + repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr) + repoCall2 := repo.On("Remove", context.Background(), mock.Anything).Return(tc.deleteErr).Maybe() + res, err := svc.AddReportConfig(context.Background(), tc.session, tc.cfg) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.NotEmpty(t, res.ID, "expected non-empty result in ID") + assert.Equal(t, tc.cfg.Name, res.Name) + assert.Equal(t, tc.cfg.Schedule, res.Schedule) + } + policyCall.Unset() + policyCall2.Unset() + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestViewReportConfig(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + id string + res reports.ReportConfig + err error + }{ + { + desc: "view report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + res: rptConfig, + err: nil, + }, + { + desc: "view report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("ViewReportConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.ViewReportConfig(context.Background(), tc.session, tc.id, false) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestUpdateReportConfig(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + + newName := namegen.Generate() + now := time.Now().Add(time.Hour) + cases := []struct { + desc string + session authn.Session + cfg reports.ReportConfig + res reports.ReportConfig + err error + }{ + { + desc: "update report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: newName, + ID: rptConfig.ID, + Schedule: schedule, + }, + res: reports.ReportConfig{ + Name: newName, + ID: rptConfig.ID, + DomainID: rptConfig.DomainID, + Status: rptConfig.Status, + Schedule: rptConfig.Schedule, + UpdatedAt: now, + UpdatedBy: userID, + }, + err: nil, + }, + { + desc: "update report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + cfg: reports.ReportConfig{ + Name: rptConfig.Name, + ID: rptConfig.ID, + Schedule: schedule, + }, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateReportConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.UpdateReportConfig(context.Background(), tc.session, tc.cfg) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestListReportsConfig(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + numConfigs := 50 + now := time.Now().Add(time.Hour) + var configs []reports.ReportConfig + for i := 0; i < numConfigs; i++ { + c := reports.ReportConfig{ + ID: testsutil.GenerateUUID(t), + Name: namegen.Generate(), + DomainID: domainID, + Status: reports.EnabledStatus, + CreatedAt: now, + CreatedBy: userID, + Schedule: schedule, + } + configs = append(configs, c) + } + + cases := []struct { + desc string + session authn.Session + pageMeta reports.PageMeta + res reports.ReportConfigPage + err error + superAdmin bool + }{ + { + desc: "list report configs successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: reports.PageMeta{}, + res: reports.ReportConfigPage{ + PageMeta: reports.PageMeta{ + Total: uint64(numConfigs), + Offset: 0, + Limit: 10, + }, + ReportConfigs: configs[0:10], + }, + err: nil, + }, + { + desc: "list report configs successfully with limit", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: reports.PageMeta{ + Limit: 100, + }, + res: reports.ReportConfigPage{ + PageMeta: reports.PageMeta{ + Total: uint64(numConfigs), + Offset: 0, + Limit: 100, + }, + ReportConfigs: configs[0:numConfigs], + }, + err: nil, + }, + { + desc: "list report configs successfully with offset", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: reports.PageMeta{ + Offset: 20, + Limit: 10, + }, + res: reports.ReportConfigPage{ + PageMeta: reports.PageMeta{ + Total: uint64(numConfigs), + Offset: 20, + Limit: 10, + }, + ReportConfigs: configs[20:30], + }, + err: nil, + }, + { + desc: "list report configs with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + pageMeta: reports.PageMeta{}, + err: svcerr.ErrViewEntity, + }, + { + desc: "list report configs as super admin successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: reports.PageMeta{}, + res: reports.ReportConfigPage{ + PageMeta: reports.PageMeta{ + Total: uint64(numConfigs), + Offset: 0, + Limit: 10, + }, + ReportConfigs: configs[0:10], + }, + superAdmin: true, + err: nil, + }, + { + desc: "list report configs as super admin with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + SuperAdmin: true, + }, + pageMeta: reports.PageMeta{}, + superAdmin: true, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var repoCall *mock.Call + if tc.superAdmin { + repoCall = repo.On("ListAllReportsConfig", mock.Anything, mock.Anything).Return(tc.res, tc.err) + } else { + repoCall = repo.On("ListUserReportsConfig", mock.Anything, mock.Anything, mock.Anything).Return(tc.res, tc.err) + } + res, err := svc.ListReportsConfig(context.Background(), tc.session, tc.pageMeta) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestRemoveReportConfig(t *testing.T) { + svc, repo, _, policies := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + id string + err error + deletePoliciesErr error + }{ + { + desc: "remove report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + err: nil, + deletePoliciesErr: nil, + }, + { + desc: "remove report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + err: svcerr.ErrRemoveEntity, + deletePoliciesErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("RemoveReportConfig", mock.Anything, mock.Anything).Return(tc.err) + policyCall := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr) + err := svc.RemoveReportConfig(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + policyCall.Unset() + repoCall.Unset() + }) + } +} + +func TestEnableReportConfig(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + id string + status reports.Status + res reports.ReportConfig + err error + }{ + { + desc: "enable report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + status: reports.EnabledStatus, + res: rptConfig, + err: nil, + }, + { + desc: "enable report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + status: reports.EnabledStatus, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateReportConfigStatus", context.Background(), mock.Anything).Return(tc.res, tc.err) + res, err := svc.EnableReportConfig(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestDisableReportConfig(t *testing.T) { + svc, repo, _, _ := newService(t, make(chan pkglog.RunInfo)) + + cases := []struct { + desc string + session authn.Session + id string + status reports.Status + res reports.ReportConfig + err error + }{ + { + desc: "disable report config successfully", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + status: reports.DisabledStatus, + res: reports.ReportConfig{ + ID: rptConfig.ID, + Name: rptConfig.Name, + DomainID: rptConfig.DomainID, + Status: reports.DisabledStatus, + Schedule: schedule, + UpdatedBy: userID, + UpdatedAt: time.Now(), + }, + err: nil, + }, + { + desc: "disable report config with failed repo", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + id: rptConfig.ID, + status: reports.DisabledStatus, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := repo.On("UpdateReportConfigStatus", mock.Anything, mock.Anything).Return(tc.res, tc.err) + res, err := svc.DisableReportConfig(context.Background(), tc.session, tc.id) + + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, tc.res, res) + } + defer repoCall.Unset() + }) + } +} + +func TestGenerateInstantEmailReport(t *testing.T) { + // nolint:dogsled + svc, _, _, _ := newService(t, make(chan pkglog.RunInfo)) + + validEmailConfig := reports.EmailSetting{ + To: []string{"test@example.com"}, + Subject: "Test Report", + Content: "Please find the attached report.", + } + + validConfig := reports.ReportConfig{ + ID: testsutil.GenerateUUID(&testing.T{}), + Name: "Test Report", + DomainID: domainID, + Status: reports.DisabledStatus, + Email: &validEmailConfig, + Config: &reports.MetricConfig{ + Title: "Test Report", + FileFormat: reports.PDF, + From: "now-1h", + To: "now", + Aggregation: reports.AggConfig{ + AggType: reports.AggregationAVG, + Interval: "1h", + }, + }, + Metrics: []reports.ReqMetric{ + { + ChannelID: testsutil.GenerateUUID(&testing.T{}), + Name: "temperature", + ClientIDs: []string{testsutil.GenerateUUID(&testing.T{})}, + }, + }, + ReportTemplate: template, + } + + cases := []struct { + desc string + session authn.Session + config reports.ReportConfig + action reports.ReportAction + err error + }{ + { + desc: "Generate instant email report with disabled config should succeed", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + config: validConfig, + action: reports.EmailReport, + err: nil, + }, + { + desc: "Generate instant email report with enabled config should succeed", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + config: func() reports.ReportConfig { + cfg := validConfig + cfg.Status = reports.EnabledStatus + return cfg + }(), + action: reports.EmailReport, + err: nil, + }, + { + desc: "Generate view report with disabled config should succeed", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + config: validConfig, + action: reports.ViewReport, + err: nil, + }, + { + desc: "Generate download report with disabled config should succeed", + session: authn.Session{ + UserID: userID, + DomainID: domainID, + }, + config: validConfig, + action: reports.DownloadReport, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + _, err := svc.GenerateReport(context.Background(), tc.session, tc.config, tc.action) + + if tc.err != nil { + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } else { + assert.False(t, errors.Contains(err, svcerr.ErrInvalidStatus), fmt.Sprintf("%s: should not get ErrInvalidStatus for instant reports, got %s\n", tc.desc, err)) + } + }) + } +} diff --git a/reports/status.go b/reports/status.go new file mode 100644 index 000000000..b223f4793 --- /dev/null +++ b/reports/status.go @@ -0,0 +1,80 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "encoding/json" + "strings" + + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +// Status represents Rule status. +type Status uint8 + +// Possible User status values. +const ( + // EnabledStatus represents enabled Rule. + EnabledStatus Status = iota + // DisabledStatus represents disabled Rule. + DisabledStatus + // DeletedStatus represents a rule that will be deleted. + DeletedStatus + + // AllStatus is used for querying purposes to list rules irrespective + // of their status - both enabled and disabled. It is never stored in the + // database as the actual User status and should always be the largest + // value in this enumeration. + AllStatus +) + +// String representation of the possible status values. +const ( + Disabled = "disabled" + Enabled = "enabled" + Deleted = "deleted" + All = "all" + Unknown = "unknown" +) + +func (s Status) String() string { + switch s { + case DisabledStatus: + return Disabled + case EnabledStatus: + return Enabled + case DeletedStatus: + return Deleted + case AllStatus: + return All + default: + return Unknown + } +} + +// ToStatus converts string value to a valid status. +func ToStatus(status string) (Status, error) { + switch status { + case "", Enabled: + return EnabledStatus, nil + case Disabled: + return DisabledStatus, nil + case Deleted: + return DeletedStatus, nil + case All: + return AllStatus, nil + } + return Status(0), svcerr.ErrInvalidStatus +} + +func (s Status) MarshalJSON() ([]byte, error) { + return json.Marshal(s.String()) +} + +func (s *Status) UnmarshalJSON(data []byte) error { + str := strings.Trim(string(data), "\"") + val, err := ToStatus(str) + *s = val + return err +} diff --git a/reports/template.go b/reports/template.go new file mode 100644 index 000000000..0d655b56e --- /dev/null +++ b/reports/template.go @@ -0,0 +1,164 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "encoding/json" + "fmt" + "text/template" + "text/template/parse" +) + +type ReportTemplate string + +func (temp ReportTemplate) String() string { + return string(temp) +} + +func (temp ReportTemplate) MarshalJSON() ([]byte, error) { + return json.Marshal(string(temp)) +} + +func (temp *ReportTemplate) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + *temp = ReportTemplate(s) + return nil +} + +func (temp ReportTemplate) Validate() error { + templateStr := string(temp) + + // Validate template syntax using Go's template parser + tmpl := template.New("validate").Funcs(template.FuncMap{ + "add": func(a, b int) int { return a + b }, + "sub": func(a, b int) int { return a - b }, + "div": func(a, b int) int { + if b == 0 { + return 0 + } + return a / b + }, + "mod": func(a, b int) int { + if b == 0 { + return 0 + } + return a % b + }, + "eq": func(a, b int) bool { return a == b }, + "ge": func(a, b int) bool { return a >= b }, + "lt": func(a, b int) bool { return a < b }, + "iterate": func(count int) []int { return make([]int, count) }, + "getStartRow": func(pageNum, firstPageRows, continuationPageRows int) int { return 0 }, + "getEndRow": func(pageNum, firstPageRows, continuationPageRows, totalMessages int) int { return 0 }, + "formatTime": func(t any) string { return "" }, + "formatValue": func(v any) string { return "" }, + }) + + parsed, err := tmpl.Parse(templateStr) + if err != nil { + return fmt.Errorf("template syntax error: %w", err) + } + + var hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd bool + // Validate essential fields are present using template parsing + if err := validateEssentialFields(parsed.Tree.Root, &hasTitle, &hasRange, &hasFormatTime, &hasFormatValue, &hasEnd); err != nil { + return err + } + + if !hasTitle { + return fmt.Errorf("missing essential template field: {{$.Title}}") + } + if !hasRange { + return fmt.Errorf("missing essential template field: {{range .Messages}} or {{range .Reports}}") + } + if !hasFormatTime { + return fmt.Errorf("missing essential template field: {{formatTime .Time}}") + } + if !hasFormatValue { + return fmt.Errorf("missing essential template field: {{formatValue .}}") + } + if !hasEnd { + return fmt.Errorf("missing essential template field: {{end}}") + } + + return nil +} + +func validateEssentialFields(node parse.Node, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd *bool) error { + if node == nil { + return nil + } + + switch n := node.(type) { + case *parse.ListNode: + for _, sub := range n.Nodes { + if err := validateEssentialFields(sub, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + } + + case *parse.ActionNode: + if n.Pipe != nil { + for _, cmd := range n.Pipe.Cmds { + cmdStr := cmd.String() + if cmdStr == "$.Title" { + *hasTitle = true + } + if len(cmd.Args) > 0 { + firstArg := cmd.Args[0].String() + if firstArg == "formatTime" { + *hasFormatTime = true + } + if firstArg == "formatValue" { + *hasFormatValue = true + } + } + } + } + + case *parse.RangeNode: + if n.Pipe != nil && len(n.Pipe.Cmds) > 0 { + cmdStr := n.Pipe.Cmds[0].String() + // Accept .Messages, .Reports, or $report.Messages + if cmdStr == ".Messages" || cmdStr == ".Reports" || cmdStr == "$report.Messages" { + *hasRange = true + } + } + if err := validateEssentialFields(n.List, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + if n.ElseList != nil { + if err := validateEssentialFields(n.ElseList, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + } + *hasEnd = true + + case *parse.IfNode: + if err := validateEssentialFields(n.List, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + if n.ElseList != nil { + if err := validateEssentialFields(n.ElseList, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + } + + case *parse.WithNode: + if err := validateEssentialFields(n.List, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + if n.ElseList != nil { + if err := validateEssentialFields(n.ElseList, hasTitle, hasRange, hasFormatTime, hasFormatValue, hasEnd); err != nil { + return err + } + } + } + + return nil +} diff --git a/reports/template_test.go b/reports/template_test.go new file mode 100644 index 000000000..65d4e5279 --- /dev/null +++ b/reports/template_test.go @@ -0,0 +1,376 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports_test + +import ( + "fmt" + "testing" + + "github.com/absmach/supermq/reports" + "github.com/stretchr/testify/assert" +) + +const ( + validTemplate = ` + + + {{$.Title}} + + + +
+

{{$.Title}}

+

Generated on: {{$.GeneratedDate}}

+
+
+

Messages

+ {{range .Messages}} +
+

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+
+ {{end}} +
+ +` + + templateWithoutTitle = ` + + + Report + + + +

Report

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end}} + +` + + templateWithoutRange = ` + + + {{$.Title}} + + +

{{$.Title}}

+

No messages to display

+ +` + + templateWithoutFormatTime = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{range .Messages}} +

Time: {{.Time}}

+

Value: {{formatValue .}}

+ {{end}} + +` + + templateWithoutFormatValue = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{.}}

+ {{end}} + +` + + templateWithoutEnd = ` + + + {{$.Title}} + + +

{{$.Title}}

+

Time: {{formatTime "test"}}

+

Value: {{formatValue "test"}}

+

No range block with end

+ +` + + templateWithSyntaxError = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end + +` + + templateWithUndefinedFunction = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+

Custom: {{customFunction .}}

+ {{end}} + +` + + templateWithIfCondition = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{if .Messages}} + {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end}} + {{else}} +

No messages available

+ {{end}} + +` + + templateWithWithCondition = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{with .Data}} + {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end}} + {{else}} +

No data available

+ {{end}} + +` + + templateWithNestedConditions = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{if .HasMessages}} + {{with .Data}} + {{range .Messages}} +

Time: {{formatTime .Time}}

+

Value: {{formatValue .}}

+ {{end}} + {{else}} +

Data not available

+ {{end}} + {{else}} +

No messages flag set

+ {{end}} + +` + + templateWithIfMissingFields = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{if .Messages}} + {{range .Messages}} +

Time: {{.Time}}

+

Value: {{.}}

+ {{end}} + {{else}} +

No messages available

+ {{end}} + +` + + templateWithWithMissingFields = ` + + + {{$.Title}} + + +

{{$.Title}}

+ {{with .Data}} + {{range .Messages}} +

Time: {{.Time}}

+

Value: {{formatValue .}}

+ {{end}} + {{else}} +

No data available

+ {{end}} + +` +) + +func TestReportTemplate_Validate(t *testing.T) { + cases := []struct { + desc string + template reports.ReportTemplate + err error + }{ + { + desc: "validate template successfully", + template: reports.ReportTemplate(validTemplate), + err: nil, + }, + { + desc: "validate template without title field", + template: reports.ReportTemplate(templateWithoutTitle), + err: fmt.Errorf("missing essential template field: {{$.Title}}"), + }, + { + desc: "validate template without range field", + template: reports.ReportTemplate(templateWithoutRange), + err: fmt.Errorf("missing essential template field: {{range .Messages}}"), + }, + { + desc: "validate template without formatTime field", + template: reports.ReportTemplate(templateWithoutFormatTime), + err: fmt.Errorf("missing essential template field: {{formatTime .Time}}"), + }, + { + desc: "validate template without formatValue field", + template: reports.ReportTemplate(templateWithoutFormatValue), + err: fmt.Errorf("missing essential template field: {{formatValue .}}"), + }, + { + desc: "validate template without end field", + template: reports.ReportTemplate(templateWithoutEnd), + err: fmt.Errorf("missing essential template field: {{range .Messages}}"), + }, + { + desc: "validate template with syntax error", + template: reports.ReportTemplate(templateWithSyntaxError), + err: fmt.Errorf("template syntax error"), + }, + { + desc: "validate template with undefined function", + template: reports.ReportTemplate(templateWithUndefinedFunction), + err: fmt.Errorf("template syntax error"), + }, + { + desc: "validate empty template", + template: reports.ReportTemplate(""), + err: fmt.Errorf("missing essential template field: {{$.Title}}"), + }, + { + desc: "validate template with if condition successfully", + template: reports.ReportTemplate(templateWithIfCondition), + err: nil, + }, + { + desc: "validate template `with` with condition successfully", + template: reports.ReportTemplate(templateWithWithCondition), + err: nil, + }, + { + desc: "validate template with nested conditions successfully", + template: reports.ReportTemplate(templateWithNestedConditions), + err: nil, + }, + { + desc: "validate template with if condition missing formatTime", + template: reports.ReportTemplate(templateWithIfMissingFields), + err: fmt.Errorf("missing essential template field: {{formatTime .Time}}"), + }, + { + desc: "validate template `with` with condition missing formatTime", + template: reports.ReportTemplate(templateWithWithMissingFields), + err: fmt.Errorf("missing essential template field: {{formatTime .Time}}"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tc.template.Validate() + if tc.err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestReportTemplate_String(t *testing.T) { + template := reports.ReportTemplate(validTemplate) + result := template.String() + + assert.Equal(t, validTemplate, result) +} + +func TestReportTemplate_MarshalJSON(t *testing.T) { + template := reports.ReportTemplate("simple template") + data, err := template.MarshalJSON() + + assert.NoError(t, err) + assert.NotNil(t, data) + assert.Equal(t, `"simple template"`, string(data)) +} + +func TestReportTemplate_UnmarshalJSON(t *testing.T) { + cases := []struct { + desc string + data []byte + expected string + err error + }{ + { + desc: "unmarshal valid JSON successfully", + data: []byte(`"simple template"`), + expected: "simple template", + err: nil, + }, + { + desc: "unmarshal invalid JSON", + data: []byte(`invalid json`), + err: fmt.Errorf("invalid character"), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + var template reports.ReportTemplate + err := template.UnmarshalJSON(tc.data) + + if tc.err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, string(template)) + } + }) + } +} diff --git a/reports/tz.go b/reports/tz.go new file mode 100644 index 000000000..d2e317686 --- /dev/null +++ b/reports/tz.go @@ -0,0 +1,25 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package reports + +import ( + "strings" + "time" +) + +// resolveTimezone returns a *time.Location from a user-provided IANA timezone name. +// Supported inputs: +// - IANA names (e.g., "Europe/Paris", "America/New_York"). +// - Empty string defaults to UTC. +func resolveTimezone(s string) (*time.Location, error) { + s = strings.TrimSpace(s) + if s == "" { + return time.UTC, nil + } + loc, err := time.LoadLocation(s) + if err != nil { + return time.UTC, err + } + return loc, nil +} diff --git a/scripts/certs.sh b/scripts/certs.sh deleted file mode 100755 index fec77e57b..000000000 --- a/scripts/certs.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -### -# Fetches the latest version of the docker files from the Certs repository. -### - -set -e -set -o pipefail - -REPO_URL=https://github.com/absmach/certs -TEMP_DIR="certs" -DOCKER_DIR="docker" -DOCKER_DST_DIR="../docker/addons/certs" -DEST_DIR="../../docker/addons/certs" - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -cd "$SCRIPT_DIR" || exit 1 - -if [ -n "$(git status --porcelain "$DOCKER_DST_DIR")" ]; then - echo "There are uncommitted changes in '$DOCKER_DST_DIR' dir. Please commit or stash them before running this script." - exit 1 -fi - -cleanup() { - rm -rf "$TEMP_DIR" -} -cleanup -trap cleanup EXIT - -git clone --depth 1 --filter=blob:none --sparse "$REPO_URL" -cd "$TEMP_DIR" -git sparse-checkout set "$DOCKER_DIR" - -if [ -d "$DEST_DIR" ]; then - rm -r "$DEST_DIR" -fi -mkdir -p "$DEST_DIR" -mv -f "$DOCKER_DIR"/.??* "$DOCKER_DIR"/* "$DEST_DIR"/ -cd .. -rm -rf "$TEMP_DIR" diff --git a/scripts/ci.sh b/scripts/ci.sh index b1d77a71c..92b2f811e 100755 --- a/scripts/ci.sh +++ b/scripts/ci.sh @@ -69,10 +69,10 @@ setup_mg() { exit 1 fi done - echo "Compile check for rabbitmq..." - SMQ_MESSAGE_BROKER_TYPE=msg_rabbitmq make http - echo "Compile check for redis..." - SMQ_ES_TYPE=es_redis make http + echo "Compile check for nats message broker..." + MG_MESSAGE_BROKER_TYPE=msg_nats make re + echo "Compile check for redis event store..." + MG_ES_TYPE=es_redis make bootstrap make -j$NPROC } diff --git a/scripts/run.sh b/scripts/run.sh index bf9ab26cb..e6567aaa6 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -38,28 +38,28 @@ done ### # Users ### -SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=../docker/templates/reset-password-email.tmpl SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email SMQ_VERIFICATION_EMAIL_TEMPLATE=../docker/templates/verification-email.tmpl $BUILD_DIR/supermq-users & +MG_USERS_LOG_LEVEL=info MG_USERS_HTTP_PORT=9002 MG_USERS_GRPC_PORT=7001 MG_USERS_ADMIN_EMAIL=admin@supermq.com MG_USERS_ADMIN_PASSWORD=12345678 MG_USERS_ADMIN_USERNAME=admin MG_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset MG_PASSWORD_RESET_EMAIL_TEMPLATE=../docker/templates/reset-password-email.tmpl MG_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email MG_VERIFICATION_EMAIL_TEMPLATE=../docker/templates/verification-email.tmpl $BUILD_DIR/supermq-users & ### # Clients ### -SMQ_CLIENTS_LOG_LEVEL=info SMQ_CLIENTS_HTTP_PORT=9000 SMQ_CLIENTS_GRPC_PORT=7000 SMQ_CLIENTS_AUTH_HTTP_PORT=9002 $BUILD_DIR/supermq-clients & +MG_CLIENTS_LOG_LEVEL=info MG_CLIENTS_HTTP_PORT=9000 MG_CLIENTS_GRPC_PORT=7000 MG_CLIENTS_AUTH_HTTP_PORT=9002 $BUILD_DIR/supermq-clients & ### # HTTP ### -SMQ_HTTP_ADAPTER_LOG_LEVEL=info SMQ_HTTP_ADAPTER_PORT=8008 SMQ_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-http & +MG_HTTP_ADAPTER_LOG_LEVEL=info MG_HTTP_ADAPTER_PORT=8008 MG_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-http & ### # MQTT ### -SMQ_MQTT_ADAPTER_LOG_LEVEL=info SMQ_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-mqtt & +MG_MQTT_ADAPTER_LOG_LEVEL=info MG_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-mqtt & ### # CoAP ### -SMQ_COAP_ADAPTER_LOG_LEVEL=info SMQ_COAP_ADAPTER_PORT=5683 SMQ_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-coap & +MG_COAP_ADAPTER_LOG_LEVEL=info MG_COAP_ADAPTER_PORT=5683 MG_CLIENTS_GRPC_URL=localhost:7000 $BUILD_DIR/supermq-coap & trap cleanup EXIT diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 7ae69ef80..3f3d2334d 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 pkgname: mocks +template: testify structname: "{{.InterfaceName}}" filename: "{{snakecase .InterfaceName}}.go" dir: "{{.InterfaceDirRelative}}/mocks" @@ -46,6 +47,20 @@ packages: dir: "./groups/mocks" structname: "GroupsServiceClient" filename: "groups_client.go" + github.com/absmach/supermq/api/grpc/certs/v1: + interfaces: + CertsServiceClient: + config: + dir: "./certs/mocks" + structname: "CertsServiceClient" + filename: "certs_client.go" + github.com/absmach/supermq/api/grpc/readers/v1: + interfaces: + ReadersServiceClient: + config: + dir: "./readers/mocks" + structname: "ReadersServiceClient" + filename: "readers_client.go" github.com/absmach/supermq/pkg/sdk: interfaces: SDK: @@ -80,9 +95,18 @@ packages: github.com/absmach/supermq/clients/private: interfaces: Service: + github.com/absmach/supermq/certs: + interfaces: + Agent: + Service: + Repository: github.com/absmach/supermq/consumers: interfaces: Notifier: + github.com/absmach/supermq/consumers/notifiers: + interfaces: + Service: + SubscriptionsRepository: github.com/absmach/supermq/domains: interfaces: Repository: @@ -98,9 +122,6 @@ packages: github.com/absmach/supermq/groups/private: interfaces: Service: - github.com/absmach/supermq/http: - interfaces: - Service: github.com/absmach/supermq/journal: interfaces: Repository: @@ -112,6 +133,9 @@ packages: github.com/absmach/supermq/pkg/authz: interfaces: Authorization: + github.com/absmach/supermq/pkg/emailer: + interfaces: + Emailer: github.com/absmach/supermq/pkg/events: interfaces: Publisher: @@ -134,9 +158,32 @@ packages: github.com/absmach/supermq/pkg/callout: interfaces: Callout: + github.com/absmach/supermq/pkg/ticker: + interfaces: + Ticker: github.com/absmach/supermq/readers: interfaces: MessageRepository: + github.com/absmach/supermq/re: + interfaces: + Repository: + Service: + github.com/absmach/supermq/bootstrap: + interfaces: + ConfigRepository: + ConfigReader: + Service: + github.com/absmach/supermq/provision: + interfaces: + Service: + github.com/absmach/supermq/alarms: + interfaces: + Service: + Repository: + github.com/absmach/supermq/reports: + interfaces: + Service: + Repository: github.com/absmach/supermq/users: interfaces: Emailer: diff --git a/tools/e2e/Makefile b/tools/e2e/Makefile new file mode 100644 index 000000000..fd27a8a22 --- /dev/null +++ b/tools/e2e/Makefile @@ -0,0 +1,15 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +PROGRAM = e2e +SOURCES = $(wildcard *.go) cmd/main.go + +all: $(PROGRAM) + +.PHONY: all clean + +$(PROGRAM): $(SOURCES) + go build -ldflags "-s -w" -o $@ cmd/main.go + +clean: + rm -rf $(PROGRAM) diff --git a/tools/e2e/README.md b/tools/e2e/README.md new file mode 100644 index 000000000..b455a9b87 --- /dev/null +++ b/tools/e2e/README.md @@ -0,0 +1,93 @@ +# SuperMQ Users Groups Clients and Channels E2E Testing Tool + +A simple utility to create a list of groups and users connected to these groups and channels and clients connected to these channels. + +## Installation + +```bash +cd tools/e2e +make +``` + +### Usage + +```bash +./e2e --help +Tool for testing end-to-end flow of SuperMQ by doing a couple of operations namely: +1. Creating, viewing, updating and changing status of users, groups, clients and channels. +2. Connecting users and groups to each other and clients and channels to each other. +3. Sending messages from clients to channels on all 4 protocol adapters (HTTP, WS, CoAP and MQTT). +Complete documentation is available at https://docs.supermq.absmach.eu + + +Usage: + + e2e [flags] + + +Examples: + +Here is a simple example of using e2e tool. +Use the following commands from the root SuperMQ directory: + +go run tools/e2e/cmd/main.go +go run tools/e2e/cmd/main.go --host 142.93.118.47 +go run tools/e2e/cmd/main.go --host localhost --num 10 --num_of_messages 100 --prefix e2e + + +Flags: + + -h, --help help for e2e + -H, --host string address for a running SuperMQ instance (default "localhost") + -n, --num uint number of users, groups, channels and clients to create and connect (default 10) + -N, --num_of_messages uint number of messages to send (default 10) + -p, --prefix string name prefix for users, groups, clients and channels +``` + +To use `-H` option, you can specify the address for the SuperMQ instance as an argument when running the program. For example, if the SuperMQ instance is running on another computer with the IP address 192.168.0.1, you could use the following command: + +```bash +go run tools/e2e/cmd/main.go --host 142.93.118.47 +``` + +This will tell the program to connect to the SuperMQ instance running on the specified IP address. + +If you want to create a list of channels with certificates: + +```bash +go run tools/e2e/cmd/main.go --host localhost --num 10 --num_of_messages 100 --prefix e2e +``` + +Example of output: + +```bash +created user with token eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE2ODEyMDYwMjMsImlhdCI6MTY4MTIwNTEyMywiaWRlbnRpdHkiOiJlMmUtbGF0ZS1zaWxlbmNlQGVtYWlsLmNvbSIsImlzcyI6ImNsaWVudHMuYXV0aCIsInN1YiI6IjdlZDIyY2IyLTRlMzQtNDhiZi04Y2RlLTIxMjZiYzYyYzY4MyIsInR5cGUiOiJhY2Nlc3MifQ.AdExNYs5mVQNpo_ejJDq7KTC5dKkZWmgM9FJvTM2T_GM2LE9ASQv0ymC4wS3PDXKWf-OcaR8DJIxE6WiG3fztQ +created users of ids: +9e87bc1d-0889-4252-a3df-36e02edfc859 +c1e4901a-fb7f-45e9-b934-c55194b1d028 +c341a9cb-542b-4c3b-afd6-c98e04ed5e7e +8cfc886b-21fa-4205-80b4-3601827b94ff +334984d7-30eb-4b06-92b8-5ec182bebac5 +created groups of ids: +7744ec55-c767-4137-be96-0d79699772a4 +c8fe4d9d-3ad6-4687-83c0-171356f3e4f6 +513f7295-0923-4e21-b41a-3cfd1cb7b9b9 +54bd71ea-3c22-401e-89ea-d58162b983c0 +ae91b327-4c40-4e68-91fe-cd6223ee4e99 +created clients of ids: +5909a907-7413-47d4-b793-e1eb36988a5f +f9b6bc18-1862-4a24-8973-adde11cb3303 +c2bd6eed-6f38-464c-989c-fe8ec8c084ba +8c76702c-0534-4246-8ed7-21816b4f91cf +25005ca8-e886-465f-9cd1-4f3c4a95c6c1 +created channels of ids: +ebb0e5f3-2241-4770-a7cc-f4bbd06134ca +d654948d-d6c1-4eae-b69a-29c853282c3d +2c2a5496-89cf-47e6-9d38-5fd5542337bd +7ab3319d-269c-4b07-9dc5-f9906693e894 +5d8fa139-10e7-4683-94f3-4e881b4db041 +created policies for users, groups, clients and channels +viewed users, groups, clients and channels +updated users, groups, clients and channels +sent messages to channels +``` diff --git a/tools/e2e/cmd/main.go b/tools/e2e/cmd/main.go new file mode 100644 index 000000000..3e68c73ef --- /dev/null +++ b/tools/e2e/cmd/main.go @@ -0,0 +1,58 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains e2e tool for testing SuperMQ. +package main + +import ( + "log" + + "github.com/absmach/supermq/tools/e2e" + cc "github.com/ivanpirog/coloredcobra" + "github.com/spf13/cobra" +) + +const defNum = uint64(10) + +func main() { + econf := e2e.Config{} + + rootCmd := &cobra.Command{ + Use: "e2e", + Short: "e2e is end-to-end testing tool for SuperMQ", + Long: "Tool for testing end-to-end flow of supermq by doing a couple of operations namely:\n" + + "1. Creating, viewing, updating and changing status of users, groups, clients and channels.\n" + + "2. Connecting users and groups to each other and clients and channels to each other.\n" + + "3. Sending messages from clients to channels on all 4 protocol adapters (HTTP, WS, CoAP and MQTT).\n" + + "Complete documentation is available at https://docs.supermq.absmach.eu", + Example: "Here is a simple example of using e2e tool.\n" + + "Use the following commands from the root supermq directory:\n\n" + + "go run tools/e2e/cmd/main.go\n" + + "go run tools/e2e/cmd/main.go --host 142.93.118.47\n" + + "go run tools/e2e/cmd/main.go --host localhost --num 10 --num_of_messages 100 --prefix e2e", + Run: func(cmd *cobra.Command, _ []string) { + e2e.Test(cmd.Context(), econf) + }, + } + + cc.Init(&cc.Config{ + RootCmd: rootCmd, + Headings: cc.HiCyan + cc.Bold + cc.Underline, + CmdShortDescr: cc.Magenta, + Example: cc.Italic + cc.Magenta, + ExecName: cc.Bold, + Flags: cc.HiGreen + cc.Bold, + FlagsDescr: cc.Green, + FlagsDataType: cc.White + cc.Italic, + }) + + // Root Flags + rootCmd.PersistentFlags().StringVarP(&econf.Host, "host", "H", "localhost", "address for a running supermq instance") + rootCmd.PersistentFlags().StringVarP(&econf.Prefix, "prefix", "p", "", "name prefix for users, groups, clients and channels") + rootCmd.PersistentFlags().Uint64VarP(&econf.Num, "num", "n", defNum, "number of users, groups, channels and clients to create and connect") + rootCmd.PersistentFlags().Uint64VarP(&econf.NumOfMsg, "num_of_messages", "N", defNum, "number of messages to send") + + if err := rootCmd.Execute(); err != nil { + log.Fatal(err) + } +} diff --git a/tools/e2e/doc.go b/tools/e2e/doc.go new file mode 100644 index 000000000..eb7fb081d --- /dev/null +++ b/tools/e2e/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package e2e contains entry point for end-to-end tests. +package e2e diff --git a/tools/e2e/e2e.go b/tools/e2e/e2e.go new file mode 100644 index 000000000..5d27bed1b --- /dev/null +++ b/tools/e2e/e2e.go @@ -0,0 +1,643 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package e2e + +import ( + "context" + "fmt" + "math/rand" + "net/http" + "os" + "os/exec" + "reflect" + "strings" + "time" + + "github.com/0x6flab/namegenerator" + sdk "github.com/absmach/supermq/pkg/sdk" + "github.com/gookit/color" + "github.com/gorilla/websocket" + "golang.org/x/sync/errgroup" +) + +const ( + defPass = "12345678" + defWSPort = "8186" + numAdapters = 4 + batchSize = 99 + usersPort = "9002" + groupsPort = "9004" + clientsPort = "9006" + channelsPort = "9005" + domainsPort = "9003" +) + +var ( + namesgenerator = namegenerator.NewGenerator() + msgFormat = `[{"bn":"demo", "bu":"V", "t": %d, "bver":5, "n":"voltage", "u":"V", "v":%d}]` +) + +// Config - test configuration. +type Config struct { + Host string + Num uint64 + NumOfMsg uint64 + SSL bool + CA string + CAKey string + Prefix string +} + +// Test - function that does actual end to end testing. +// The operations are: +// - Create a user +// - Create other users +// - Do Read, Update and Change of Status operations on users. + +// - Create groups using hierarchy +// - Do Read, Update and Change of Status operations on groups. + +// - Create clients +// - Do Read, Update and Change of Status operations on clients. + +// - Create channels +// - Do Read, Update and Change of Status operations on channels. + +// - Connect client to channel +// - Publish message from HTTP, MQTT, WS and CoAP Adapters. +func Test(ctx context.Context, conf Config) { + sdkConf := sdk.Config{ + UsersURL: fmt.Sprintf("http://%s:%s", conf.Host, usersPort), + GroupsURL: fmt.Sprintf("http://%s:%s", conf.Host, groupsPort), + DomainsURL: fmt.Sprintf("http://%s:%s", conf.Host, domainsPort), + ClientsURL: fmt.Sprintf("http://%s:%s", conf.Host, clientsPort), + ChannelsURL: fmt.Sprintf("http://%s:%s", conf.Host, channelsPort), + HTTPAdapterURL: fmt.Sprintf("http://%s/http", conf.Host), + MsgContentType: sdk.CTJSONSenML, + TLSVerification: false, + } + + s := sdk.NewSDK(sdkConf) + + magenta := color.FgLightMagenta.Render + + domainID, token, err := createUser(ctx, s, conf) + if err != nil { + errExit(fmt.Errorf("unable to create user: %w", err)) + } + color.Success.Printf("created user with token %s\n", magenta(token)) + color.Success.Printf("created domain with ID %s\n", magenta(domainID)) + + users, err := createUsers(ctx, s, conf, token) + if err != nil { + errExit(fmt.Errorf("unable to create users: %w", err)) + } + color.Success.Printf("created users of ids:\n%s\n", magenta(getIDS(users))) + + groups, err := createGroups(ctx, s, conf, domainID, token) + if err != nil { + errExit(fmt.Errorf("unable to create groups: %w", err)) + } + color.Success.Printf("created groups of ids:\n%s\n", magenta(getIDS(groups))) + + clients, err := createClients(ctx, s, conf, domainID, token) + if err != nil { + errExit(fmt.Errorf("unable to create clients: %w", err)) + } + color.Success.Printf("created clients of ids:\n%s\n", magenta(getIDS(clients))) + + channels, err := createChannels(ctx, s, conf, domainID, token) + if err != nil { + errExit(fmt.Errorf("unable to create channels: %w", err)) + } + color.Success.Printf("created channels of ids:\n%s\n", magenta(getIDS(channels))) + + // List users, groups, clients and channels + if err := read(ctx, s, conf, domainID, token, users, groups, clients, channels); err != nil { + errExit(fmt.Errorf("unable to read users, groups, clients and channels: %w", err)) + } + color.Success.Println("viewed users, groups, clients and channels") + + // Update users, groups, clients and channels + if err := update(ctx, s, domainID, token, users, groups, clients, channels); err != nil { + errExit(fmt.Errorf("unable to update users, groups, clients and channels: %w", err)) + } + color.Success.Println("updated users, groups, clients and channels") + + // Send messages to channels + if err := messaging(ctx, s, conf, domainID, token, clients, channels); err != nil { + errExit(fmt.Errorf("unable to send messages to channels: %w", err)) + } + color.Success.Println("sent messages to channels") +} + +func errExit(err error) { + color.Error.Println(err.Error()) + os.Exit(1) +} + +func createUser(ctx context.Context, s sdk.SDK, conf Config) (string, string, error) { + user := sdk.User{ + FirstName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + LastName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + Email: fmt.Sprintf("%s%s@email.com", conf.Prefix, namesgenerator.Generate()), + Credentials: sdk.Credentials{ + Username: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + Secret: defPass, + }, + Status: sdk.EnabledStatus, + Role: "admin", + } + + if _, err := s.CreateUser(ctx, user, ""); err != nil { + return "", "", fmt.Errorf("unable to create user: %w", err) + } + + login := sdk.Login{ + Username: user.Credentials.Username, + Password: user.Credentials.Secret, + } + token, err := s.CreateToken(ctx, login) + if err != nil { + return "", "", fmt.Errorf("unable to login user: %w", err) + } + + dname := fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()) + domain := sdk.Domain{ + Name: dname, + Route: strings.ToLower(dname), + Permission: "admin", + } + + domain, err = s.CreateDomain(ctx, domain, token.AccessToken) + if err != nil { + return "", "", fmt.Errorf("unable to create domain: %w", err) + } + + login = sdk.Login{ + Username: user.Credentials.Username, + Password: user.Credentials.Secret, + } + token, err = s.CreateToken(ctx, login) + if err != nil { + return "", "", fmt.Errorf("unable to login user: %w", err) + } + + return domain.ID, token.AccessToken, nil +} + +func createUsers(ctx context.Context, s sdk.SDK, conf Config, token string) ([]sdk.User, error) { + var err error + users := []sdk.User{} + + for i := uint64(0); i < conf.Num; i++ { + user := sdk.User{ + FirstName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + LastName: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + Email: fmt.Sprintf("%s%s@email.com", conf.Prefix, namesgenerator.Generate()), + Credentials: sdk.Credentials{ + Username: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + Secret: defPass, + }, + Status: sdk.EnabledStatus, + } + + user, err = s.CreateUser(ctx, user, token) + if err != nil { + return []sdk.User{}, fmt.Errorf("failed to create the users: %w", err) + } + users = append(users, user) + } + + return users, nil +} + +func createGroups(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Group, error) { + var err error + groups := []sdk.Group{} + + for i := uint64(0); i < conf.Num; i++ { + group := sdk.Group{ + Name: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + Status: sdk.EnabledStatus, + } + + group, err = s.CreateGroup(ctx, group, domainID, token) + if err != nil { + return []sdk.Group{}, fmt.Errorf("failed to create the group: %w", err) + } + groups = append(groups, group) + } + + return groups, nil +} + +func createClientsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Client, error) { + var err error + clients := make([]sdk.Client, num) + + for i := uint64(0); i < num; i++ { + clients[i] = sdk.Client{ + Name: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + } + } + + clients, err = s.CreateClients(ctx, clients, domainID, token) + if err != nil { + return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) + } + + return clients, nil +} + +func createClients(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Client, error) { + clients := []sdk.Client{} + + if conf.Num > batchSize { + batches := int(conf.Num) / batchSize + for i := 0; i < batches; i++ { + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, batchSize) + if err != nil { + return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) + } + clients = append(clients, ths...) + } + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) + if err != nil { + return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) + } + clients = append(clients, ths...) + } else { + ths, err := createClientsInBatch(ctx, s, conf, domainID, token, conf.Num) + if err != nil { + return []sdk.Client{}, fmt.Errorf("failed to create the clients: %w", err) + } + clients = append(clients, ths...) + } + + return clients, nil +} + +func createChannelsInBatch(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, num uint64) ([]sdk.Channel, error) { + var err error + channels := make([]sdk.Channel, num) + + for i := uint64(0); i < num; i++ { + channels[i] = sdk.Channel{ + Name: fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()), + } + channels[i], err = s.CreateChannel(ctx, channels[i], domainID, token) + if err != nil { + return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) + } + } + + return channels, nil +} + +func createChannels(ctx context.Context, s sdk.SDK, conf Config, domainID, token string) ([]sdk.Channel, error) { + channels := []sdk.Channel{} + + if conf.Num > batchSize { + batches := int(conf.Num) / batchSize + for i := 0; i < batches; i++ { + chs, err := createChannelsInBatch(ctx, s, conf, token, domainID, batchSize) + if err != nil { + return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) + } + channels = append(channels, chs...) + } + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num%uint64(batchSize)) + if err != nil { + return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) + } + channels = append(channels, chs...) + } else { + chs, err := createChannelsInBatch(ctx, s, conf, domainID, token, conf.Num) + if err != nil { + return []sdk.Channel{}, fmt.Errorf("failed to create the channels: %w", err) + } + channels = append(channels, chs...) + } + + return channels, nil +} + +func read(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { + for _, user := range users { + if _, err := s.User(ctx, user.ID, token); err != nil { + return fmt.Errorf("failed to get user %w", err) + } + } + up, err := s.Users(ctx, sdk.PageMetadata{}, token) + if err != nil { + return fmt.Errorf("failed to get users %w", err) + } + if up.Total < conf.Num { + return fmt.Errorf("returned users %d less than created users %d", up.Total, conf.Num) + } + for _, group := range groups { + if _, err := s.Group(ctx, group.ID, domainID, token); err != nil { + return fmt.Errorf("failed to get group %w", err) + } + } + gp, err := s.Groups(ctx, sdk.PageMetadata{}, domainID, token) + if err != nil { + return fmt.Errorf("failed to get groups %w", err) + } + if gp.Total < conf.Num { + return fmt.Errorf("returned groups %d less than created groups %d", gp.Total, conf.Num) + } + for _, c := range clients { + if _, err := s.Client(ctx, c.ID, domainID, token); err != nil { + return fmt.Errorf("failed to get client %w", err) + } + } + tp, err := s.Clients(ctx, sdk.PageMetadata{}, domainID, token) + if err != nil { + return fmt.Errorf("failed to get clients %w", err) + } + if tp.Total < conf.Num { + return fmt.Errorf("returned clients %d less than created clients %d", tp.Total, conf.Num) + } + for _, channel := range channels { + if _, err := s.Channel(ctx, channel.ID, domainID, token); err != nil { + return fmt.Errorf("failed to get channel %w", err) + } + } + cp, err := s.Channels(ctx, sdk.PageMetadata{}, domainID, token) + if err != nil { + return fmt.Errorf("failed to get channels %w", err) + } + if cp.Total < conf.Num { + return fmt.Errorf("returned channels %d less than created channels %d", cp.Total, conf.Num) + } + + return nil +} + +func update(ctx context.Context, s sdk.SDK, domainID, token string, users []sdk.User, groups []sdk.Group, clients []sdk.Client, channels []sdk.Channel) error { + for _, user := range users { + user.FirstName = namesgenerator.Generate() + user.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} + rUser, err := s.UpdateUser(ctx, user, token) + if err != nil { + return fmt.Errorf("failed to update user %w", err) + } + if rUser.FirstName != user.FirstName { + return fmt.Errorf("failed to update user name before %s after %s", user.FirstName, rUser.FirstName) + } + if rUser.Metadata["Update"] != user.Metadata["Update"] { + return fmt.Errorf("failed to update user metadata before %s after %s", user.Metadata["Update"], rUser.Metadata["Update"]) + } + user = rUser + user.Credentials.Username = namesgenerator.Generate() + rUser, err = s.UpdateUsername(ctx, user, token) + if err != nil { + return fmt.Errorf("failed to update username %w", err) + } + if rUser.Credentials.Username != user.Credentials.Username { + return fmt.Errorf("failed to update user name before %s after %s", user.Credentials.Username, rUser.Credentials.Username) + } + user = rUser + rUser, err = s.UpdateUserEmail(ctx, user, token) + if err != nil { + return fmt.Errorf("failed to update user identity %w", err) + } + if rUser.Email != user.Email { + return fmt.Errorf("failed to update user identity before %s after %s", user.Email, rUser.Email) + } + user = rUser + user.Tags = []string{namesgenerator.Generate()} + rUser, err = s.UpdateUserTags(ctx, user, token) + if err != nil { + return fmt.Errorf("failed to update user tags %w", err) + } + if rUser.Tags[0] != user.Tags[0] { + return fmt.Errorf("failed to update user tags before %s after %s", user.Tags[0], rUser.Tags[0]) + } + user = rUser + rUser, err = s.DisableUser(ctx, user.ID, token) + if err != nil { + return fmt.Errorf("failed to disable user %w", err) + } + if rUser.Status != sdk.DisabledStatus { + return fmt.Errorf("failed to disable user before %s after %s", user.Status, rUser.Status) + } + user = rUser + rUser, err = s.EnableUser(ctx, user.ID, token) + if err != nil { + return fmt.Errorf("failed to enable user %w", err) + } + if rUser.Status != sdk.EnabledStatus { + return fmt.Errorf("failed to enable user before %s after %s", user.Status, rUser.Status) + } + } + for _, group := range groups { + group.Name = namesgenerator.Generate() + group.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} + rGroup, err := s.UpdateGroup(ctx, group, domainID, token) + if err != nil { + return fmt.Errorf("failed to update group %w", err) + } + if rGroup.Name != group.Name { + return fmt.Errorf("failed to update group name before %s after %s", group.Name, rGroup.Name) + } + if rGroup.Metadata["Update"] != group.Metadata["Update"] { + return fmt.Errorf("failed to update group metadata before %s after %s", group.Metadata["Update"], rGroup.Metadata["Update"]) + } + group = rGroup + rGroup, err = s.DisableGroup(ctx, group.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to disable group %w", err) + } + if rGroup.Status != sdk.DisabledStatus { + return fmt.Errorf("failed to disable group before %s after %s", group.Status, rGroup.Status) + } + group = rGroup + rGroup, err = s.EnableGroup(ctx, group.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to enable group %w", err) + } + if rGroup.Status != sdk.EnabledStatus { + return fmt.Errorf("failed to enable group before %s after %s", group.Status, rGroup.Status) + } + } + for _, t := range clients { + t.Name = namesgenerator.Generate() + t.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} + rClient, err := s.UpdateClient(ctx, t, domainID, token) + if err != nil { + return fmt.Errorf("failed to update client %w", err) + } + if rClient.Name != t.Name { + return fmt.Errorf("failed to update client name before %s after %s", t.Name, rClient.Name) + } + if rClient.Metadata["Update"] != t.Metadata["Update"] { + return fmt.Errorf("failed to update client metadata before %s after %s", t.Metadata["Update"], rClient.Metadata["Update"]) + } + t = rClient + rClient, err = s.UpdateClientSecret(ctx, t.ID, t.Credentials.Secret, domainID, token) + if err != nil { + return fmt.Errorf("failed to update client secret %w", err) + } + t = rClient + t.Tags = []string{namesgenerator.Generate()} + rClient, err = s.UpdateClientTags(ctx, t, domainID, token) + if err != nil { + return fmt.Errorf("failed to update client tags %w", err) + } + if rClient.Tags[0] != t.Tags[0] { + return fmt.Errorf("failed to update client tags before %s after %s", t.Tags[0], rClient.Tags[0]) + } + t = rClient + rClient, err = s.DisableClient(ctx, t.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to disable client %w", err) + } + if rClient.Status != sdk.DisabledStatus { + return fmt.Errorf("failed to disable client before %s after %s", t.Status, rClient.Status) + } + t = rClient + rClient, err = s.EnableClient(ctx, t.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to enable client %w", err) + } + if rClient.Status != sdk.EnabledStatus { + return fmt.Errorf("failed to enable client before %s after %s", t.Status, rClient.Status) + } + } + for _, channel := range channels { + channel.Name = namesgenerator.Generate() + channel.Metadata = sdk.Metadata{"Update": namesgenerator.Generate()} + rChannel, err := s.UpdateChannel(ctx, channel, domainID, token) + if err != nil { + return fmt.Errorf("failed to update channel %w", err) + } + if rChannel.Name != channel.Name { + return fmt.Errorf("failed to update channel name before %s after %s", channel.Name, rChannel.Name) + } + if rChannel.Metadata["Update"] != channel.Metadata["Update"] { + return fmt.Errorf("failed to update channel metadata before %s after %s", channel.Metadata["Update"], rChannel.Metadata["Update"]) + } + channel = rChannel + rChannel, err = s.DisableChannel(ctx, channel.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to disable channel %w", err) + } + if rChannel.Status != sdk.DisabledStatus { + return fmt.Errorf("failed to disable channel before %s after %s", channel.Status, rChannel.Status) + } + channel = rChannel + rChannel, err = s.EnableChannel(ctx, channel.ID, domainID, token) + if err != nil { + return fmt.Errorf("failed to enable channel %w", err) + } + } + + return nil +} + +func messaging(ctx context.Context, s sdk.SDK, conf Config, domainID, token string, clients []sdk.Client, channels []sdk.Channel) error { + for _, c := range clients { + for _, channel := range channels { + conn := sdk.Connection{ + ClientIDs: []string{c.ID}, + ChannelIDs: []string{channel.ID}, + Types: []string{"publish", "subscribe"}, + } + if err := s.Connect(ctx, conn, domainID, token); err != nil { + return fmt.Errorf("failed to connect client %s to channel %s", c.ID, channel.ID) + } + } + } + + g := new(errgroup.Group) + + bt := time.Now().Unix() + for i := uint64(0); i < conf.NumOfMsg; i++ { + for _, client := range clients { + for _, channel := range channels { + func(num int64, client sdk.Client, channel sdk.Channel) { + g.Go(func() error { + msg := fmt.Sprintf(msgFormat, num+1, rand.Int()) + return sendHTTPMessage(ctx, s, msg, client, channel.ID) + }) + g.Go(func() error { + msg := fmt.Sprintf(msgFormat, num+2, rand.Int()) + return sendCoAPMessage(msg, client, channel.ID) + }) + g.Go(func() error { + msg := fmt.Sprintf(msgFormat, num+3, rand.Int()) + return sendMQTTMessage(msg, client, channel.ID) + }) + g.Go(func() error { + msg := fmt.Sprintf(msgFormat, num+4, rand.Int()) + return sendWSMessage(conf, msg, client, channel.ID) + }) + }(bt, client, channel) + bt += numAdapters + } + } + } + + return g.Wait() +} + +func sendHTTPMessage(ctx context.Context, s sdk.SDK, msg string, client sdk.Client, chanID string) error { + if err := s.SendMessage(ctx, client.DomainID, chanID, msg, client.Credentials.Secret); err != nil { + return fmt.Errorf("HTTP failed to send message from client %s to channel %s: %w", client.ID, chanID, err) + } + + return nil +} + +func sendCoAPMessage(msg string, client sdk.Client, chanID string) error { + cmd := exec.Command("coap-cli", "post", fmt.Sprintf("m/%s/c/%s", client.DomainID, chanID), "--auth", client.Credentials.Secret, "-d", msg) + if _, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("CoAP failed to send message from client %s to channel %s: %w", client.ID, chanID, err) + } + + return nil +} + +func sendMQTTMessage(msg string, client sdk.Client, chanID string) error { + cmd := exec.Command("mosquitto_pub", "--id-prefix", "supermq", "-u", client.ID, "-P", client.Credentials.Secret, "-t", fmt.Sprintf("m/%s/c/%s", client.DomainID, chanID), "-h", "localhost", "-m", msg) + if _, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("MQTT failed to send message from client %s to channel %s: %w", client.ID, chanID, err) + } + + return nil +} + +func sendWSMessage(conf Config, msg string, client sdk.Client, chanID string) error { + socketURL := fmt.Sprintf("ws://%s:%s/m/%s/c/%s", conf.Host, defWSPort, client.DomainID, chanID) + header := http.Header{"Authorization": []string{client.Credentials.Secret}} + conn, _, err := websocket.DefaultDialer.Dial(socketURL, header) + if err != nil { + return fmt.Errorf("unable to connect to websocket: %w", err) + } + defer conn.Close() + if err := conn.WriteMessage(websocket.TextMessage, []byte(msg)); err != nil { + return fmt.Errorf("WS failed to send message from client %s to channel %s: %w", client.ID, chanID, err) + } + + return nil +} + +// getIDS returns a list of IDs of the given objects. +func getIDS(objects any) string { + v := reflect.ValueOf(objects) + if v.Kind() != reflect.Slice { + panic("objects argument must be a slice") + } + ids := make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + id := v.Index(i).FieldByName("ID").String() + ids[i] = id + } + idList := strings.Join(ids, "\n") + + return idList +} diff --git a/tools/mqtt-bench/Makefile b/tools/mqtt-bench/Makefile new file mode 100644 index 000000000..f2b3bed0d --- /dev/null +++ b/tools/mqtt-bench/Makefile @@ -0,0 +1,15 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +PROGRAM = mqtt-bench +SOURCES = $(wildcard *.go) cmd/main.go + +all: $(PROGRAM) + +.PHONY: all clean + +$(PROGRAM): $(SOURCES) + go build -ldflags "-s -w" -o $@ cmd/main.go + +clean: + rm -rf $(PROGRAM) diff --git a/tools/mqtt-bench/README.md b/tools/mqtt-bench/README.md new file mode 100644 index 000000000..2acd5dd00 --- /dev/null +++ b/tools/mqtt-bench/README.md @@ -0,0 +1,109 @@ +# MQTT Benchmarking Tool + +A simple MQTT benchmarking tool for SuperMQ platform. + +It connects SuperMQ clients as subscribers over a number of channels and +uses other SuperMQ clients to publish messages and create MQTT load. + +SuperMQ clients used must be pre-provisioned first, and SuperMQ `provision` tool can be used for this purpose. + +## Installation + +``` +cd tools/mqtt-bench +make +``` + +## Usage + +The tool supports multiple concurrent clients, publishers and subscribers configurable message size, etc: + +``` +./mqtt-bench --help +Tool for extensive load and benchmarking of MQTT brokers used within SuperMQ platform. +Complete documentation is available at https://docs.supermq.absmach.eu + +Usage: + mqtt-bench [flags] + +Flags: + -b, --broker string address for mqtt broker, for secure use tcps and 8883 (default "tcp://localhost:1883") + --ca string CA file (default "ca.crt") + -c, --config string config file for mqtt-bench (default "config.toml") + -n, --count int Number of messages sent per publisher (default 100) + -f, --format string Output format: text|json (default "text") + -h, --help help for mqtt-bench + -m, --supermq string config file for SuperMQ connections (default "connections.toml") + --mtls Use mtls for connection + -p, --pubs int Number of publishers (default 10) + -q, --qos int QoS for published messages, values 0 1 2 + --quiet Suppress messages + -r, --retain Retain mqtt messages + -z, --size int Size of message payload bytes (default 100) + -t, --skipTLSVer Skip tls verification + -t, --timeout Timeout mqtt messages (default 10000) +``` + +Two output formats supported: human-readable plain text and JSON. + +Before use you need a `mgconn.toml` - a TOML file that describes SuperMQ connection data (channels, clientIDs, clientKeys, certs). +You can use `provision` tool (in tools/provision) to create this TOML config file. + +```bash +go run tools/mqtt-bench/cmd/main.go -u test@supermq.com -p test1234 --host http://127.0.0.1 --num 100 > tools/mqtt-bench/mgconn.toml +``` + +Example use and output + +Without mtls: + +``` +go run tools/mqtt-bench/cmd/main.go --broker tcp://localhost:1883 --count 100 --size 100 --qos 0 --format text --pubs 10 --supermq tools/mqtt-bench/mgconn.toml +``` + +With mtls +go run tools/mqtt-bench/cmd/main.go --broker tcps://localhost:8883 --count 100 --size 100 --qos 0 --format text --pubs 10 --supermq tools/mqtt-bench/mgconn.toml --mtls -ca docker/ssl/certs/ca.crt + +``` + +You can use `config.toml` to create tests with this tool: + +``` + +go run tools/mqtt-bench/cmd/main.go --config tools/mqtt-bench/config.toml + +``` + +Example of `config.toml`: + +``` + +[mqtt] +[mqtt.broker] +url = "tcp://localhost:1883" + +[mqtt.message] +size = 100 +format = "text" +qos = 2 +retain = true + +[mqtt.tls] +mtls = false +skiptlsver = true +ca = "ca.crt" + +[test] +pubs = 3 +count = 100 + +[log] +quiet = false + +[supermq] +connections_file = "smqconn.toml" + +``` + +Based on this, a test scenario is provided in `templates/reference.toml` file. +``` diff --git a/tools/mqtt-bench/bench.go b/tools/mqtt-bench/bench.go new file mode 100644 index 000000000..0e10dabc2 --- /dev/null +++ b/tools/mqtt-bench/bench.go @@ -0,0 +1,205 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bench + +import ( + "crypto/rand" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "os" + "strconv" + "sync" + "time" + + smqlog "github.com/absmach/supermq/logger" + "github.com/pelletier/go-toml" +) + +// Benchmark - main benchmarking function. +func Benchmark(cfg Config) error { + if err := checkConnection(cfg.MQTT.Broker.URL, 1); err != nil { + return err + } + logger, err := smqlog.New(os.Stdout, "debug") + if err != nil { + return err + } + + subsResults := map[string](*[]float64){} + var caByte []byte + if cfg.MQTT.TLS.MTLS { + caFile, err := os.Open(cfg.MQTT.TLS.CA) + + defer func() { + if err = caFile.Close(); err != nil { + logger.Warn(fmt.Sprintf("Could not close file: %s", err)) + } + }() + if err != nil { + logger.Warn(err.Error()) + } + caByte, _ = io.ReadAll(caFile) + } + + data, err := os.ReadFile(cfg.Smq.ConnFile) + if err != nil { + return fmt.Errorf("error loading connections file: %s", err) + } + + mg := superMQ{} + if err := toml.Unmarshal(data, &mg); err != nil { + return fmt.Errorf("cannot load SuperMQ connections config %s \nUse tools/provision to create file", cfg.Smq.ConnFile) + } + + resCh := make(chan *runResults) + finishedPub := make(chan bool) + + startStamp := time.Now() + + n := len(mg.Channels) + var cert tls.Certificate + + start := time.Now() + + var wg sync.WaitGroup + errorChan := make(chan error, cfg.Test.Pubs) + + for i := 0; i < cfg.Test.Pubs; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + mgChan := mg.Channels[i%n] + mgCli := mg.Clients[i%n] + + if cfg.MQTT.TLS.MTLS { + cert, err = tls.X509KeyPair([]byte(mgCli.MTLSCert), []byte(mgCli.MTLSKey)) + if err != nil { + errorChan <- err + return + } + } + c, err := makeClient(i, cfg, mgChan, mgCli, startStamp, caByte, cert) + if err != nil { + errorChan <- fmt.Errorf("unable to create message payload %s", err.Error()) + return + } + + c.publish(resCh, errorChan) + }(i) + } + + go func() { + wg.Wait() + close(errorChan) + }() + + for err := range errorChan { + if err != nil { + return err + } + } + + // Collect the results + var results []*runResults + if cfg.Test.Pubs > 0 { + results = make([]*runResults, cfg.Test.Pubs) + } + + // Wait for publishers to finish + go func() { + for i := 0; i < cfg.Test.Pubs; i++ { + results[i] = <-resCh + } + finishedPub <- true + }() + + <-finishedPub + + totalTime := time.Since(start) + totals := calculateTotalResults(results, totalTime, subsResults) + if totals == nil { + return fmt.Errorf("totals not assigned") + } + + printResults(results, totals, cfg.MQTT.Message.Format, cfg.Log.Quiet) + return nil +} + +func getBytePayload(size int, m message) (handler, error) { + // Calculate payload size. + var b []byte + s, err := json.Marshal(&m) + if err != nil { + return nil, err + } + n := len(s) + if n < size { + sz := size - n + for { + b = make([]byte, sz) + if _, err = rand.Read(b); err != nil { + return nil, err + } + m.Payload = b + content, err := json.Marshal(&m) + if err != nil { + return nil, err + } + l := len(content) + // Use range because the size of generated JSON + // depends on current time and random byte array. + if l <= size+5 && l >= size-5 { + break + } + if l > size { + sz-- + } + if l < size { + sz++ + } + } + } + + ret := func(m *message) ([]byte, error) { + m.Payload = b + m.Sent = time.Now() + return json.Marshal(m) + } + return ret, nil +} + +func makeClient(i int, cfg Config, mgChan channel, cli client, start time.Time, caCert []byte, clientCert tls.Certificate) (*Client, error) { + c := &Client{ + ID: strconv.Itoa(i), + BrokerURL: cfg.MQTT.Broker.URL, + BrokerUser: cli.ClientID, + BrokerPass: cli.ClientSecret, + MsgTopic: fmt.Sprintf("channels/%s/messages/%d/test", mgChan.ChannelID, start.UnixNano()), + MsgSize: cfg.MQTT.Message.Size, + MsgCount: cfg.Test.Count, + MsgQoS: byte(cfg.MQTT.Message.QoS), + Quiet: cfg.Log.Quiet, + MTLS: cfg.MQTT.TLS.MTLS, + SkipTLSVer: cfg.MQTT.TLS.SkipTLSVer, + CA: caCert, + timeout: cfg.MQTT.Timeout, + ClientCert: clientCert, + Retain: cfg.MQTT.Message.Retain, + } + msg := message{ + Topic: c.MsgTopic, + QoS: c.MsgQoS, + ID: c.ID, + Sent: time.Now(), + } + h, err := getBytePayload(cfg.MQTT.Message.Size, msg) + if err != nil { + return nil, err + } + + c.SendMsg = h + return c, nil +} diff --git a/tools/mqtt-bench/client.go b/tools/mqtt-bench/client.go new file mode 100644 index 000000000..1372990c4 --- /dev/null +++ b/tools/mqtt-bench/client.go @@ -0,0 +1,221 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bench + +import ( + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "log" + "net" + "strings" + "sync" + "time" + + mqtt "github.com/eclipse/paho.mqtt.golang" +) + +// Set default ping timeout to large value, so that ping +// won't fail in the case of broker pingresp delay. +const pingTimeout = 10000 + +// Client - represents mqtt client. +type Client struct { + ID string + BrokerURL string + BrokerUser string + BrokerPass string + MsgTopic string + MsgSize int + MsgCount int + MsgQoS byte + Quiet bool + timeout int + mqttClient *mqtt.Client + MTLS bool + SkipTLSVer bool + Retain bool + CA []byte + ClientCert tls.Certificate + ClientKey *rsa.PrivateKey + SendMsg handler +} + +type message struct { + ID string `json:"id"` + Topic string `json:"topic"` + QoS byte `json:"qos"` + Payload []byte `json:"payload"` + Sent time.Time `json:"sent"` + Delivered time.Time `json:"delivered"` + Error bool `json:"error"` +} + +type handler func(*message) ([]byte, error) + +func (c *Client) publish(r chan *runResults, errChan chan<- error) { + res := &runResults{} + times := make([]*float64, c.MsgCount) + + start := time.Now() + if c.connect() != nil { + flushMessages := make([]message, c.MsgCount) + for i, m := range flushMessages { + m.Error = true + times[i] = calcMsgRes(&m, res) + } + r <- calcRes(res, start, arr(times)) + } + if !c.Quiet { + log.Printf("Client %v is connected to the broker %v\n", c.ID, c.BrokerURL) + } + wg := sync.WaitGroup{} + mu := sync.Mutex{} + // Use a single message. + m := message{ + Topic: c.MsgTopic, + QoS: c.MsgQoS, + ID: c.ID, + Sent: time.Now(), + } + payload, err := c.SendMsg(&m) + if err != nil { + errChan <- fmt.Errorf("failed to marshal payload - %s", err.Error()) + } + + for i := 0; i < c.MsgCount; i++ { + wg.Add(1) + go func(mut *sync.Mutex, wg *sync.WaitGroup, i int, m message) { + defer wg.Done() + m.Sent = time.Now() + + token := (*c.mqttClient).Publish(m.Topic, m.QoS, c.Retain, payload) + if !token.WaitTimeout(time.Second*time.Duration(c.timeout)) || token.Error() != nil || !(*c.mqttClient).IsConnectionOpen() { + m.Error = true + mu.Lock() + times[i] = calcMsgRes(&m, res) + mu.Unlock() + return + } + + m.Delivered = time.Now() + m.Error = false + mu.Lock() + times[i] = calcMsgRes(&m, res) + mu.Unlock() + + if !c.Quiet && i > 0 && i%100 == 0 { + log.Printf("Client %v published %v messages and keeps publishing...\n", c.ID, i) + } + }(&mu, &wg, i, m) + } + wg.Wait() + + r <- calcRes(res, start, arr(times)) +} + +func (c *Client) connect() error { + opts := mqtt.NewClientOptions(). + AddBroker(c.BrokerURL). + SetClientID(c.ID). + SetCleanSession(false). + SetAutoReconnect(false). + SetOnConnectHandler(c.connected). + SetConnectionLostHandler(c.connLost). + SetPingTimeout(time.Second * pingTimeout). + SetAutoReconnect(true). + SetCleanSession(false) + + if c.BrokerUser != "" && c.BrokerPass != "" { + opts.SetUsername(c.BrokerUser) + opts.SetPassword(c.BrokerPass) + } + + if c.MTLS { + cfg := &tls.Config{ + InsecureSkipVerify: c.SkipTLSVer, + } + + if c.CA != nil { + cfg.RootCAs = x509.NewCertPool() + cfg.RootCAs.AppendCertsFromPEM(c.CA) + } + if c.ClientCert.Certificate != nil { + cfg.Certificates = []tls.Certificate{c.ClientCert} + } + + opts.SetTLSConfig(cfg) + opts.SetProtocolVersion(4) + } + + client := mqtt.NewClient(opts) + token := client.Connect() + token.Wait() + + c.mqttClient = &client + + if token.Error() != nil { + log.Printf("Client %v had error connecting to the broker: %s\n", c.ID, token.Error().Error()) + return token.Error() + } + + return nil +} + +func checkConnection(broker string, timeoutSecs int) error { + s := strings.Split(broker, ":") + if len(s) != 3 { + return errors.New("wrong host address format") + } + + network := s[0] + host := strings.Trim(s[1], "/") + port := s[2] + + log.Println("Testing connection...") + conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%s", host, port), time.Duration(timeoutSecs)*time.Second) + conClose := func() { + if conn != nil { + log.Println("Closing testing connection...") + conn.Close() + } + } + + defer conClose() + if err, ok := err.(*net.OpError); ok && err.Timeout() { + return fmt.Errorf("timeout error: %s", err.Error()) + } + + if err != nil { + return fmt.Errorf("error: %s", err.Error()) + } + + log.Printf("Connection to %s://%s:%s looks OK\n", network, host, port) + return nil +} + +func arr(a []*float64) []float64 { + ret := []float64{} + for _, v := range a { + if v != nil { + ret = append(ret, *v) + } + } + if len(ret) == 0 { + ret = append(ret, 0) + } + return ret +} + +func (c *Client) connected(client mqtt.Client) { + if !c.Quiet { + log.Printf("Client %v is connected to the broker %v\n", c.ID, c.BrokerURL) + } +} + +func (c *Client) connLost(client mqtt.Client, reason error) { + log.Printf("Client %v had lost connection to the broker: %s\n", c.ID, reason.Error()) +} diff --git a/tools/mqtt-bench/cmd/main.go b/tools/mqtt-bench/cmd/main.go new file mode 100644 index 000000000..27f6e19ea --- /dev/null +++ b/tools/mqtt-bench/cmd/main.go @@ -0,0 +1,77 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains the entry point of the mqtt-bench tool. +package main + +import ( + "log" + + bench "github.com/absmach/supermq/tools/mqtt-bench" + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func main() { + confFile := "" + bconf := bench.Config{} + + // Command + rootCmd := &cobra.Command{ + Use: "mqtt-bench", + Short: "mqtt-bench is MQTT benchmark tool for SuperMQ", + Long: `Tool for extensive load and benchmarking of MQTT brokers used within the SuperMQ platform. +Complete documentation is available at https://docs.supermq.absmach.eu`, + Run: func(cmd *cobra.Command, args []string) { + if confFile != "" { + viper.SetConfigFile(confFile) + + if err := viper.ReadInConfig(); err != nil { + log.Printf("Failed to load config - %s", err) + } + + if err := viper.Unmarshal(&bconf); err != nil { + log.Printf("Unable to decode into struct, %v", err) + } + } + + if err := bench.Benchmark(bconf); err != nil { + log.Fatal(err) + } + }, + } + + // Flags + // MQTT Broker + rootCmd.PersistentFlags().StringVarP(&bconf.MQTT.Broker.URL, "broker", "b", "tcp://localhost:1883", + "address for mqtt broker, for secure use tcps and 8883") + + // MQTT Message + rootCmd.PersistentFlags().IntVarP(&bconf.MQTT.Message.Size, "size", "z", 100, "Size of message payload bytes") + rootCmd.PersistentFlags().StringVarP(&bconf.MQTT.Message.Payload, "payload", "l", "", "Template message") + rootCmd.PersistentFlags().StringVarP(&bconf.MQTT.Message.Format, "format", "f", "text", "Output format: text|json") + rootCmd.PersistentFlags().IntVarP(&bconf.MQTT.Message.QoS, "qos", "q", 0, "QoS for published messages, values 0 1 2") + rootCmd.PersistentFlags().BoolVarP(&bconf.MQTT.Message.Retain, "retain", "r", false, "Retain mqtt messages") + rootCmd.PersistentFlags().IntVarP(&bconf.MQTT.Timeout, "timeout", "o", 10000, "Timeout mqtt messages") + + // MQTT TLS + rootCmd.PersistentFlags().BoolVarP(&bconf.MQTT.TLS.MTLS, "mtls", "", false, "Use mtls for connection") + rootCmd.PersistentFlags().BoolVarP(&bconf.MQTT.TLS.SkipTLSVer, "skipTLSVer", "t", false, "Skip tls verification") + rootCmd.PersistentFlags().StringVarP(&bconf.MQTT.TLS.CA, "ca", "", "ca.crt", "CA file") + + // Test params + rootCmd.PersistentFlags().IntVarP(&bconf.Test.Count, "count", "n", 100, "Number of messages sent per publisher") + rootCmd.PersistentFlags().IntVarP(&bconf.Test.Subs, "subs", "s", 10, "Number of subscribers") + rootCmd.PersistentFlags().IntVarP(&bconf.Test.Pubs, "pubs", "p", 10, "Number of publishers") + + // Log params + rootCmd.PersistentFlags().BoolVarP(&bconf.Log.Quiet, "quiet", "", false, "Suppress messages") + + // Config file + rootCmd.PersistentFlags().StringVarP(&confFile, "config", "c", "config.toml", "config file for mqtt-bench") + rootCmd.PersistentFlags().StringVarP(&bconf.Smq.ConnFile, "supermq", "m", "connections.toml", "config file for SuperMQ connections") + + if err := rootCmd.Execute(); err != nil { + log.Fatal(err) + } +} diff --git a/tools/mqtt-bench/config.go b/tools/mqtt-bench/config.go new file mode 100644 index 000000000..0a59a29b4 --- /dev/null +++ b/tools/mqtt-bench/config.go @@ -0,0 +1,68 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bench + +// Keep struct names exported, otherwise Viper unmarshalling won't work. +type mqttBrokerConfig struct { + URL string `toml:"url" mapstructure:"url"` +} + +type mqttMessageConfig struct { + Size int `toml:"size" mapstructure:"size"` + Payload string `toml:"payload" mapstructure:"payload"` + Format string `toml:"format" mapstructure:"format"` + QoS int `toml:"qos" mapstructure:"qos"` + Retain bool `toml:"retain" mapstructure:"retain"` +} + +type mqttTLSConfig struct { + MTLS bool `toml:"mtls" mapstructure:"mtls"` + SkipTLSVer bool `toml:"skiptlsver" mapstructure:"skiptlsver"` + CA string `toml:"ca" mapstructure:"ca"` +} + +type mqttConfig struct { + Broker mqttBrokerConfig `toml:"broker" mapstructure:"broker"` + Message mqttMessageConfig `toml:"message" mapstructure:"message"` + Timeout int `toml:"timeout" mapstructure:"timeout"` + TLS mqttTLSConfig `toml:"tls" mapstructure:"tls"` +} + +type testConfig struct { + Count int `toml:"count" mapstructure:"count"` + Pubs int `toml:"pubs" mapstructure:"pubs"` + Subs int `toml:"subs" mapstructure:"subs"` +} + +type logConfig struct { + Quiet bool `toml:"quiet" mapstructure:"quiet"` +} + +type smqFile struct { + ConnFile string `toml:"connections_file" mapstructure:"connections_file"` +} + +type client struct { + ClientID string `toml:"client_id" mapstructure:"client_id"` + ClientSecret string `toml:"client_secret" mapstructure:"client_secret"` + MTLSCert string `toml:"mtls_cert" mapstructure:"mtls_cert"` + MTLSKey string `toml:"mtls_key" mapstructure:"mtls_key"` +} + +type channel struct { + ChannelID string `toml:"channel_id" mapstructure:"channel_id"` +} + +type superMQ struct { + Clients []client `toml:"clients" mapstructure:"clients"` + Channels []channel `toml:"channels" mapstructure:"channels"` +} + +// Config struct holds benchmark configuration. +type Config struct { + MQTT mqttConfig `toml:"mqtt" mapstructure:"mqtt"` + Test testConfig `toml:"test" mapstructure:"test"` + Log logConfig `toml:"log" mapstructure:"log"` + Smq smqFile `toml:"supermq" mapstructure:"supermq"` +} diff --git a/tools/mqtt-bench/doc.go b/tools/mqtt-bench/doc.go new file mode 100644 index 000000000..624651473 --- /dev/null +++ b/tools/mqtt-bench/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bench contains benchmarking tool for MQTT broker. +package bench diff --git a/tools/mqtt-bench/results.go b/tools/mqtt-bench/results.go new file mode 100644 index 000000000..6d397e0f0 --- /dev/null +++ b/tools/mqtt-bench/results.go @@ -0,0 +1,194 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bench + +import ( + "bytes" + "encoding/json" + "fmt" + "log" + "time" + + "gonum.org/v1/gonum/mat" + "gonum.org/v1/gonum/stat" +) + +type subsResults map[string](*[]float64) + +type runResults struct { + ID string `json:"id"` + Successes int64 `json:"successes"` + Failures int64 `json:"failures"` + RunTime float64 `json:"run_time"` + MsgTimeMin float64 `json:"msg_time_min"` + MsgTimeMax float64 `json:"msg_time_max"` + MsgTimeMean float64 `json:"msg_time_mean"` + MsgTimeStd float64 `json:"msg_time_std"` + MsgDelTimeMin float64 `json:"msg_del_time_min"` + MsgDelTimeMax float64 `json:"msg_del_time_max"` + MsgDelTimeMean float64 `json:"msg_del_time_mean"` + MsgDelTimeStd float64 `json:"msg_del_time_std"` + MsgsPerSec float64 `json:"msgs_per_sec"` +} + +type totalResults struct { + Ratio float64 `json:"ratio"` + Successes int64 `json:"successes"` + Failures int64 `json:"failures"` + TotalRunTime float64 `json:"total_run_time"` + AvgRunTime float64 `json:"avg_run_time"` + MsgTimeMin float64 `json:"msg_time_min"` + MsgTimeMax float64 `json:"msg_time_max"` + MsgDelTimeMin float64 `json:"msg_del_time_min"` + MsgDelTimeMax float64 `json:"msg_del_time_max"` + MsgTimeMeanAvg float64 `json:"msg_time_mean_avg"` + MsgTimeMeanStd float64 `json:"msg_time_mean_std"` + MsgDelTimeMeanAvg float64 `json:"msg_del_time_mean_avg"` + MsgDelTimeMeanStd float64 `json:"msg_del_time_mean_std"` + TotalMsgsPerSec float64 `json:"total_msgs_per_sec"` + AvgMsgsPerSec float64 `json:"avg_msgs_per_sec"` +} + +// JSONResults are used to export results as a JSON document. +type JSONResults struct { + Runs []*runResults `json:"runs"` + Totals *totalResults `json:"totals"` +} + +func calcMsgRes(m *message, res *runResults) *float64 { + if m.Error { + res.Failures++ + return nil + } + res.Successes++ + diff := float64(m.Delivered.Sub(m.Sent).Nanoseconds() / 1000) // in microseconds + return &diff +} + +func calcRes(r *runResults, start time.Time, times []float64) *runResults { + duration := time.Since(start) + timeMatrix := mat.NewDense(1, len(times), times) + r.MsgTimeMin = mat.Min(timeMatrix) + r.MsgTimeMax = mat.Max(timeMatrix) + r.MsgTimeMean = stat.Mean(times, nil) + r.MsgTimeStd = stat.StdDev(times, nil) + r.RunTime = duration.Seconds() + r.MsgsPerSec = float64(r.Successes) / duration.Seconds() + return r +} + +func calculateTotalResults(results []*runResults, totalTime time.Duration, sr subsResults) *totalResults { + if results == nil || len(results) < 1 { + return nil + } + totals := new(totalResults) + msgTimeMeans := make([]float64, len(results)) + msgTimeMeansDelivered := make([]float64, len(results)) + msgsPerSecs := make([]float64, len(results)) + runTimes := make([]float64, len(results)) + bws := make([]float64, len(results)) + + totals.TotalRunTime = totalTime.Seconds() + + totals.MsgTimeMin = results[0].MsgTimeMin + for i, res := range results { + totals.Successes += res.Successes + totals.Failures += res.Failures + totals.TotalMsgsPerSec += res.MsgsPerSec + + // Don't count those client that sent no messages. + if res.MsgsPerSec == 0 { + continue + } + + if res.MsgTimeMin < totals.MsgTimeMin { + totals.MsgTimeMin = res.MsgTimeMin + } + + if res.MsgTimeMax > totals.MsgTimeMax { + totals.MsgTimeMax = res.MsgTimeMax + } + + if res.MsgDelTimeMin < totals.MsgDelTimeMin { + totals.MsgDelTimeMin = res.MsgDelTimeMin + } + + if res.MsgDelTimeMax > totals.MsgDelTimeMax { + totals.MsgDelTimeMax = res.MsgDelTimeMax + } + + msgTimeMeansDelivered[i] = res.MsgDelTimeMean + msgTimeMeans[i] = res.MsgTimeMean + msgsPerSecs[i] = res.MsgsPerSec + runTimes[i] = res.RunTime + bws[i] = res.MsgsPerSec + } + + for _, v := range sr { + times := mat.NewDense(1, len(*v), *v) + totals.MsgDelTimeMin = mat.Min(times) / 1000 + totals.MsgDelTimeMax = mat.Max(times) / 1000 + totals.MsgDelTimeMeanAvg = stat.Mean(*v, nil) / 1000 + totals.MsgDelTimeMeanStd = stat.StdDev(*v, nil) / 1000 + } + + totals.Ratio = float64(totals.Successes) / float64(totals.Successes+totals.Failures) + totals.AvgMsgsPerSec = stat.Mean(msgsPerSecs, nil) + totals.AvgRunTime = stat.Mean(runTimes, nil) + totals.MsgDelTimeMeanAvg = stat.Mean(msgTimeMeansDelivered, nil) + totals.MsgDelTimeMeanStd = stat.StdDev(msgTimeMeansDelivered, nil) + totals.MsgTimeMeanAvg = stat.Mean(msgTimeMeans, nil) + totals.MsgTimeMeanStd = stat.StdDev(msgTimeMeans, nil) + + return totals +} + +func printResults(results []*runResults, totals *totalResults, format string, quiet bool) { + switch format { + case "json": + jr := JSONResults{ + Runs: results, + Totals: totals, + } + data, err := json.Marshal(jr) + if err != nil { + log.Printf("Failed to prepare results for printing - %s\n", err.Error()) + } + var out bytes.Buffer + if err = json.Indent(&out, data, "", "\t"); err != nil { + return + } + + fmt.Println(out.String()) + default: + if !quiet { + for _, res := range results { + fmt.Printf("======= CLIENT %s =======\n", res.ID) + fmt.Printf("Ratio: %.6f (%d/%d)\n", float64(res.Successes)/float64(res.Successes+res.Failures), res.Successes, res.Successes+res.Failures) + fmt.Printf("Succeeded: %d\n", res.Successes) + fmt.Printf("Failed: %d\n", res.Failures) + fmt.Printf("Runtime (s): %.3f\n", res.RunTime) + fmt.Printf("Msg time min (µs): %.3f\n", res.MsgTimeMin) + fmt.Printf("Msg time max (µs): %.3f\n", res.MsgTimeMax) + fmt.Printf("Msg time mean (µs): %.3f\n", res.MsgTimeMean) + fmt.Printf("Msg time std (µs): %.3f\n\n", res.MsgTimeStd) + + fmt.Printf("Bandwidth (msg/sec): %.3f\n\n", res.MsgsPerSec) + } + } + fmt.Printf("========= TOTAL (%d) =========\n", len(results)) + fmt.Printf("Total Ratio: %.3f (%d/%d)\n", totals.Ratio, totals.Successes, totals.Successes+totals.Failures) + fmt.Printf("Succeeded: %d\n", totals.Successes) + fmt.Printf("Failed: %d\n", totals.Failures) + fmt.Printf("Total Runtime (sec): %.3f\n", totals.TotalRunTime) + fmt.Printf("Average Runtime (sec): %.3f\n", totals.AvgRunTime) + fmt.Printf("Msg time min (µs): %.3f\n", totals.MsgTimeMin) + fmt.Printf("Msg time max (µs): %.3f\n", totals.MsgTimeMax) + fmt.Printf("Msg time mean (µs): %.3f\n", totals.MsgTimeMeanAvg) + fmt.Printf("Msg time mean std (µs): %.3f\n", totals.MsgTimeMeanStd) + + fmt.Printf("Average Bandwidth (msg/sec): %.3f\n", totals.AvgMsgsPerSec) + fmt.Printf("Total Bandwidth (msg/sec): %.3f\n", totals.TotalMsgsPerSec) + } +} diff --git a/tools/mqtt-bench/scripts/mqtt-bench.sh b/tools/mqtt-bench/scripts/mqtt-bench.sh new file mode 100755 index 000000000..5142b7bfe --- /dev/null +++ b/tools/mqtt-bench/scripts/mqtt-bench.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +i=0 +echo "BEGIN TEST " > result.$1.out +for mtls in true +do + for ret in false true + do + for qos in 0 1 2 + do + for pub in 1 10 100 + do + for sub in 1 10 + do + for message in 100 1000 + do + if [[ $pub -eq 100 && $message -eq 1000 ]]; + then + continue + fi + + for size in 100 500 + do + let "i += 1" + echo "=================================TEST $i=========================================" >> $1-$i.out + echo "MTLS: $mtls RETAIN: $ret, QOS $qos" >> $1-$i.out + echo "Pub:" $pub ", Sub:" $sub ", MsgSize:" $size ", MsgPerPub:" $message >> $1-$i.out + echo "=================================================================================" >> $1-$i.out + if [ "$mtls" = true ]; + then + echo "| " >> $1-$i.out + echo "| ./mqtt-bench --channels $3 -s $size -n $message --subs $sub --pubs $pub -q $qos --retain=$ret -m=true -b tcps://$2:8883 --quiet=true --ca ../../../docker/ssl/certs/ca.crt -t=true" >> $1-$i.out + echo "| " >> $1-$i.out + ../cmd/mqtt-bench --channels $3 -s $size -n $message --subs $sub --pubs $pub -q $qos --retain=$ret -m=true -b tcps://$2:8883 --quiet=true --ca ../../../docker/ssl/certs/ca.crt -t=true >> $1-$i.out + else + echo "| " >> $1-$i.out + echo "| ./mqtt-bench --channels $3 -s $size -n $message --subs $sub --pubs $pub -q $qos --retain=$ret -b tcp://$2:1883 --quiet=true" >> $1-$i.out + echo "| " >> $1-$i.out + ../cmd/mqtt-bench --channels $3 -s $size -n $message --subs $sub --pubs $pub -q $qos --retain=$ret -b tcp://$2:1883 --quiet=true >> $1-$i.out + fi + sleep 2 + done + done + done + done + done + + done +done +files=`ls test*.out | sort --version-sort ` +for file in $files +do + cat $file >> result.$1.out +done +echo "END TEST " >> result.$1.out diff --git a/tools/mqtt-bench/templates/reference.toml b/tools/mqtt-bench/templates/reference.toml new file mode 100644 index 000000000..c03235826 --- /dev/null +++ b/tools/mqtt-bench/templates/reference.toml @@ -0,0 +1,29 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +[mqtt] + timeout = 1000 + [mqtt.broker] + url = "tcp://localhost:1883" + + [mqtt.message] + size = 1000 + format = "text" + qos = 2 + retain = true + payload = "{\"bn\":\"some-base-name\",\"bt\":1.276020076001e+09, \"bu\":\"A\",\"bver\":5, \"n\":\"voltage\",\"u\":\"V\",\"v\":120.1}" + + [mqtt.tls] + mtls = false + skiptlsver = true + ca = "ca.crt" + +[test] +pubs = 2000 +count = 70 + +[log] +quiet = true + +[supermq] +connections_file = "../provision/mgconn.toml" diff --git a/tools/provision/Makefile b/tools/provision/Makefile new file mode 100644 index 000000000..7b8abc566 --- /dev/null +++ b/tools/provision/Makefile @@ -0,0 +1,15 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +PROGRAM = provision +SOURCES = $(wildcard *.go) cmd/main.go + +all: $(PROGRAM) + +.PHONY: all clean + +$(PROGRAM): $(SOURCES) + go build -ldflags "-s -w" -o $@ cmd/main.go + +clean: + rm -rf $(PROGRAM) diff --git a/tools/provision/README.md b/tools/provision/README.md new file mode 100644 index 000000000..6cf027720 --- /dev/null +++ b/tools/provision/README.md @@ -0,0 +1,148 @@ +# SuperMQ Clients and Channels Provisioning Tool + +A simple utility to create a list of channels and clients connected to these channels with possibility to create certificates for mTLS use case. + +This tool is useful for testing, and it creates a TOML format output (on stdout, can be redirected into the file as needed) +that can be used by SuperMQ MQTT benchmarking tool (`mqtt-bench`). + +## Installation + +```bash +cd tools/provision +make +``` + +### Usage + +```bash +./provision --help +Tool for provisioning series of SuperMQ channels and clients and connecting them together. +Complete documentation is available at https://docs.supermq.absmach.eu + +Usage: + provision [flags] + +Flags: + --ca string CA for creating and signing clients certificate (default "ca.crt") + --cakey string ca.key for creating and signing clients certificate (default "ca.key") + -h, --help help for provision + --host string address for supermq instance (default "https://localhost") + --num int number of channels and clients to create and connect (default 10) + -p, --password string supermq users password + --ssl create certificates for mTLS access + -u, --username string supermq user + --prefix string name prefix for clients and channels +``` + +Example: + +```bash +go run tools/provision/cmd/main.go -u test@supermq.com -p test1234 --host https://142.93.118.47 +``` + +If you want to create a list of channels with certificates: + +```bash +go run tools/provision/cmd/main.go --host http://localhost --num 10 -u test@supermq.com -p test1234 --ssl true --ca docker/ssl/certs/ca.crt --cakey docker/ssl/certs/ca.key + +``` + +> `ca.crt` and `ca.key` are used for creating clients certificate and for HTTPS, +> if you are provisioning on remote server you will have to get these files to your local +> directory so that you can create certificates for clients + +Example of output: + +```bash +# List of clients that can be connected to MQTT broker +[[clients]] +client_id = "0eac601b-6d54-4767-b8b7-594aaf9990d3" +client_key = "07713103-513f-43c7-b7fe-500c1af23d7d" +mtls_cert = """-----BEGIN CERTIFICATE----- +MIIEmTCCA4GgAwIBAgIRAO50qOfXsU+cHm/QY2NYu+0wDQYJKoZIhvcNAQELBQAw +VzESMBAGA1UEAwwJbG9jYWxob3N0MREwDwYDVQQKDAhNYWluZmx1eDEMMAoGA1UE +CwwDSW9UMSAwHgYJKoZIhvcNAQkBFhFpbmZvQG1haW5mbHV4LmNvbTAeFw0xOTEx +MTUxNzU2MzhaFw0yMDAyMjMxNzU2MzhaMFUxETAPBgNVBAoTCE1haW5mbHV4MREw +DwYDVQQLEwhtYWluZmx1eDEtMCsGA1UEAxMkMDc3MTMxMDMtNTEzZi00M2M3LWI3 +ZmUtNTAwYzFhZjIzZDdkMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA +zsIYoovZJGJxfu7e4X3P3wnHDi9/wvRMhGW1EZEB5vNvfxvmmt4PhiE1c73mCypT +AUdui0j+hrCx8P90v12LEcJqty3yBnw+ge2/xCLNLKZh2/MjBQ7A7PMQpmOo31LR +hxFSthW41C296iwVYyvRa19y7g5mcUrzWvI2EVZbbGEDym1U/PI4aKhdQ3a7fF6B +GfvXYbGOa4/8VUIj8KHTRg2Z6/iLhxYgUnHd3xMCjihQkwLvB7/avVr9Ih9oLEe+ +h7H9Pl5hMEpHP4BvHokUFhtbzqofuHNBKuEUf5r/cQ1oVAl6F77Fs5vZbQ59bLxw +etclDxW7nvOgIxEIUcJAkdd+nOxhpfbDM8QFsPXGSfb9vWUTaoQDIeWx9pPY5tsY +tbtW2HeKRGHO9jGFSzonY6sbTiaIzQ0F2PNPS1BoBIo2A95YNwt2ScfuRTs5ZK62 +2+RNWbs+pDXJ5ZGcWDfjSxEYXy+jGUyvDExGCtryUu5Ufp7XuZ4O767iDzaj7dFG +rXSXfXrqwm8u2CMwucNzdVqikNG2gDToHDyIjLRd62m2pHk9gXbk3FGI+5x52pBs ++xdRaddMY8+DJ2R88PFoq3kqexxs2HJathCu6RfoP452zH9iU0gvPLR7fXuPoZ6Y +5NqE1CebZ6IiwwivD7kU1LxmhmQUY9DaHdHNYd66bd0CAwEAAaNiMGAwDgYDVR0P +AQH/BAQDAgeAMB0GA1UdJQQWMBQGCCsGAQUFBwMCBggrBgEFBQcDATAOBgNVHQ4E +BwQFAQIDBAYwHwYDVR0jBBgwFoAUbOMUfdahIzURpsN/dcUu8ek3PvIwDQYJKoZI +hvcNAQELBQADggEBAI+DdKYKKPVi4CPUbl+R81dq+Otd8L9i/RxM7G89XU0aGkSO +GSJzURKYbmLGgWdVWcdYMUfbpiE8vH1dLuDQdRywpDDjSMx7h0PwpYvk25HHKMSs +OIKpxvI1DyuNcwxrPuH863zw1Mo1hpGGin7yZc8VBf6nbR3RMNbQ2elMH1m7no4v +YM4HrTeR9n1bakIVw9OLnFpB03sT3keBdWsLDbAZ0yZfvxqdn6Hr7NRnab3vyrOz +GrYPJ51B/FGZC9n0ZR+SWzipen15vaG46SvoCv9HfDZ9cbSVR4eyPy/OIx+5CBVY +uGpJ+kN8jH5tuoxrmHZOsPMA+a6CZD2cKTaRu+Y= +-----END CERTIFICATE----- +""" +mtls_key = """-----BEGIN RSA PRIVATE KEY----- +MIIJKQIBAAKCAgEAzsIYoovZJGJxfu7e4X3P3wnHDi9/wvRMhGW1EZEB5vNvfxvm +mt4PhiE1c73mCypTAUdui0j+hrCx8P90v12LEcJqty3yBnw+ge2/xCLNLKZh2/Mj +BQ7A7PMQpmOo31LRhxFSthW41C296iwVYyvRa19y7g5mcUrzWvI2EVZbbGEDym1U +/PI4aKhdQ3a7fF6BGfvXYbGOa4/8VUIj8KHTRg2Z6/iLhxYgUnHd3xMCjihQkwLv +B7/avVr9Ih9oLEe+h7H9Pl5hMEpHP4BvHokUFhtbzqofuHNBKuEUf5r/cQ1oVAl6 +F77Fs5vZbQ59bLxwetclDxW7nvOgIxEIUcJAkdd+nOxhpfbDM8QFsPXGSfb9vWUT +aoQDIeWx9pPY5tsYtbtW2HeKRGHO9jGFSzonY6sbTiaIzQ0F2PNPS1BoBIo2A95Y +Nwt2ScfuRTs5ZK622+RNWbs+pDXJ5ZGcWDfjSxEYXy+jGUyvDExGCtryUu5Ufp7X +uZ4O767iDzaj7dFGrXSXfXrqwm8u2CMwucNzdVqikNG2gDToHDyIjLRd62m2pHk9 +gXbk3FGI+5x52pBs+xdRaddMY8+DJ2R88PFoq3kqexxs2HJathCu6RfoP452zH9i +U0gvPLR7fXuPoZ6Y5NqE1CebZ6IiwwivD7kU1LxmhmQUY9DaHdHNYd66bd0CAwEA +AQKCAgAj2sr03TWhtqSh84CZL/0tW3+2eQw53a2rRAv7aN8gktSiAU+jSaD9jKK9 +WJAdHZDZZu7Hnrfs2ZVyCorPaMRmJwXkkEYpU8BvPbCErdhQxuWvg+FtzhosvRYF +FMFDQRRuzNVAGFI+EVSe2Fg5I28kpJ/EoqCnQu0it2Ai74vZJpXGs+EKIGMh2xiZ +S2zF64mN3PuDyIu/IXALxPWAlD+UJWWs4yQnH/Io+fAU8DIAPwOCCv8yo9WmArJl +CXdCPorO81HMUAegnTDv1TDv5aujDcmE9EGd9fa2HeQ1IMbtbvrJn/8ZQQ79z6gL +3nhns+H5m3ekvwsTTIJXsmtz6jDSCek5C78gKJ6fIH/urKkgG0Pcw4HdOtt5PYQS +KnAKN9KuPEqwxJCDpwKcENDxBul9Huc9i4m1J8hq4qtEBk8k1rqfjWAxigBmhdQV +jY0q//ou/VYgD07RIqezCovVZwJDqvEKg2A5e2YmUXIbYmG1BTCN5NIDcnwqO65C +gD4V9vgn2+ek7z8rBr5VHJ/3LNqc+XFzQW+GjzVFLUfzkgipMGt4DVQdseXWKaiz +v6LV7Nn4hPKETZ5pYzNll4SH+PkVG0Pwc9g8yZF0CcvQt/4wry78LdihgXUBtI7G ++5cH/DXOCd1itaauggHQwEm6GF4VR3uPthoU++QvPKqSAvWnQQKCAQEA7n6xDE2J +iWEBCj8gDYcKKgMUlwWmnWc7MprOU2oCR4DXLcDNcmJLKwb2UC1Z4dxQy5pJs6Yk +5f6rOFwQ0sMM36PcmRJcBNeMTsj2ilZ79TbVYl4pgtjZLJl4JptwXFZFeVdTx1Sa +QoZasqlyO44Uw5D3+ztddHpnOVPCLd36xV6R3e1scKuXCrE4Pl/+YmkYG8NrRKoe +vHUhmmtcukxsEPhGJhQqpbMhm75hBFfHJw2gMu1bBGDGYzfX9bBkF1ZRq+7X6/g0 +Zvr5Gh1tZhkHDR9JwRMNbTSQgVvJD0eToBo5kZbWF4+giAhNkV+wGiCMJgdGWJQo +4Cz5rY+Nv2Rz7QKCAQEA3e8SzLm4Gvft9AZUy96kuk5uKckAXW/FnDKfa+zFoT7w +KyEz9yOZRFXoPdrReZLzgk8GDZVbYAyXmONx9Sjq1GmZ/fDkXpUtdr6PmDR19Hea +CVqUfkBYmMTmA0zFpS6rsI+dIwCP2h7slJQ4eUESYVRiXWyOKEhQVGM0t9liUfrr +lfRnVj6q9I3vqCcqgBuODoAS/iFaFpSfh05XSKdl9XW2t/sd33acPqh9zKBczlsR +H6dyrO02znbbOgrBCBbxtFdq4YLuHKsBB2umz/NKfpnoOUHLeTU2VaqyOtDK9BIA +XtCPu6KJNZ86eFAbtHwBpHn7u7iQZtcaWK9LuESDsQKCAQEAiMV/I18UEQTgY8/v +wdI/sfgyRqmm833QJSVCTfPterQYstRu/boBAZvshe58LVr7usewnKYbYwq5hojF +3RieuWJvkBlHTD+Q5124hX0zeV0I4nC9vZw+b6VTklByD4IqNXwvP5D1JlGGkg86 +w4ynu7/XduyEm9fWerneEg/LUIT7gho2pibBaBBaAOtsJ2O9v65CRg6Jseo6ayRG ++U/6aYD4Ob429u/Txk1XtfXg8DSQOqSEHe6h1ySfZPbTb87A56kBiwG8i5JCaQeX +RYX01UGsOl2Cxa3vcUAB/hE+SALCIQwvmzNzDJA2a7hEdbdUqDpjzUiqaGViinZZ +A/nHwQKCAQAkTxLCT7ghIWLaw5Zn7DsDCAXZ7DqVDs5DqbyPSaNjqApe5AW+byKK +HYvrYrtWqoYQUaFp43+ZjTXYG43vUAxrSAObmieimcFgZfjUK/EIV/Dpito0dY6J +H92JuKu1RJduQXCx40ulod2OyVkb7Vt2dPnK0xHG4V3TEI/1bCk7xFN6qwuk/oe1 +jusglZfMcbWiBa4VyZsViqc22chJ6KkzqViFbR4MCzmwvpwmOC42zItWpGyMghqv +WJ6xNkUyb56HpK2ly2ftZMS8VA5sgx8y6zck9vC1GdGT3mNeX/50Q+WvnWuGhSbx +kOVd/a0qsAcMw7A9nApz6Mk0rSk0MnFhAoIBAQCI6dU5c1sTp/LNp+z6yQmcJD3Z +HNYdVhf8pxHpRWZ8r5otFwi1lr5vk15Zh59B5nMLQHP3UWJ7R66HUjXCtFe86ojV +xngL3lXJNtLcCWXQHM/nkWZ1TVCeZ6mS8aJndcy4sY0lPUqRtYaXSV/EyzpQJUmf +xcEeQuOhBZ4s8uSyuLgEPYbeYyi7Vpujm7UpplTN55dIZrQ7tMefRNgHjybFfC8P +QsxPR4lWoFpr9xFvtBORlP+In8LjD3Z2EDm2guIRAWebEJGsY7ftAv7CEFrLOJd5 +uCRt+TFMyEfqilipmNsV7esgbroiyEGXGMI8JdBY9OsnK6ZSlXaMnQ9vq2kK +-----END RSA PRIVATE KEY----- +""" + +# List of channels that clients can publish to +# each channel is connected to each client from clients list +# Clients connected to channel 1f18afa1-29c4-4634-99d1-68dfa1b74e6a: 0eac601b-6d54-4767-b8b7-594aaf9990d3 +[[channels]] +channel_id = "1f18afa1-29c4-4634-99d1-68dfa1b74e6a" + +``` diff --git a/tools/provision/cmd/main.go b/tools/provision/cmd/main.go new file mode 100644 index 000000000..98f4be727 --- /dev/null +++ b/tools/provision/cmd/main.go @@ -0,0 +1,42 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package main contains entry point for provisioning tool. +package main + +import ( + "log" + + "github.com/absmach/supermq/tools/provision" + "github.com/spf13/cobra" +) + +func main() { + pconf := provision.Config{} + + rootCmd := &cobra.Command{ + Use: "provision", + Short: "provision is provisioning tool for SuperMQ", + Long: `Tool for provisioning series of SuperMQ channels and clients and connecting them together. +Complete documentation is available at https://docs.supermq.absmach.eu`, + Run: func(cmd *cobra.Command, _ []string) { + if err := provision.Provision(cmd.Context(), pconf); err != nil { + log.Fatal(err) + } + }, + } + + // Root Flags + rootCmd.PersistentFlags().StringVarP(&pconf.Host, "host", "", "https://localhost", "address for supermq instance") + rootCmd.PersistentFlags().StringVarP(&pconf.Prefix, "prefix", "", "", "name prefix for clients and channels") + rootCmd.PersistentFlags().StringVarP(&pconf.Username, "username", "u", "", "supermq user") + rootCmd.PersistentFlags().StringVarP(&pconf.Password, "password", "p", "", "supermq users password") + rootCmd.PersistentFlags().IntVarP(&pconf.Num, "num", "", 10, "number of channels and clients to create and connect") + rootCmd.PersistentFlags().BoolVarP(&pconf.SSL, "ssl", "", false, "create certificates for mTLS access") + rootCmd.PersistentFlags().StringVarP(&pconf.CAKey, "cakey", "", "ca.key", "ca.key for creating and signing clients certificate") + rootCmd.PersistentFlags().StringVarP(&pconf.CA, "ca", "", "ca.crt", "CA for creating and signing clients certificate") + + if err := rootCmd.Execute(); err != nil { + log.Fatal(err) + } +} diff --git a/tools/provision/doc.go b/tools/provision/doc.go new file mode 100644 index 000000000..da596dc21 --- /dev/null +++ b/tools/provision/doc.go @@ -0,0 +1,7 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package provision is a simple utility to create +// a list of channels and clients connected to these channels +// with possibility to create certificates for mTLS use case. +package provision diff --git a/tools/provision/provision.go b/tools/provision/provision.go new file mode 100644 index 000000000..917566338 --- /dev/null +++ b/tools/provision/provision.go @@ -0,0 +1,301 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package provision + +import ( + "bufio" + "bytes" + "context" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "log" + "math/big" + "os" + "strings" + "time" + + "github.com/0x6flab/namegenerator" + sdk "github.com/absmach/supermq/pkg/sdk" + supermqSDK "github.com/absmach/supermq/pkg/sdk" +) + +const ( + defPass = "12345678" + defReaderURL = "http://localhost:9005" +) + +var namesgenerator = namegenerator.NewGenerator() + +// MgConn - structure describing SuperMQ connection set. +type MgConn struct { + ClientID string + ClinetSecret string + ChannelID string + MTLSCert string + MTLSKey string +} + +// Config - provisioning configuration. +type Config struct { + Host string + Username string + Email string + Password string + Num int + SSL bool + CA string + CAKey string + Prefix string +} + +// Provision - function that does actual provisiong. +func Provision(ctx context.Context, conf Config) error { + const ( + rsaBits = 4096 + ttl = "2400h" + ) + + msgContentType := string(supermqSDK.CTJSONSenML) + sdkConf := sdk.Config{ + ClientsURL: conf.Host, + UsersURL: conf.Host, + ReaderURL: defReaderURL, + HTTPAdapterURL: fmt.Sprintf("%s/http", conf.Host), + BootstrapURL: conf.Host, + CertsURL: conf.Host, + MsgContentType: supermqSDK.ContentType(msgContentType), + TLSVerification: false, + } + + s := sdk.NewSDK(sdkConf) + + user := supermqSDK.User{ + Email: conf.Email, + Credentials: supermqSDK.Credentials{ + Username: conf.Username, + Secret: conf.Password, + }, + } + + if user.Email == "" { + user.Email = fmt.Sprintf("%s@email.com", namesgenerator.Generate()) + user.Credentials.Secret = defPass + } + + // Create new user + if _, err := s.CreateUser(ctx, user, ""); err != nil { + return fmt.Errorf("unable to create new user: %s", err.Error()) + } + + var err error + + // Login user + token, err := s.CreateToken(ctx, supermqSDK.Login{Username: user.Credentials.Username, Password: user.Credentials.Secret}) + if err != nil { + return fmt.Errorf("unable to login user: %s", err.Error()) + } + + // Create new domain + dname := fmt.Sprintf("%s%s", conf.Prefix, namesgenerator.Generate()) + domain := supermqSDK.Domain{ + Name: dname, + Route: strings.ToLower(dname), + Permission: "admin", + } + + domain, err = s.CreateDomain(ctx, domain, token.AccessToken) + if err != nil { + return fmt.Errorf("unable to create domain: %w", err) + } + // Login to domain + token, err = s.CreateToken(ctx, supermqSDK.Login{ + Username: user.Credentials.Username, + Password: user.Credentials.Secret, + }) + if err != nil { + return fmt.Errorf("unable to login user: %w", err) + } + + var tlsCert tls.Certificate + var caCert *x509.Certificate + + if conf.SSL { + tlsCert, err = tls.LoadX509KeyPair(conf.CA, conf.CAKey) + if err != nil { + return fmt.Errorf("failed to load CA cert") + } + + b, err := os.ReadFile(conf.CA) + if err != nil { + return fmt.Errorf("failed to load CA cert") + } + + block, _ := pem.Decode(b) + if block == nil { + return fmt.Errorf("no PEM data found, failed to decode CA") + } + + caCert, err = x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("failed to decode certificate - %s", err.Error()) + } + } + + // Create clients and channels + clients := make([]supermqSDK.Client, conf.Num) + channels := make([]supermqSDK.Channel, conf.Num) + cIDs := []string{} + tIDs := []string{} + + fmt.Println("# List of clients that can be connected to MQTT broker") + + for i := 0; i < conf.Num; i++ { + clients[i] = supermqSDK.Client{Name: fmt.Sprintf("%s-client-%d", conf.Prefix, i)} + channels[i] = supermqSDK.Channel{Name: fmt.Sprintf("%s-channel-%d", conf.Prefix, i)} + } + + clients, err = s.CreateClients(ctx, clients, domain.ID, token.AccessToken) + if err != nil { + return fmt.Errorf("failed to create the clients: %s", err.Error()) + } + + var chs []supermqSDK.Channel + for _, c := range channels { + c, err = s.CreateChannel(ctx, c, domain.ID, token.AccessToken) + if err != nil { + return fmt.Errorf("failed to create the chennels: %s", err.Error()) + } + chs = append(chs, c) + } + channels = chs + + for _, t := range clients { + tIDs = append(tIDs, t.ID) + } + + for _, c := range channels { + cIDs = append(cIDs, c.ID) + } + + for i := 0; i < conf.Num; i++ { + cert := "" + key := "" + + if conf.SSL { + var priv any + priv, _ = rsa.GenerateKey(rand.Reader, rsaBits) + + notBefore := time.Now() + validFor, err := time.ParseDuration(ttl) + if err != nil { + return fmt.Errorf("failed to set date %v", validFor) + } + notAfter := notBefore.Add(validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return fmt.Errorf("failed to generate serial number: %s", err) + } + + tmpl := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"SuperMQ"}, + CommonName: clients[i].Credentials.Secret, + OrganizationalUnit: []string{"supermq"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, caCert, publicKey(priv), tlsCert.PrivateKey) + if err != nil { + return fmt.Errorf("failed to create certificate: %s", err) + } + + var bw, keyOut bytes.Buffer + buffWriter := bufio.NewWriter(&bw) + buffKeyOut := bufio.NewWriter(&keyOut) + + if err := pem.Encode(buffWriter, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return fmt.Errorf("failed to write cert pem data: %s", err) + } + buffWriter.Flush() + cert = bw.String() + + if err := pem.Encode(buffKeyOut, pemBlockForKey(priv)); err != nil { + return fmt.Errorf("failed to write key pem data: %s", err) + } + buffKeyOut.Flush() + key = keyOut.String() + } + + // Print output + fmt.Printf("[[clients]]\nclient_id = \"%s\"\nclient_key = \"%s\"\n", clients[i].ID, clients[i].Credentials.Secret) + if conf.SSL { + fmt.Printf("mtls_cert = \"\"\"%s\"\"\"\n", cert) + fmt.Printf("mtls_key = \"\"\"%s\"\"\"\n", key) + } + fmt.Println("") + } + + fmt.Printf("# List of channels that clients can publish to\n" + + "# each channel is connected to each client from clients list\n") + for i := 0; i < conf.Num; i++ { + fmt.Printf("[[channels]]\nchannel_id = \"%s\"\n\n", cIDs[i]) + } + + for _, cID := range cIDs { + for _, tID := range tIDs { + conIDs := supermqSDK.Connection{ + ClientIDs: []string{tID}, + ChannelIDs: []string{cID}, + Types: []string{"publish", "subscribe"}, + } + if err := s.Connect(ctx, conIDs, domain.ID, token.AccessToken); err != nil { + log.Fatalf("Failed to connect clients %s to channels %s: %s", tID, cID, err) + } + } + } + + return nil +} + +func publicKey(priv any) any { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} + +func pemBlockForKey(priv any) *pem.Block { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} + case *ecdsa.PrivateKey: + b, err := x509.MarshalECPrivateKey(k) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) + os.Exit(2) + } + return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} + default: + return nil + } +} diff --git a/users/README.md b/users/README.md index aa9356e43..bf51418b7 100644 --- a/users/README.md +++ b/users/README.md @@ -14,49 +14,49 @@ The service is configured using the environment variables presented in the follo | Variable | Description | Default | | --------------------------------- | ----------------------------------------------------------------------- | --------------------------------- | -| `SMQ_USERS_LOG_LEVEL` | Log level for users service (debug, info, warn, error) | info | -| `SMQ_USERS_ADMIN_EMAIL` | Default user, created on startup | | -| `SMQ_USERS_ADMIN_PASSWORD` | Default user password, created on startup | 12345678 | -| `SMQ_USERS_PASS_REGEX` | Password regex | ^.{8,}$ | -| `SMQ_USERS_HTTP_HOST` | Users service HTTP host | localhost | -| `SMQ_USERS_HTTP_PORT` | Users service HTTP port | 9002 | -| `SMQ_USERS_HTTP_SERVER_CERT` | Path to the PEM encoded server certificate file | "" | -| `SMQ_USERS_HTTP_SERVER_KEY` | Path to the PEM encoded server key file | "" | -| `SMQ_USERS_HTTP_SERVER_CA_CERTS` | Path to the PEM encoded server CA certificate file | "" | -| `SMQ_USERS_HTTP_CLIENT_CA_CERTS` | Path to the PEM encoded client CA certificate file | "" | -| `SMQ_AUTH_GRPC_URL` | Auth service GRPC URL | localhost:8181 | -| `SMQ_AUTH_GRPC_TIMEOUT` | Auth service GRPC timeout | 1s | -| `SMQ_AUTH_GRPC_CLIENT_CERT` | Path to the PEM encoded client certificate file | "" | -| `SMQ_AUTH_GRPC_CLIENT_KEY` | Path to the PEM encoded client key file | "" | -| `SMQ_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM encoded server CA certificate file | "" | -| `SMQ_USERS_DB_HOST` | Database host address | localhost | -| `SMQ_USERS_DB_PORT` | Database host port | 5432 | -| `SMQ_USERS_DB_USER` | Database user | supermq | -| `SMQ_USERS_DB_PASS` | Database password | supermq | -| `SMQ_USERS_DB_NAME` | Name of the database used by the service | users | -| `SMQ_USERS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| `SMQ_USERS_DB_SSL_CERT` | Path to the PEM encoded certificate file | "" | -| `SMQ_USERS_DB_SSL_KEY` | Path to the PEM encoded key file | "" | -| `SMQ_USERS_DB_SSL_ROOT_CERT` | Path to the PEM encoded root certificate file | "" | -| `SMQ_EMAIL_HOST` | Mail server host | localhost | -| `SMQ_EMAIL_PORT` | Mail server port | 25 | -| `SMQ_EMAIL_USERNAME` | Mail server username | "" | -| `SMQ_EMAIL_PASSWORD` | Mail server password | "" | -| `SMQ_EMAIL_FROM_ADDRESS` | Email "from" address | "" | -| `SMQ_EMAIL_FROM_NAME` | Email "from" name | "" | -| `SMQ_PASSWORD_RESET_URL_PREFIX` | Password reset URL prefix | | -| `SMQ_PASSWORD_RESET_EMAIL_TEMPLATE` | Password reset email template | reset-password-email.tmpl | -| `SMQ_VERIFICATION_URL_PREFIX` | Verification URL prefix | | -| `SMQ_VERIFICATION_EMAIL_TEMPLATE` | Verification email template | verification-email.tmpl | -| `SMQ_USERS_ES_URL` | Event store URL | | -| `SMQ_JAEGER_URL` | Jaeger server URL | | -| `SMQ_OAUTH_UI_REDIRECT_URL` | OAuth UI redirect URL | | -| `SMQ_OAUTH_UI_ERROR_URL` | OAuth UI error URL | | -| `SMQ_USERS_DELETE_INTERVAL` | Interval for deleting users | 24h | -| `SMQ_USERS_DELETE_AFTER` | Time after which users are deleted | 720h | -| `SMQ_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | -| `SMQ_SEND_TELEMETRY` | Send telemetry to supermq call home server. | true | -| `SMQ_USERS_INSTANCE_ID` | SuperMQ instance ID | "" | +| `MG_USERS_LOG_LEVEL` | Log level for users service (debug, info, warn, error) | info | +| `MG_USERS_ADMIN_EMAIL` | Default user, created on startup | | +| `MG_USERS_ADMIN_PASSWORD` | Default user password, created on startup | 12345678 | +| `MG_USERS_PASS_REGEX` | Password regex | ^.{8,}$ | +| `MG_USERS_HTTP_HOST` | Users service HTTP host | localhost | +| `MG_USERS_HTTP_PORT` | Users service HTTP port | 9002 | +| `MG_USERS_HTTP_SERVER_CERT` | Path to the PEM encoded server certificate file | "" | +| `MG_USERS_HTTP_SERVER_KEY` | Path to the PEM encoded server key file | "" | +| `MG_USERS_HTTP_SERVER_CA_CERTS` | Path to the PEM encoded server CA certificate file | "" | +| `MG_USERS_HTTP_CLIENT_CA_CERTS` | Path to the PEM encoded client CA certificate file | "" | +| `MG_AUTH_GRPC_URL` | Auth service GRPC URL | localhost:8181 | +| `MG_AUTH_GRPC_TIMEOUT` | Auth service GRPC timeout | 1s | +| `MG_AUTH_GRPC_CLIENT_CERT` | Path to the PEM encoded client certificate file | "" | +| `MG_AUTH_GRPC_CLIENT_KEY` | Path to the PEM encoded client key file | "" | +| `MG_AUTH_GRPC_SERVER_CA_CERTS` | Path to the PEM encoded server CA certificate file | "" | +| `MG_USERS_DB_HOST` | Database host address | localhost | +| `MG_USERS_DB_PORT` | Database host port | 5432 | +| `MG_USERS_DB_USER` | Database user | supermq | +| `MG_USERS_DB_PASS` | Database password | supermq | +| `MG_USERS_DB_NAME` | Name of the database used by the service | users | +| `MG_USERS_DB_SSL_MODE` | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| `MG_USERS_DB_SSL_CERT` | Path to the PEM encoded certificate file | "" | +| `MG_USERS_DB_SSL_KEY` | Path to the PEM encoded key file | "" | +| `MG_USERS_DB_SSL_ROOT_CERT` | Path to the PEM encoded root certificate file | "" | +| `MG_EMAIL_HOST` | Mail server host | localhost | +| `MG_EMAIL_PORT` | Mail server port | 25 | +| `MG_EMAIL_USERNAME` | Mail server username | "" | +| `MG_EMAIL_PASSWORD` | Mail server password | "" | +| `MG_EMAIL_FROM_ADDRESS` | Email "from" address | "" | +| `MG_EMAIL_FROM_NAME` | Email "from" name | "" | +| `MG_PASSWORD_RESET_URL_PREFIX` | Password reset URL prefix | | +| `MG_PASSWORD_RESET_EMAIL_TEMPLATE` | Password reset email template | reset-password-email.tmpl | +| `MG_VERIFICATION_URL_PREFIX` | Verification URL prefix | | +| `MG_VERIFICATION_EMAIL_TEMPLATE` | Verification email template | verification-email.tmpl | +| `MG_USERS_ES_URL` | Event store URL | | +| `MG_JAEGER_URL` | Jaeger server URL | | +| `MG_OAUTH_UI_REDIRECT_URL` | OAuth UI redirect URL | | +| `MG_OAUTH_UI_ERROR_URL` | OAuth UI error URL | | +| `MG_USERS_DELETE_INTERVAL` | Interval for deleting users | 24h | +| `MG_USERS_DELETE_AFTER` | Time after which users are deleted | 720h | +| `MG_JAEGER_TRACE_RATIO` | Jaeger sampling ratio | 1.0 | +| `MG_SEND_TELEMETRY` | Send telemetry to supermq call home server. | true | +| `MG_USERS_INSTANCE_ID` | SuperMQ instance ID | "" | ## Deployment @@ -77,57 +77,57 @@ make users make install # set the environment variables and run the service -SMQ_USERS_LOG_LEVEL=info \ -SMQ_USERS_ADMIN_EMAIL=admin@example.com \ -SMQ_USERS_ADMIN_PASSWORD=12345678 \ -SMQ_USERS_PASS_REGEX="^.{8,}$" \ -SMQ_USERS_HTTP_HOST=localhost \ -SMQ_USERS_HTTP_PORT=9002 \ -SMQ_USERS_HTTP_SERVER_CERT="" \ -SMQ_USERS_HTTP_SERVER_KEY="" \ -SMQ_USERS_HTTP_SERVER_CA_CERTS="" \ -SMQ_USERS_HTTP_CLIENT_CA_CERTS="" \ -SMQ_AUTH_GRPC_URL=localhost:8181 \ -SMQ_AUTH_GRPC_TIMEOUT=1s \ -SMQ_AUTH_GRPC_CLIENT_CERT="" \ -SMQ_AUTH_GRPC_CLIENT_KEY="" \ -SMQ_AUTH_GRPC_SERVER_CA_CERTS="" \ -SMQ_USERS_DB_HOST=localhost \ -SMQ_USERS_DB_PORT=5432 \ -SMQ_USERS_DB_USER=supermq \ -SMQ_USERS_DB_PASS=supermq \ -SMQ_USERS_DB_NAME=users \ -SMQ_USERS_DB_SSL_MODE=disable \ -SMQ_USERS_DB_SSL_CERT="" \ -SMQ_USERS_DB_SSL_KEY="" \ -SMQ_USERS_DB_SSL_ROOT_CERT="" \ -SMQ_EMAIL_HOST=smtp.mailtrap.io \ -SMQ_EMAIL_PORT=2525 \ -SMQ_EMAIL_USERNAME="18bf7f7070513" \ -SMQ_EMAIL_PASSWORD="2b0d302e775b1e" \ -SMQ_EMAIL_FROM_ADDRESS=from@example.com \ -SMQ_EMAIL_FROM_NAME=Example \ -SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset \ -SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=docker/templates/reset-password-email.tmpl \ -SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email \ -SMQ_VERIFICATION_EMAIL_TEMPLATE=docker/templates/verification-email.tmpl \ -SMQ_USERS_ES_URL=nats://localhost:4222 \ -SMQ_JAEGER_URL=http://localhost:14268/api/traces \ -SMQ_JAEGER_TRACE_RATIO=1.0 \ -SMQ_SEND_TELEMETRY=true \ -SMQ_OAUTH_UI_REDIRECT_URL=http://localhost:9095/domains \ -SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095/error \ -SMQ_USERS_DELETE_INTERVAL=24h \ -SMQ_USERS_DELETE_AFTER=720h \ -SMQ_USERS_INSTANCE_ID="" \ +MG_USERS_LOG_LEVEL=info \ +MG_USERS_ADMIN_EMAIL=admin@example.com \ +MG_USERS_ADMIN_PASSWORD=12345678 \ +MG_USERS_PASS_REGEX="^.{8,}$" \ +MG_USERS_HTTP_HOST=localhost \ +MG_USERS_HTTP_PORT=9002 \ +MG_USERS_HTTP_SERVER_CERT="" \ +MG_USERS_HTTP_SERVER_KEY="" \ +MG_USERS_HTTP_SERVER_CA_CERTS="" \ +MG_USERS_HTTP_CLIENT_CA_CERTS="" \ +MG_AUTH_GRPC_URL=localhost:8181 \ +MG_AUTH_GRPC_TIMEOUT=1s \ +MG_AUTH_GRPC_CLIENT_CERT="" \ +MG_AUTH_GRPC_CLIENT_KEY="" \ +MG_AUTH_GRPC_SERVER_CA_CERTS="" \ +MG_USERS_DB_HOST=localhost \ +MG_USERS_DB_PORT=5432 \ +MG_USERS_DB_USER=supermq \ +MG_USERS_DB_PASS=supermq \ +MG_USERS_DB_NAME=users \ +MG_USERS_DB_SSL_MODE=disable \ +MG_USERS_DB_SSL_CERT="" \ +MG_USERS_DB_SSL_KEY="" \ +MG_USERS_DB_SSL_ROOT_CERT="" \ +MG_EMAIL_HOST=smtp.mailtrap.io \ +MG_EMAIL_PORT=2525 \ +MG_EMAIL_USERNAME="18bf7f7070513" \ +MG_EMAIL_PASSWORD="2b0d302e775b1e" \ +MG_EMAIL_FROM_ADDRESS=from@example.com \ +MG_EMAIL_FROM_NAME=Example \ +MG_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset \ +MG_PASSWORD_RESET_EMAIL_TEMPLATE=docker/templates/reset-password-email.tmpl \ +MG_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email \ +MG_VERIFICATION_EMAIL_TEMPLATE=docker/templates/verification-email.tmpl \ +MG_USERS_ES_URL=nats://localhost:4222 \ +MG_JAEGER_URL=http://localhost:14268/api/traces \ +MG_JAEGER_TRACE_RATIO=1.0 \ +MG_SEND_TELEMETRY=true \ +MG_OAUTH_UI_REDIRECT_URL=http://localhost:9095/domains \ +MG_OAUTH_UI_ERROR_URL=http://localhost:9095/error \ +MG_USERS_DELETE_INTERVAL=24h \ +MG_USERS_DELETE_AFTER=720h \ +MG_USERS_INSTANCE_ID="" \ $GOBIN/supermq-users ``` -If `SMQ_EMAIL_TEMPLATE` doesn't point to any file service will function but password reset functionality will not work. The email environment variables are used to send emails with password reset link. The service expects a file in Go template format. The template should be something like [this](https://github.com/absmach/supermq/blob/main/docker/templates/users.tmpl). +If `MG_EMAIL_TEMPLATE` doesn't point to any file service will function but password reset functionality will not work. The email environment variables are used to send emails with password reset link. The service expects a file in Go template format. The template should be something like [this](https://github.com/absmach/supermq/blob/main/docker/templates/users.tmpl). -Setting `SMQ_USERS_HTTP_SERVER_CERT` and `SMQ_USERS_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_USERS_HTTP_SERVER_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. Setting `SMQ_USERS_HTTP_CLIENT_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. +Setting `MG_USERS_HTTP_SERVER_CERT` and `MG_USERS_HTTP_SERVER_KEY` will enable TLS against the service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_USERS_HTTP_SERVER_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. Setting `MG_USERS_HTTP_CLIENT_CA_CERTS` will enable TLS against the service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. -Setting `SMQ_AUTH_GRPC_CLIENT_CERT` and `SMQ_AUTH_GRPC_CLIENT_KEY` will enable TLS against the auth service. The service expects a file in PEM format for both the certificate and the key. Setting `SMQ_AUTH_GRPC_SERVER_CA_CERTS` will enable TLS against the auth service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. +Setting `MG_AUTH_GRPC_CLIENT_CERT` and `MG_AUTH_GRPC_CLIENT_KEY` will enable TLS against the auth service. The service expects a file in PEM format for both the certificate and the key. Setting `MG_AUTH_GRPC_SERVER_CA_CERTS` will enable TLS against the auth service trusting only those CAs that are provided. The service expects a file in PEM format of trusted CAs. ## HTTP API @@ -153,7 +153,7 @@ Base URL defaults to `http://localhost:9002`. Unless otherwise noted, endpoints - Disable self-registration in production; onboard users via admin tokens or your IdP. - Keep `allow_unverified_user` false and require email verification before granting domain roles. - Enforce TLS for HTTP and mTLS for gRPC by setting server/client cert env vars. -- Harden passwords with `SMQ_USERS_PASS_REGEX` and rotate credentials; purge stale accounts via `SMQ_USERS_DELETE_AFTER`. +- Harden passwords with `MG_USERS_PASS_REGEX` and rotate credentials; purge stale accounts via `MG_USERS_DELETE_AFTER`. - Rate-limit token issuance and password reset endpoints at your API gateway; export Prometheus metrics to watch for abuse. - Store SMTP credentials and certificates in a secrets manager; avoid embedding secrets in images or repos. diff --git a/users/api/endpoints.go b/users/api/endpoints.go index a53bd4771..6b6d3772d 100644 --- a/users/api/endpoints.go +++ b/users/api/endpoints.go @@ -274,9 +274,9 @@ func updateEmailEndpoint(svc users.Service) endpoint.Endpoint { // Password reset request endpoint. // When successful password reset link is generated. -// Link is generated using SMQ_TOKEN_RESET_ENDPOINT env. +// Link is generated using MG_TOKEN_RESET_ENDPOINT env. // and value from Referer header for host. -// {Referer}+{SMQ_TOKEN_RESET_ENDPOINT}+{token=TOKEN} +// {Referer}+{MG_TOKEN_RESET_ENDPOINT}+{token=TOKEN} // http://supermq.com/reset-request?token=xxxxxxxxxxx. // Email with a link is being sent to the user. // When user clicks on a link it should get the ui with form to diff --git a/users/emailer/emailer.go b/users/emailer/emailer.go index 6cf4bdfb3..24ca40405 100644 --- a/users/emailer/emailer.go +++ b/users/emailer/emailer.go @@ -41,10 +41,10 @@ func New(resetURL, verificationURL string, resetConfig, verifyConfig *email.Conf func (e *emailer) SendPasswordReset(to []string, user, token string) error { url := fmt.Sprintf("%s?token=%s", e.resetURL, token) - return e.resetAgent.Send(to, "", "Password Reset Request", "", user, url, "") + return e.resetAgent.Send(to, "", "Password Reset Request", "", user, url, "", nil) } func (e *emailer) SendVerification(to []string, user, verificationToken string) error { url := fmt.Sprintf("%s?token=%s", e.verificationURL, verificationToken) - return e.verifyAgent.Send(to, "", "Email Verification", "", user, url, "") + return e.verifyAgent.Send(to, "", "Email Verification", "", user, url, "", nil) } diff --git a/users/events/streams.go b/users/events/streams.go index fb176adc8..26edd66c5 100644 --- a/users/events/streams.go +++ b/users/events/streams.go @@ -53,7 +53,7 @@ type eventStore struct { // NewEventStoreMiddleware returns wrapper around users service that sends // events to event store. func NewEventStoreMiddleware(ctx context.Context, svc users.Service, url string) (users.Service, error) { - publisher, err := store.NewPublisher(ctx, url) + publisher, err := store.NewPublisher(ctx, url, "users-es-pub") if err != nil { return nil, err } diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index ba229f7d2..9d3ee9e0c 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -127,7 +127,7 @@ func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn return users.User{}, err } session.SuperAdmin = true - if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, user.ID, policies.MembershipPermission, policies.PlatformType, policies.SuperMQObject); err != nil { + if err := am.authorize(ctx, session, "", policies.UserType, policies.UsersKind, user.ID, policies.MembershipPermission, policies.PlatformType, policies.MagistralaObject); err != nil { return users.User{}, err } @@ -183,7 +183,7 @@ func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, user users } func (am *authorizationMiddleware) OAuthAddUserPolicy(ctx context.Context, user users.User) error { - if err := am.authorize(ctx, authn.Session{}, "", policies.UserType, policies.UsersKind, user.ID, policies.MembershipPermission, policies.PlatformType, policies.SuperMQObject); err == nil { + if err := am.authorize(ctx, authn.Session{}, "", policies.UserType, policies.UsersKind, user.ID, policies.MembershipPermission, policies.PlatformType, policies.MagistralaObject); err == nil { return nil } return am.svc.OAuthAddUserPolicy(ctx, user) @@ -198,7 +198,7 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session Subject: session.UserID, Permission: policies.AdminPermission, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }, nil); err != nil { return err } diff --git a/users/service.go b/users/service.go index ad9515294..233c4a14e 100644 --- a/users/service.go +++ b/users/service.go @@ -682,7 +682,7 @@ func (svc service) addUserPolicy(ctx context.Context, userID string, role Role) Subject: userID, Relation: policies.MemberRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) if role == AdminRole { @@ -691,7 +691,7 @@ func (svc service) addUserPolicy(ctx context.Context, userID string, role Role) Subject: userID, Relation: policies.AdministratorRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) } err := svc.policies.AddPolicies(ctx, policyList) @@ -710,7 +710,7 @@ func (svc service) addUserPolicyRollback(ctx context.Context, userID string, rol Subject: userID, Relation: policies.MemberRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) if role == AdminRole { @@ -719,7 +719,7 @@ func (svc service) addUserPolicyRollback(ctx context.Context, userID string, rol Subject: userID, Relation: policies.AdministratorRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) } err := svc.policies.DeletePolicies(ctx, policyList) @@ -738,7 +738,7 @@ func (svc service) updateUserPolicy(ctx context.Context, userID string, role Rol Subject: userID, Relation: policies.AdministratorRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) if err != nil { return errors.Wrap(svcerr.ErrAddPolicies, err) @@ -753,7 +753,7 @@ func (svc service) updateUserPolicy(ctx context.Context, userID string, role Rol Subject: userID, Relation: policies.AdministratorRelation, ObjectType: policies.PlatformType, - Object: policies.SuperMQObject, + Object: policies.MagistralaObject, }) if err != nil { return errors.Wrap(svcerr.ErrDeletePolicies, err)