diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index fa59a21e6..cf999d40b 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -17,7 +17,6 @@ on: - "domains/api/http/**" - "groups/api/http/**" - "http/api/**" - - "invitations/api/**" - "journal/api/**" - "users/api/**" @@ -33,7 +32,6 @@ env: CHANNELS_URL: http://localhost:9005 GROUPS_URL: http://localhost:9004 HTTP_ADAPTER_URL: http://localhost:8008 - INVITATIONS_URL: http://localhost:9020 AUTH_URL: http://localhost:9001 CERTS_URL: http://localhost:9019 JOURNAL_URL: http://localhost:9021 @@ -95,11 +93,6 @@ jobs: - "apidocs/openapi/http.yml" - "http/api/**" - invitations: - - ".github/workflows/api-tests.yml" - - "apidocs/openapi/invitations.yml" - - "invitations/api/**" - clients: - ".github/workflows/api-tests.yml" - "apidocs/openapi/clients.yml" @@ -170,16 +163,6 @@ jobs: report: false args: '--header "Authorization: Client ${{ env.CLIENT_SECRET }}" --contrib-openapi-formats-uuid --hypothesis-suppress-health-check=filter_too_much --stateful=links' - - name: Run Invitations API tests - if: steps.changes.outputs.invitations == 'true' - uses: schemathesis/action@v1 - with: - schema: apidocs/openapi/invitations.yml - base-url: ${{ env.INVITATIONS_URL }} - checks: all - report: false - args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --contrib-openapi-formats-uuid --hypothesis-suppress-health-check=filter_too_much --stateful=links' - - name: Run Auth API tests if: steps.changes.outputs.auth == 'true' uses: schemathesis/action@v1 diff --git a/.github/workflows/check-generated-files.yml b/.github/workflows/check-generated-files.yml index 192703fbf..caf9ab5ce 100644 --- a/.github/workflows/check-generated-files.yml +++ b/.github/workflows/check-generated-files.yml @@ -58,7 +58,6 @@ jobs: - "auth/service.go" - "pkg/events/events.go" - "pkg/groups/groups.go" - - "invitations/invitations.go" - "users/emailer.go" - "users/hasher.go" - "certs/certs.go" @@ -119,8 +118,6 @@ jobs: MOCKERY_VERSION=v2.43.2 go install github.com/vektra/mockery/v2@$MOCKERY_VERSION - mv ./invitations/mocks/repository.go ./invitations/mocks/repository.go.tmp - mv ./invitations/mocks/service.go ./invitations/mocks/service.go.tmp mv ./auth/mocks/token_client.go ./auth/mocks/token_client.go.tmp mv ./auth/mocks/authz.go ./auth/mocks/authz.go.tmp mv ./auth/mocks/keys.go ./auth/mocks/keys.go.tmp @@ -177,8 +174,6 @@ jobs: fi } - check_mock_changes ./invitations/mocks/repository.go " ./invitations/mocks/repository.go" - check_mock_changes ./invitations/mocks/service.go " ./invitations/mocks/service.go" check_mock_changes ./auth/mocks/token_client.go " ./auth/mocks/token_client.go" check_mock_changes ./auth/mocks/authz.go " ./auth/mocks/authz.go" check_mock_changes ./auth/mocks/keys.go " ./auth/mocks/keys.go" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6c74ee3c3..fe0049e67 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -164,14 +164,6 @@ jobs: internal: - "internal/**" - invitations: - - "invitations/**" - - "cmd/invitations/**" - - "auth.pb.go" - - "auth_grpc.pb.go" - - "auth/**" - - "pkg/sdk/**" - journal: - "journal/**" - "cmd/journal/**" @@ -292,11 +284,6 @@ jobs: run: | go test --race -v -count=1 -coverprofile=coverage/internal.out ./internal/... - - name: Run invitations tests - if: steps.changes.outputs.invitations == 'true' || steps.changes.outputs.workflow == 'true' - run: | - go test --race -v -count=1 -coverprofile=coverage/invitations.out ./invitations/... - - name: Run logger tests if: steps.changes.outputs.logger == 'true' || steps.changes.outputs.workflow == 'true' run: | diff --git a/Makefile b/Makefile index ad4af7929..5318a3c3b 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ SMQ_DOCKER_IMAGE_NAME_PREFIX ?= supermq BUILD_DIR ?= build -SERVICES = auth users clients groups channels domains http coap ws cli mqtt certs invitations journal -TEST_API_SERVICES = journal auth certs http invitations clients users channels groups domains +SERVICES = auth users clients groups channels domains http coap ws cli mqtt certs journal +TEST_API_SERVICES = journal auth certs http clients users channels groups domains TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES)) DOCKERS = $(addprefix docker_,$(SERVICES)) DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES)) @@ -172,7 +172,6 @@ test_api_domains: TEST_API_URL := http://localhost:9003 test_api_channels: TEST_API_URL := http://localhost:9005 test_api_groups: TEST_API_URL := http://localhost:9004 test_api_http: TEST_API_URL := http://localhost:8008 -test_api_invitations: TEST_API_URL := http://localhost:9020 test_api_auth: TEST_API_URL := http://localhost:9001 test_api_certs: TEST_API_URL := http://localhost:9019 test_api_journal: TEST_API_URL := http://localhost:9021 diff --git a/apidocs/openapi/README.md b/apidocs/openapi/README.md index b2aa8eb6f..a5c61cf65 100644 --- a/apidocs/openapi/README.md +++ b/apidocs/openapi/README.md @@ -2,4 +2,4 @@ This folder contains an OpenAPI specifications for SuperMQ API. -View specification in Swagger UI at [docs.api.magistrala.abstractmachines.fr](https://docs.api.supermq.abstractmachines.fr) \ No newline at end of file +View specification in Swagger UI at [docs.api.magistrala.abstractmachines.fr](https://docs.api.supermq.abstractmachines.fr) diff --git a/apidocs/openapi/domains.yml b/apidocs/openapi/domains.yml index 02f63f9ac..283661bb3 100644 --- a/apidocs/openapi/domains.yml +++ b/apidocs/openapi/domains.yml @@ -26,10 +26,15 @@ tags: description: Find out more about domains url: https://docs.magistrala.abstractmachines.fr/ - name: Roles - description: All operations involving roles for clients + description: All operations involving roles for domains externalDocs: - description: Find out more about roles - url: https://docs.supermq.abstractmachines.fr/ + description: Find out more about roles + url: https://docs.supermq.abstractmachines.fr/ + - name: Invitations + description: All operations involving invitations for domains + externalDocs: + description: Find out more about Invitations + url: https://docs.supermq.abstractmachines.fr/ - name: Health description: Service health check endpoint. externalDocs: @@ -242,7 +247,7 @@ paths: - bearerAuth: [] responses: "201": - $ref: "./schemas/roles.yml#/components/responses/CreateRoleRes" + $ref: "./schemas/roles.yml#/components/responses/CreateRoleRes" "400": description: Failed due to malformed domain's ID. "401": @@ -683,6 +688,216 @@ paths: "500": $ref: "#/components/responses/ServiceError" + /domains/{domainID}/invitations: + post: + operationId: sendInvitation + tags: + - Invitations + summary: Send invitation + description: | + Send invitation to user to join domain. + parameters: + - $ref: "#/components/parameters/DomainID" + requestBody: + $ref: "#/components/requestBodies/SendInvitationReq" + security: + - bearerAuth: [] + responses: + "201": + description: Invitation sent. + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "409": + description: Failed due to using an existing identity. + "415": + description: Missing or invalid content type. + "500": + $ref: "#/components/responses/ServiceError" + + get: + operationId: listDomainInvitations + tags: + - Invitations + summary: List domain invitations + description: | + Retrieves a list of invitations for a given domain. Due to performance concerns, data + is retrieved in subsets. The API must ensure that the entire + dataset is consumed either by making subsequent requests, or by + increasing the subset size of the initial request. + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/user_id" + - $ref: "#/components/parameters/InvitedBy" + - $ref: "#/components/parameters/State" + security: + - bearerAuth: [] + responses: + "200": + $ref: "#/components/responses/InvitationPageRes" + "400": + description: Failed due to malformed query parameters. + "401": + description: | + Missing or invalid access token provided. + This endpoint is available only for administrators. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /domains/{domainID}/invitations/{userID}: + get: + operationId: getInvitation + summary: Retrieves a specific invitation + description: | + Retrieves a specific invitation that is identifier by the user ID and domain ID. + tags: + - Invitations + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/UserID" + security: + - bearerAuth: [] + responses: + "200": + $ref: "#/components/responses/InvitationRes" + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + delete: + operationId: deleteInvitation + summary: Deletes a specific invitation + description: | + Deletes a specific invitation that is identifier by the user ID and domain ID. + tags: + - Invitations + parameters: + - $ref: "#/components/parameters/DomainID" + - $ref: "#/components/parameters/UserID" + security: + - bearerAuth: [] + responses: + "204": + description: Invitation deleted. + "400": + description: Failed due to malformed JSON. + "403": + description: Failed to perform authorization over the entity. + "404": + description: Failed due to non existing user. + "401": + description: Missing or invalid access token provided. + "500": + $ref: "#/components/responses/ServiceError" + + /invitations: + get: + operationId: listUserInvitations + tags: + - Invitations + summary: List user invitations + description: | + Retrieves a list of invitations for the current user. Due to performance concerns, data + is retrieved in subsets. The API must ensure that the entire + dataset is consumed either by making subsequent requests, or by + increasing the subset size of the initial request. + parameters: + - $ref: "#/components/parameters/domain_id" + - $ref: "#/components/parameters/Limit" + - $ref: "#/components/parameters/Offset" + - $ref: "#/components/parameters/user_id" + - $ref: "#/components/parameters/InvitedBy" + - $ref: "#/components/parameters/State" + security: + - bearerAuth: [] + responses: + "200": + $ref: "#/components/responses/InvitationPageRes" + "400": + description: Failed due to malformed query parameters. + "401": + description: | + Missing or invalid access token provided. + This endpoint is available only for administrators. + "403": + description: Failed to perform authorization over the entity. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /invitations/accept: + post: + operationId: acceptInvitation + summary: Accept invitation + description: | + Current logged in user accepts invitation to join domain. + tags: + - Invitations + security: + - bearerAuth: [] + requestBody: + $ref: "#/components/requestBodies/AcceptInvitationReq" + responses: + "204": + description: Invitation accepted. + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "500": + $ref: "#/components/responses/ServiceError" + + /invitations/reject: + post: + operationId: rejectInvitation + summary: Reject invitation + description: | + Current logged in user rejects invitation to join domain. + tags: + - Invitations + security: + - bearerAuth: [] + requestBody: + $ref: "#/components/requestBodies/AcceptInvitationReq" + responses: + "204": + description: Invitation rejected. + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "500": + $ref: "#/components/responses/ServiceError" + /health: get: summary: Retrieves service health check info. @@ -824,6 +1039,107 @@ components: example: domain alias description: Domain alias. + SendInvitationReqObj: + type: object + properties: + invitee_user_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: User unique identifier. + role_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: Identifier for the role to be assigned to the user. + required: + - invitee_user_id + - role_id + + Invitation: + type: object + properties: + invited_by: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: User unique identifier. + invitee_user_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: Invitee user unique identifier. + domain_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: Domain unique identifier. + role_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: Role unique identifier. + role_name: + type: string + example: editor + description: Role name. + actions: + type: array + items: + type: string + example: ["read", "write"] + description: Actions allowed for the role. + created_at: + type: string + format: date-time + example: "2019-11-26 13:31:52" + description: Time when the group was created. + updated_at: + type: string + format: date-time + example: "2019-11-26 13:31:52" + description: Time when the group was created. + confirmed_at: + type: string + format: date-time + example: "2019-11-26 13:31:52" + description: Time when the group was created. + xml: + name: invitation + + InvitationPage: + type: object + properties: + invitations: + type: array + minItems: 0 + uniqueItems: true + items: + $ref: "#/components/schemas/Invitation" + total: + type: integer + example: 1 + description: Total number of items. + offset: + type: integer + description: Number of items to skip during retrieval. + limit: + type: integer + example: 10 + description: Maximum number of items to return in one page. + required: + - invitations + - total + - offset + + Error: + type: object + properties: + error: + type: string + description: Error message + example: { "error": "malformed entity specification" } + parameters: DomainID: name: domainID @@ -910,6 +1226,61 @@ components: schema: type: string required: false + UserID: + name: userID + description: Unique user identifier. + in: path + schema: + type: string + format: uuid + required: true + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + user_id: + name: user_id + description: Unique user identifier. + in: query + schema: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + domain_id: + name: domain_id + description: Unique identifier for a domain. + in: query + schema: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + InvitedBy: + name: invited_by + description: Unique identifier for a user that invited the user. + in: query + schema: + type: string + format: uuid + required: false + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + State: + name: state + description: Invitation state. + in: query + schema: + type: string + enum: + - pending + - accepted + - all + required: false + example: accepted + RoleID: + name: roleID + description: Unique role identifier. + in: query + schema: + type: string + format: uuid + required: true + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 requestBodies: DomainCreateReq: @@ -926,6 +1297,28 @@ components: application/json: schema: $ref: "#/components/schemas/DomainUpdate" + SendInvitationReq: + description: JSON-formatted document describing request for sending invitation + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SendInvitationReqObj" + AcceptInvitationReq: + description: JSON-formatted document describing request for accepting invitation + required: true + content: + application/json: + schema: + type: object + properties: + domain_id: + type: string + format: uuid + example: bb7edb32-2eac-4aad-aebe-ed96fe073879 + description: Domain unique identifier. + required: + - domain_id responses: ServiceError: @@ -956,6 +1349,25 @@ components: application/json: schema: $ref: "#/components/schemas/DomainsPage" + InvitationRes: + description: Data retrieved. + content: + application/json: + schema: + $ref: "#/components/schemas/Invitation" + links: + delete: + operationId: deleteInvitation + parameters: + user_id: $response.body#/user_id + domain_id: $response.body#/domain_id + + InvitationPageRes: + description: Data retrieved. + content: + application/json: + schema: + $ref: "#/components/schemas/InvitationPage" HealthRes: description: Service Health Check. content: diff --git a/apidocs/openapi/invitations.yml b/apidocs/openapi/invitations.yml deleted file mode 100644 index 10cc2adc3..000000000 --- a/apidocs/openapi/invitations.yml +++ /dev/null @@ -1,537 +0,0 @@ -# Copyright (c) Abstract Machines -# SPDX-License-Identifier: Apache-2.0 - -openapi: 3.0.3 -info: - title: SuperMQ Invitations Service - description: | - This is the Invitations Server based on the OpenAPI 3.0 specification. It is the HTTP API for managing platform invitations. You can now help us improve the API whether it's by making changes to the definition itself or to the code. - Some useful links: - - [The SuperMQ repository](https://github.com/absmach/supermq) - contact: - email: info@abstractmachines.fr - license: - name: Apache 2.0 - url: https://github.com/absmach/supermq/blob/main/LICENSE - version: 0.15.1 - -servers: - - url: http://localhost:9020 - - url: https://localhost:9020 - -tags: - - name: Invitations - description: Everything about your Invitations - externalDocs: - description: Find out more about Invitations - url: https://docs.supermq.abstractmachines.fr/ - -paths: - /invitations: - post: - operationId: sendInvitation - tags: - - Invitations - summary: Send invitation - description: | - Send invitation to user to join domain. - requestBody: - $ref: "#/components/requestBodies/SendInvitationReq" - security: - - bearerAuth: [] - responses: - "201": - description: Invitation sent. - "400": - description: Failed due to malformed JSON. - "401": - description: Missing or invalid access token provided. - "403": - description: Failed to perform authorization over the entity. - "404": - description: A non-existent entity request. - "409": - description: Failed due to using an existing identity. - "415": - description: Missing or invalid content type. - "500": - $ref: "#/components/responses/ServiceError" - - get: - operationId: listInvitations - tags: - - Invitations - summary: List invitations - description: | - Retrieves a list of invitations. Due to performance concerns, data - is retrieved in subsets. The API must ensure that the entire - dataset is consumed either by making subsequent requests, or by - increasing the subset size of the initial request. - parameters: - - $ref: "#/components/parameters/Limit" - - $ref: "#/components/parameters/Offset" - - $ref: "#/components/parameters/UserID" - - $ref: "#/components/parameters/InvitedBy" - - $ref: "#/components/parameters/DomainID" - - $ref: "#/components/parameters/Relation" - - $ref: "#/components/parameters/State" - security: - - bearerAuth: [] - responses: - "200": - $ref: "#/components/responses/InvitationPageRes" - "400": - description: Failed due to malformed query parameters. - "401": - description: | - Missing or invalid access token provided. - This endpoint is available only for administrators. - "403": - description: Failed to perform authorization over the entity. - "404": - description: A non-existent entity request. - "422": - description: Database can't process request. - "500": - $ref: "#/components/responses/ServiceError" - - /invitations/accept: - post: - operationId: acceptInvitation - summary: Accept invitation - description: | - Current logged in user accepts invitation to join domain. - tags: - - Invitations - security: - - bearerAuth: [] - requestBody: - $ref: "#/components/requestBodies/AcceptInvitationReq" - responses: - "204": - description: Invitation accepted. - "400": - description: Failed due to malformed query parameters. - "401": - description: Missing or invalid access token provided. - "404": - description: A non-existent entity request. - "500": - $ref: "#/components/responses/ServiceError" - - /invitations/reject: - post: - operationId: rejectInvitation - summary: Reject invitation - description: | - Current logged in user rejects invitation to join domain. - tags: - - Invitations - security: - - bearerAuth: [] - requestBody: - $ref: "#/components/requestBodies/AcceptInvitationReq" - responses: - "204": - description: Invitation rejected. - "400": - description: Failed due to malformed query parameters. - "401": - description: Missing or invalid access token provided. - "404": - description: A non-existent entity request. - "500": - $ref: "#/components/responses/ServiceError" - - /invitations/{user_id}/{domain_id}: - get: - operationId: getInvitation - summary: Retrieves a specific invitation - description: | - Retrieves a specific invitation that is identifier by the user ID and domain ID. - tags: - - Invitations - parameters: - - $ref: "#/components/parameters/user_id" - - $ref: "#/components/parameters/domain_id" - security: - - bearerAuth: [] - responses: - "200": - $ref: "#/components/responses/InvitationRes" - "400": - description: Failed due to malformed query parameters. - "401": - description: Missing or invalid access token provided. - "403": - description: Failed to perform authorization over the entity. - "404": - description: A non-existent entity request. - "422": - description: Database can't process request. - "500": - $ref: "#/components/responses/ServiceError" - - delete: - operationId: deleteInvitation - summary: Deletes a specific invitation - description: | - Deletes a specific invitation that is identifier by the user ID and domain ID. - tags: - - Invitations - parameters: - - $ref: "#/components/parameters/user_id" - - $ref: "#/components/parameters/domain_id" - security: - - bearerAuth: [] - responses: - "204": - description: Invitation deleted. - "400": - description: Failed due to malformed JSON. - "403": - description: Failed to perform authorization over the entity. - "404": - description: Failed due to non existing user. - "401": - description: Missing or invalid access token provided. - "500": - $ref: "#/components/responses/ServiceError" - - /health: - get: - summary: Retrieves service health check info. - tags: - - health - security: [] - responses: - "200": - $ref: "#/components/responses/HealthRes" - "500": - $ref: "#/components/responses/ServiceError" - -components: - schemas: - SendInvitationReqObj: - type: object - properties: - user_id: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: User unique identifier. - domain_id: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: Domain unique identifier. - relation: - type: string - enum: - - administrator - - editor - - contributor - - member - - guest - - domain - - parent_group - - role_group - - group - - platform - example: editor - description: Relation between user and domain. - resend: - type: boolean - example: true - description: Resend invitation. - required: - - user_id - - domain_id - - relation - - Invitation: - type: object - properties: - invited_by: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: User unique identifier. - user_id: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: User unique identifier. - domain_id: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: Domain unique identifier. - relation: - type: string - enum: - - administrator - - editor - - contributor - - member - - guest - - domain - - parent_group - - role_group - - group - - platform - example: editor - description: Relation between user and domain. - created_at: - type: string - format: date-time - example: "2019-11-26 13:31:52" - description: Time when the group was created. - updated_at: - type: string - format: date-time - example: "2019-11-26 13:31:52" - description: Time when the group was created. - confirmed_at: - type: string - format: date-time - example: "2019-11-26 13:31:52" - description: Time when the group was created. - xml: - name: invitation - - InvitationPage: - type: object - properties: - invitations: - type: array - minItems: 0 - uniqueItems: true - items: - $ref: "#/components/schemas/Invitation" - total: - type: integer - example: 1 - description: Total number of items. - offset: - type: integer - description: Number of items to skip during retrieval. - limit: - type: integer - example: 10 - description: Maximum number of items to return in one page. - required: - - invitations - - total - - offset - - Error: - type: object - properties: - error: - type: string - description: Error message - example: { "error": "malformed entity specification" } - - HealthRes: - type: object - properties: - status: - type: string - description: Service status. - enum: - - pass - version: - type: string - description: Service version. - example: 0.14.0 - commit: - type: string - description: Service commit hash. - example: 7d6f4dc4f7f0c1fa3dc24eddfb18bb5073ff4f62 - description: - type: string - description: Service description. - example: service - build_time: - type: string - description: Service build time. - example: 1970-01-01_00:00:00 - - parameters: - Offset: - name: offset - description: Number of items to skip during retrieval. - in: query - schema: - type: integer - default: 0 - minimum: 0 - required: false - example: "0" - - Limit: - name: limit - description: Size of the subset to retrieve. - in: query - schema: - type: integer - default: 10 - maximum: 10 - minimum: 1 - required: false - example: "10" - - UserID: - name: user_id - description: Unique user identifier. - in: query - schema: - type: string - format: uuid - required: true - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - - user_id: - name: user_id - description: Unique user identifier. - in: path - schema: - type: string - format: uuid - required: true - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - - DomainID: - name: domain_id - description: Unique identifier for a domain. - in: query - schema: - type: string - format: uuid - required: false - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - - domain_id: - name: domain_id - description: Unique identifier for a domain. - in: path - schema: - type: string - format: uuid - required: true - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - - InvitedBy: - name: invited_by - description: Unique identifier for a user that invited the user. - in: query - schema: - type: string - format: uuid - required: false - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - - Relation: - name: relation - description: Relation between user and domain. - in: query - schema: - type: string - enum: - - administrator - - editor - - contributor - - member - - guest - - domain - - parent_group - - role_group - - group - - platform - required: false - example: editor - - State: - name: state - description: Invitation state. - in: query - schema: - type: string - enum: - - pending - - accepted - - all - required: false - example: accepted - - requestBodies: - SendInvitationReq: - description: JSON-formatted document describing request for sending invitation - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/SendInvitationReqObj" - - AcceptInvitationReq: - description: JSON-formatted document describing request for accepting invitation - required: true - content: - application/json: - schema: - type: object - properties: - domain_id: - type: string - format: uuid - example: bb7edb32-2eac-4aad-aebe-ed96fe073879 - description: Domain unique identifier. - required: - - domain_id - - responses: - InvitationRes: - description: Data retrieved. - content: - application/json: - schema: - $ref: "#/components/schemas/Invitation" - links: - delete: - operationId: deleteInvitation - parameters: - user_id: $response.body#/user_id - domain_id: $response.body#/domain_id - - InvitationPageRes: - description: Data retrieved. - content: - application/json: - schema: - $ref: "#/components/schemas/InvitationPage" - - HealthRes: - description: Service Health Check. - content: - application/health+json: - schema: - $ref: "#/components/schemas/HealthRes" - - ServiceError: - description: Unexpected server-side error occurred. - content: - application/json: - schema: - $ref: "#/components/schemas/Error" - - securitySchemes: - bearerAuth: - type: http - scheme: bearer - bearerFormat: JWT - description: | - * User access: "Authorization: Bearer " - -security: - - bearerAuth: [] diff --git a/cli/config.go b/cli/config.go index d92849963..f94870b89 100644 --- a/cli/config.go +++ b/cli/config.go @@ -25,7 +25,6 @@ const ( defChannelsURL string = defURL + ":9005" defGroupsURL string = defURL + ":9004" defCertsURL string = defURL + ":9019" - defInvitationsURL string = defURL + ":9020" defHTTPURL string = defURL + ":8008" defJournalURL string = defURL + ":9021" defTLSVerification bool = false @@ -43,7 +42,6 @@ type remotes struct { GroupsURL string `toml:"groups_url"` HTTPAdapterURL string `toml:"http_adapter_url"` CertsURL string `toml:"certs_url"` - InvitationsURL string `toml:"invitations_url"` JournalURL string `toml:"journal_url"` HostURL string `toml:"host_url"` TLSVerification bool `toml:"tls_verification"` @@ -114,7 +112,6 @@ func ParseConfig(sdkConf smqsdk.Config) (smqsdk.Config, error) { GroupsURL: defGroupsURL, HTTPAdapterURL: defHTTPURL, CertsURL: defCertsURL, - InvitationsURL: defInvitationsURL, JournalURL: defJournalURL, HostURL: defURL, TLSVerification: defTLSVerification, @@ -199,10 +196,6 @@ func ParseConfig(sdkConf smqsdk.Config) (smqsdk.Config, error) { sdkConf.CertsURL = config.Remotes.CertsURL } - if sdkConf.InvitationsURL == "" && config.Remotes.InvitationsURL != "" { - sdkConf.InvitationsURL = config.Remotes.InvitationsURL - } - if sdkConf.JournalURL == "" && config.Remotes.JournalURL != "" { sdkConf.JournalURL = config.Remotes.JournalURL } diff --git a/cli/invitations.go b/cli/invitations.go index 2afc36a36..524cd499c 100644 --- a/cli/invitations.go +++ b/cli/invitations.go @@ -10,20 +10,20 @@ import ( var cmdInvitations = []cobra.Command{ { - Use: "send ", + Use: "send ", Short: "Send invitation", Long: "Send invitation to user\n" + "For example:\n" + - "\tsupermq-cli invitations send 39f97daf-d6b6-40f4-b229-2697be8006ef 4ef09eff-d500-4d56-b04f-d23a512d6f2a administrator $USER_AUTH_TOKEN\n", + "\tsupermq-cli invitations send 39f97daf-d6b6-40f4-b229-2697be8006ef 4ef09eff-d500-4d56-b04f-d23a512d6f2a ba4c904c-e6d4-4978-9417-1694aac6793e $USER_AUTH_TOKEN\n", Run: func(cmd *cobra.Command, args []string) { if len(args) != 4 { logUsageCmd(*cmd, cmd.Use) return } inv := smqsdk.Invitation{ - UserID: args[0], - DomainID: args[1], - Relation: args[2], + InviteeUserID: args[0], + DomainID: args[1], + RoleID: args[2], } if err := sdk.SendInvitation(inv, args[3]); err != nil { logErrorCmd(*cmd, err) diff --git a/cli/invitations_test.go b/cli/invitations_test.go index 2768b7b21..cee686791 100644 --- a/cli/invitations_test.go +++ b/cli/invitations_test.go @@ -21,9 +21,9 @@ import ( ) var invitation = mgsdk.Invitation{ - InvitedBy: testsutil.GenerateUUID(&testing.T{}), - UserID: user.ID, - DomainID: domain.ID, + InvitedBy: testsutil.GenerateUUID(&testing.T{}), + InviteeUserID: user.ID, + DomainID: domain.ID, } func TestSendUserInvitationCmd(t *testing.T) { diff --git a/cmd/channels/main.go b/cmd/channels/main.go index a3f20e908..8b9c56c08 100644 --- a/cmd/channels/main.go +++ b/cmd/channels/main.go @@ -233,7 +233,7 @@ func main() { } ddatabase := pg.NewDatabase(db, dbConfig, tracer) - drepo := dpostgres.New(ddatabase) + drepo := dpostgres.NewRepository(ddatabase) if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 029e98751..8a6faa9f6 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -100,14 +100,6 @@ func main() { "HTTP adapter URL", ) - rootCmd.PersistentFlags().StringVarP( - &sdkConf.InvitationsURL, - "invitations-url", - "v", - sdkConf.InvitationsURL, - "Inivitations URL", - ) - rootCmd.PersistentFlags().StringVarP( &sdkConf.JournalURL, "journal-url", diff --git a/cmd/clients/main.go b/cmd/clients/main.go index 55a065a2d..b3030976c 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -250,7 +250,7 @@ func main() { } ddatabase := pg.NewDatabase(db, dbConfig, tracer) - drepo := dpostgres.New(ddatabase) + drepo := dpostgres.NewRepository(ddatabase) if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) diff --git a/cmd/domains/main.go b/cmd/domains/main.go index 4f98d9fe2..0a4195380 100644 --- a/cmd/domains/main.go +++ b/cmd/domains/main.go @@ -159,7 +159,7 @@ func main() { logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure()) database := postgres.NewDatabase(db, dbConfig, tracer) - domainsRepo := dpostgres.New(database) + domainsRepo := dpostgres.NewRepository(database) cacheclient, err := redisclient.Connect(cfg.CacheURL) if err != nil { diff --git a/cmd/groups/main.go b/cmd/groups/main.go index a41520385..a4fa60d32 100644 --- a/cmd/groups/main.go +++ b/cmd/groups/main.go @@ -230,7 +230,7 @@ func main() { } ddatabase := pg.NewDatabase(db, dbConfig, tracer) - drepo := dpostgres.New(ddatabase) + drepo := dpostgres.NewRepository(ddatabase) if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { logger.Error(fmt.Sprintf("failed to create domains event store : %s", err)) diff --git a/cmd/invitations/main.go b/cmd/invitations/main.go deleted file mode 100644 index f0b6e02bb..000000000 --- a/cmd/invitations/main.go +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package main contains invitations main function to start the invitations service. -package main - -import ( - "context" - "fmt" - "log" - "log/slog" - "net/url" - "os" - - chclient "github.com/absmach/callhome/pkg/client" - "github.com/absmach/supermq" - grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" - "github.com/absmach/supermq/invitations" - httpapi "github.com/absmach/supermq/invitations/api" - "github.com/absmach/supermq/invitations/middleware" - invitationspg "github.com/absmach/supermq/invitations/postgres" - smqlog "github.com/absmach/supermq/logger" - authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" - smqauthz "github.com/absmach/supermq/pkg/authz" - authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" - domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" - "github.com/absmach/supermq/pkg/grpcclient" - "github.com/absmach/supermq/pkg/jaeger" - "github.com/absmach/supermq/pkg/postgres" - clientspg "github.com/absmach/supermq/pkg/postgres" - "github.com/absmach/supermq/pkg/prometheus" - mgsdk "github.com/absmach/supermq/pkg/sdk" - "github.com/absmach/supermq/pkg/server" - "github.com/absmach/supermq/pkg/server/http" - "github.com/absmach/supermq/pkg/uuid" - "github.com/caarlos0/env/v11" - "github.com/jmoiron/sqlx" - "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" -) - -const ( - svcName = "invitations" - envPrefixDB = "SMQ_INVITATIONS_DB_" - envPrefixHTTP = "SMQ_INVITATIONS_HTTP_" - envPrefixAuth = "SMQ_AUTH_GRPC_" - envPrefixDomains = "SMQ_DOMAINS_GRPC_" - defDB = "invitations" - defSvcHTTPPort = "9020" -) - -type config struct { - LogLevel string `env:"SMQ_INVITATIONS_LOG_LEVEL" envDefault:"info"` - UsersURL string `env:"SMQ_USERS_URL" envDefault:"http://localhost:9002"` - DomainsURL string `env:"SMQ_DOMAINS_URL" envDefault:"http://localhost:9003"` - InstanceID string `env:"SMQ_INVITATIONS_INSTANCE_ID" envDefault:""` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` -} - -func main() { - ctx, cancel := context.WithCancel(context.Background()) - g, ctx := errgroup.WithContext(ctx) - - cfg := config{} - if err := env.Parse(&cfg); err != nil { - log.Fatalf("failed to load %s configuration : %s", svcName, err) - } - - logger, err := smqlog.New(os.Stdout, cfg.LogLevel) - if err != nil { - log.Fatalf("failed to init logger: %s", err.Error()) - } - - var exitCode int - defer smqlog.ExitWithError(&exitCode) - - if cfg.InstanceID == "" { - if cfg.InstanceID, err = uuid.New().ID(); err != nil { - logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err)) - exitCode = 1 - return - } - } - - dbConfig := clientspg.Config{Name: defDB} - if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s database configuration : %s", svcName, err)) - exitCode = 1 - return - } - db, err := clientspg.Setup(dbConfig, *invitationspg.Migration()) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer db.Close() - - authClientCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&authClientCfg, env.Options{Prefix: envPrefixAuth}); err != nil { - logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err.Error())) - exitCode = 1 - return - } - tokenClient, tokenHandler, err := grpcclient.SetupTokenClient(ctx, authClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer tokenHandler.Close() - logger.Info("Token service client successfully connected to auth gRPC server " + tokenHandler.Secure()) - - authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer authnHandler.Close() - logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure()) - - domsGrpcCfg := grpcclient.Config{} - if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { - logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) - exitCode = 1 - return - } - domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer domainsHandler.Close() - - authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authClientCfg, domAuthz) - if err != nil { - logger.Error(err.Error()) - exitCode = 1 - return - } - defer authzHandler.Close() - logger.Info("Authz successfully connected to auth gRPC server " + authzHandler.Secure()) - - tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) - if err != nil { - logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err)) - exitCode = 1 - return - } - defer func() { - if err := tp.Shutdown(ctx); err != nil { - logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err)) - } - }() - tracer := tp.Tracer(svcName) - - svc, err := newService(db, dbConfig, authz, tokenClient, tracer, cfg, logger) - if err != nil { - logger.Error(fmt.Sprintf("failed to create %s service: %s", svcName, err)) - exitCode = 1 - return - } - - httpServerConfig := server.Config{Port: defSvcHTTPPort} - if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { - logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) - exitCode = 1 - return - } - - httpSvr := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, authn, cfg.InstanceID), logger) - - if cfg.SendTelemetry { - chc := chclient.New(svcName, supermq.Version, logger, cancel) - go chc.CallHome(ctx) - } - - g.Go(func() error { - return httpSvr.Start() - }) - - g.Go(func() error { - return server.StopSignalHandler(ctx, cancel, logger, svcName, httpSvr) - }) - - if err := g.Wait(); err != nil { - logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err)) - } -} - -func newService(db *sqlx.DB, dbConfig clientspg.Config, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, tracer trace.Tracer, conf config, logger *slog.Logger) (invitations.Service, error) { - database := postgres.NewDatabase(db, dbConfig, tracer) - repo := invitationspg.NewRepository(database) - - config := mgsdk.Config{ - UsersURL: conf.UsersURL, - DomainsURL: conf.DomainsURL, - } - sdk := mgsdk.NewSDK(config) - - svc := invitations.NewService(token, repo, sdk) - svc = middleware.AuthorizationMiddleware(authz, svc) - svc = middleware.Tracing(svc, tracer) - svc = middleware.Logging(logger, svc) - counter, latency := prometheus.MakeMetrics(svcName, "api") - svc = middleware.Metrics(counter, latency, svc) - - return svc, nil -} diff --git a/config.toml b/config.toml index 387e459d8..cbfa63773 100644 --- a/config.toml +++ b/config.toml @@ -17,7 +17,6 @@ user_token = "" groups_url = "http://localhost:9004" host_url = "http://localhost" http_adapter_url = "http://localhost:8008" - invitations_url = "http://localhost:9020" journal_url = "http://localhost:9021" tls_verification = false users_url = "http://localhost:9002" diff --git a/docker/.env b/docker/.env index b4e293c1d..7e006a36e 100644 --- a/docker/.env +++ b/docker/.env @@ -143,23 +143,6 @@ SMQ_SPICEDB_HOST=supermq-spicedb SMQ_SPICEDB_PORT=50051 SMQ_SPICEDB_DATASTORE_ENGINE=postgres -### Invitations -SMQ_INVITATIONS_LOG_LEVEL=info -SMQ_INVITATIONS_HTTP_HOST=invitations -SMQ_INVITATIONS_HTTP_PORT=9020 -SMQ_INVITATIONS_HTTP_SERVER_CERT= -SMQ_INVITATIONS_HTTP_SERVER_KEY= -SMQ_INVITATIONS_DB_HOST=invitations-db -SMQ_INVITATIONS_DB_PORT=5432 -SMQ_INVITATIONS_DB_USER=supermq -SMQ_INVITATIONS_DB_PASS=supermq -SMQ_INVITATIONS_DB_NAME=invitations -SMQ_INVITATIONS_DB_SSL_MODE=disable -SMQ_INVITATIONS_DB_SSL_CERT= -SMQ_INVITATIONS_DB_SSL_KEY= -SMQ_INVITATIONS_DB_SSL_ROOT_CERT= -SMQ_INVITATIONS_INSTANCE_ID= - ### UI SMQ_UI_LOG_LEVEL=debug SMQ_UI_PORT=9095 diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1f527d735..476cef3bc 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -19,7 +19,6 @@ volumes: supermq-pat-db-volume: supermq-domains-db-volume: supermq-domains-redis-volume: - supermq-invitations-db-volume: supermq-ui-db-volume: services: @@ -293,83 +292,6 @@ services: bind: create_host_path: true - invitations-db: - image: postgres:16.2-alpine - container_name: supermq-invitations-db - restart: on-failure - command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}" - environment: - POSTGRES_USER: ${SMQ_INVITATIONS_DB_USER} - POSTGRES_PASSWORD: ${SMQ_INVITATIONS_DB_PASS} - POSTGRES_DB: ${SMQ_INVITATIONS_DB_NAME} - SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS} - ports: - - 6021:5432 - networks: - - supermq-base-net - volumes: - - supermq-invitations-db-volume:/var/lib/postgresql/data - - invitations: - image: supermq/invitations:${SMQ_RELEASE_TAG} - container_name: supermq-invitations - restart: on-failure - depends_on: - - auth - - invitations-db - environment: - SMQ_INVITATIONS_LOG_LEVEL: ${SMQ_INVITATIONS_LOG_LEVEL} - SMQ_USERS_URL: ${SMQ_USERS_URL} - SMQ_DOMAINS_URL: ${SMQ_DOMAINS_URL} - SMQ_INVITATIONS_HTTP_HOST: ${SMQ_INVITATIONS_HTTP_HOST} - SMQ_INVITATIONS_HTTP_PORT: ${SMQ_INVITATIONS_HTTP_PORT} - SMQ_INVITATIONS_HTTP_SERVER_CERT: ${SMQ_INVITATIONS_HTTP_SERVER_CERT} - SMQ_INVITATIONS_HTTP_SERVER_KEY: ${SMQ_INVITATIONS_HTTP_SERVER_KEY} - SMQ_INVITATIONS_DB_HOST: ${SMQ_INVITATIONS_DB_HOST} - SMQ_INVITATIONS_DB_USER: ${SMQ_INVITATIONS_DB_USER} - SMQ_INVITATIONS_DB_PASS: ${SMQ_INVITATIONS_DB_PASS} - SMQ_INVITATIONS_DB_PORT: ${SMQ_INVITATIONS_DB_PORT} - SMQ_INVITATIONS_DB_NAME: ${SMQ_INVITATIONS_DB_NAME} - SMQ_INVITATIONS_DB_SSL_MODE: ${SMQ_INVITATIONS_DB_SSL_MODE} - SMQ_INVITATIONS_DB_SSL_CERT: ${SMQ_INVITATIONS_DB_SSL_CERT} - SMQ_INVITATIONS_DB_SSL_KEY: ${SMQ_INVITATIONS_DB_SSL_KEY} - SMQ_INVITATIONS_DB_SSL_ROOT_CERT: ${SMQ_INVITATIONS_DB_SSL_ROOT_CERT} - SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL} - SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT} - SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt} - SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key} - SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt} - SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL} - SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT} - SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} - SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} - SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} - SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} - SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} - SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} - SMQ_INVITATIONS_INSTANCE_ID: ${SMQ_INVITATIONS_INSTANCE_ID} - ports: - - ${SMQ_INVITATIONS_HTTP_PORT}:${SMQ_INVITATIONS_HTTP_PORT} - networks: - - supermq-base-net - volumes: - # Auth gRPC client certificates - - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} - target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt} - bind: - create_host_path: true - - type: bind - source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key} - target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key} - bind: - create_host_path: true - - type: bind - source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca} - target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt} - bind: - create_host_path: true - nginx: image: nginx:1.25.4-alpine container_name: supermq-nginx diff --git a/docker/nginx/entrypoint.sh b/docker/nginx/entrypoint.sh index c1a3e0a6b..66b2c51d3 100755 --- a/docker/nginx/entrypoint.sh +++ b/docker/nginx/entrypoint.sh @@ -22,7 +22,6 @@ envsubst ' ${SMQ_HTTP_ADAPTER_PORT} ${SMQ_NGINX_MQTT_PORT} ${SMQ_NGINX_MQTTS_PORT} - ${SMQ_INVITATIONS_HTTP_PORT} ${SMQ_WS_ADAPTER_HTTP_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf exec nginx -g "daemon off;" diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index 9779466dc..2b4d38368 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -85,20 +85,13 @@ http { proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; } - # Proxy pass to domains service + # Proxy pass to channels service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT}; } - # Proxy pass to invitations service - location ~ ^/(invitations) { - include snippets/proxy-headers.conf; - add_header Access-Control-Expose-Headers Location; - proxy_pass http://invitations:${SMQ_INVITATIONS_HTTP_PORT}; - } - location /health { include snippets/proxy-headers.conf; proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index 5d913a5e3..3514898e8 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -94,20 +94,13 @@ http { proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; } - # Proxy pass to domains service + # Proxy pass to channels service location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT}; } - # Proxy pass to invitations service - location ~ ^/(invitations) { - include snippets/proxy-headers.conf; - add_header Access-Control-Expose-Headers Location; - proxy_pass http://invitations:${SMQ_INVITATIONS_HTTP_PORT}; - } - location /health { include snippets/proxy-headers.conf; proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT}; diff --git a/domains/api/http/decode.go b/domains/api/http/decode.go index 13999acf1..82e625424 100644 --- a/domains/api/http/decode.go +++ b/domains/api/http/decode.go @@ -16,6 +16,15 @@ import ( "github.com/go-chi/chi/v5" ) +const ( + inviteeUserIDKey = "invitee_user_id" + domainIDKey = "domain_id" + invitedByKey = "invited_by" + roleIDKey = "role_id" + roleNameKey = "role_name" + stateKey = "state" +) + func decodeCreateDomainRequest(_ context.Context, r *http.Request) (interface{}, error) { if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) @@ -157,3 +166,86 @@ func decodePageRequest(_ context.Context, r *http.Request) (page, error) { status: st, }, nil } + +func decodeSendInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + var req sendInvitationReq + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) + } + + return req, nil +} + +func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, error) { + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + inviteeUserID, err := apiutil.ReadStringQuery(r, inviteeUserIDKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + invitedBy, err := apiutil.ReadStringQuery(r, invitedByKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + roleID, err := apiutil.ReadStringQuery(r, roleIDKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + domainID, err := apiutil.ReadStringQuery(r, domainIDKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + st, err := apiutil.ReadStringQuery(r, stateKey, domains.AllState.String()) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + state, err := domains.ToState(st) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + req := listInvitationsReq{ + InvitationPageMeta: domains.InvitationPageMeta{ + Offset: offset, + Limit: limit, + InvitedBy: invitedBy, + InviteeUserID: inviteeUserID, + RoleID: roleID, + DomainID: domainID, + State: state, + }, + } + + return req, nil +} + +func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + var req acceptInvitationReq + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) + } + + return req, nil +} + +func decodeInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { + req := invitationReq{ + userID: chi.URLParam(r, "userID"), + domainID: chi.URLParam(r, "domainID"), + } + + return req, nil +} diff --git a/domains/api/http/endpoint.go b/domains/api/http/endpoint.go index b50fe898a..b3cd3ae7e 100644 --- a/domains/api/http/endpoint.go +++ b/domains/api/http/endpoint.go @@ -15,6 +15,9 @@ import ( "github.com/go-kit/kit/endpoint" ) +// InvitationSent is the message returned when an invitation is sent. +const InvitationSent = "invitation sent" + func createDomainEndpoint(svc domains.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { req := request.(createDomainReq) @@ -182,3 +185,161 @@ func freezeDomainEndpoint(svc domains.Service) endpoint.Endpoint { return freezeDomainRes{}, nil } } + +func sendInvitationEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(sendInvitationReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + invitation := domains.Invitation{ + InviteeUserID: req.InviteeUserID, + DomainID: session.DomainID, + RoleID: req.RoleID, + } + + if err := svc.SendInvitation(ctx, session, invitation); err != nil { + return nil, err + } + + return sendInvitationRes{ + Message: InvitationSent, + }, nil + } +} + +func viewInvitationEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(invitationReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + session.DomainID = req.domainID + invitation, err := svc.ViewInvitation(ctx, session, req.userID, req.domainID) + if err != nil { + return nil, err + } + + return viewInvitationRes{ + Invitation: invitation, + }, nil + } +} + +func listDomainInvitationsEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listInvitationsReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + req.InvitationPageMeta.DomainID = session.DomainID + + page, err := svc.ListInvitations(ctx, session, req.InvitationPageMeta) + if err != nil { + return nil, err + } + + return listInvitationsRes{ + page, + }, nil + } +} + +func listUserInvitationsEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listInvitationsReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + session.DomainID = req.DomainID + + page, err := svc.ListInvitations(ctx, session, req.InvitationPageMeta) + if err != nil { + return nil, err + } + + return listInvitationsRes{ + page, + }, nil + } +} + +func acceptInvitationEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(acceptInvitationReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + if err := svc.AcceptInvitation(ctx, session, req.DomainID); err != nil { + return nil, err + } + + return acceptInvitationRes{}, nil + } +} + +func rejectInvitationEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(acceptInvitationReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + + if err := svc.RejectInvitation(ctx, session, req.DomainID); err != nil { + return nil, err + } + + return rejectInvitationRes{}, nil + } +} + +func deleteInvitationEndpoint(svc domains.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(invitationReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(api.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthorization + } + session.DomainID = req.domainID + + if err := svc.DeleteInvitation(ctx, session, req.userID, req.domainID); err != nil { + return nil, err + } + + return deleteInvitationRes{}, nil + } +} diff --git a/domains/api/http/endpoint_test.go b/domains/api/http/endpoint_test.go index 51a9b31ad..d5bfdc10c 100644 --- a/domains/api/http/endpoint_test.go +++ b/domains/api/http/endpoint_test.go @@ -45,6 +45,8 @@ var ( inValidToken = "invalid" invalid = "invalid" userID = testsutil.GenerateUUID(&testing.T{}) + validID = testsutil.GenerateUUID(&testing.T{}) + domainID = testsutil.GenerateUUID(&testing.T{}) ) const ( @@ -1072,6 +1074,587 @@ func TestFreezeDomain(t *testing.T) { } } +func TestSendInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + domainID string + data string + session authn.Session + contentType string + status int + authnErr error + svcErr error + }{ + { + desc: "send invitation with valid request", + token: validToken, + domainID: domainID, + data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID), + status: http.StatusCreated, + contentType: contentType, + svcErr: nil, + }, + { + desc: "send invitation with invalid token", + token: "", + domainID: domainID, + data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID), + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "send invitation with empty domain_id", + token: validToken, + domainID: "", + data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID), + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "send invitation with invalid content type", + token: validToken, + domainID: domainID, + data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + svcErr: nil, + }, + { + desc: "send invitation with invalid data", + token: validToken, + domainID: domainID, + data: `data`, + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "send invitation with service error", + token: validToken, + domainID: domainID, + data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID), + status: http.StatusForbidden, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("SendInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/domains/%s/invitations", is.URL, tc.domainID), + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + +func TestListInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + session authn.Session + query string + contentType string + status int + svcErr error + authnErr error + }{ + { + desc: "list invitations with valid request", + token: validToken, + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invalid token", + token: "", + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with offset", + token: validToken, + query: "offset=1", + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invalid offset", + token: validToken, + query: "offset=invalid", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with limit", + token: validToken, + query: "limit=1", + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invalid limit", + token: validToken, + query: "limit=invalid", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invitee_user_id", + token: validToken, + query: fmt.Sprintf("invitee_user_id=%s", validID), + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with duplicate invitee_user_id", + token: validToken, + query: "invitee_user_id=1&invitee_user_id=2", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invited_by", + token: validToken, + query: fmt.Sprintf("invited_by=%s", validID), + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with duplicate invited_by", + token: validToken, + query: "invited_by=1&invited_by=2", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with state", + token: validToken, + query: "state=pending", + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with invalid state", + token: validToken, + query: "state=invalid", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with duplicate state", + token: validToken, + query: "state=all&state=all", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "list invitations with service error", + token: validToken, + status: http.StatusForbidden, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("ListInvitations", mock.Anything, tc.session, mock.Anything).Return(domains.InvitationPage{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodGet, + url: is.URL + "/invitations?" + tc.query, + token: tc.token, + contentType: tc.contentType, + } + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + +func TestViewInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + session authn.Session + domainID string + userID string + contentType string + status int + svcErr error + authnErr error + }{ + { + desc: "view invitation with valid request", + token: validToken, + userID: validID, + domainID: domainID, + status: http.StatusOK, + contentType: contentType, + svcErr: nil, + }, + { + desc: "view invitation with invalid token", + token: "", + userID: validID, + domainID: domainID, + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "view invitation with service error", + token: validToken, + userID: validID, + domainID: domainID, + status: http.StatusBadRequest, + contentType: contentType, + svcErr: svcerr.ErrViewEntity, + }, + { + desc: "view invitation with empty domain", + token: validToken, + userID: validID, + domainID: "", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "view invitation with empty invitee_user_id and domain_id", + token: validToken, + userID: "", + domainID: "", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(domains.Invitation{}, tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/domains/%s/invitations/%s", is.URL, tc.domainID, tc.userID), + token: tc.token, + contentType: tc.contentType, + } + + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + +func TestDeleteInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + session authn.Session + domainID string + userID string + contentType string + status int + svcErr error + authnErr error + }{ + { + desc: "delete invitation with valid request", + token: validToken, + userID: validID, + domainID: domainID, + status: http.StatusNoContent, + contentType: contentType, + svcErr: nil, + }, + { + desc: "delete invitation with invalid token", + token: "", + userID: validID, + domainID: domainID, + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "delete invitation with service error", + token: validToken, + userID: validID, + domainID: domainID, + status: http.StatusForbidden, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + }, + { + desc: "delete invitation with empty invitee_user_id", + token: validToken, + userID: "", + domainID: domainID, + status: http.StatusMethodNotAllowed, + contentType: contentType, + svcErr: nil, + }, + { + desc: "delete invitation with empty domain_id", + token: validToken, + userID: validID, + domainID: "", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + { + desc: "delete invitation with empty invitee_user_id and domain_id", + token: validToken, + userID: "", + domainID: "", + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodDelete, + url: fmt.Sprintf("%s/domains/%s/invitations/%s", is.URL, tc.domainID, tc.userID), + token: tc.token, + contentType: tc.contentType, + } + + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + +func TestAcceptInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + session authn.Session + data string + contentType string + status int + svcErr error + authnErr error + }{ + { + desc: "accept invitation with valid request", + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + token: validToken, + status: http.StatusNoContent, + contentType: contentType, + svcErr: nil, + }, + { + desc: "accept invitation with invalid token", + token: "", + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "accept invitation with service error", + token: validToken, + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusForbidden, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + }, + { + desc: "accept invitation with invalid content type", + token: validToken, + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + svcErr: nil, + }, + { + desc: "accept invitation with invalid data", + token: validToken, + data: `data`, + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID, DomainID: domainID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("AcceptInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: is.URL + "/invitations/accept", + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + +func TestRejectInvitation(t *testing.T) { + is, svc, auth := newDomainsServer() + + cases := []struct { + desc string + token string + session authn.Session + data string + contentType string + status int + svcErr error + authnErr error + }{ + { + desc: "reject invitation with valid request", + token: validToken, + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusNoContent, + contentType: contentType, + svcErr: nil, + }, + { + desc: "reject invitation with invalid token", + token: "", + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusUnauthorized, + contentType: contentType, + svcErr: nil, + }, + { + desc: "reject invitation with unauthorized error", + token: validToken, + data: fmt.Sprintf(`{"domain_id": "%s"}`, "invalid"), + status: http.StatusForbidden, + contentType: contentType, + svcErr: svcerr.ErrAuthorization, + }, + { + desc: "reject invitation with invalid content type", + token: validToken, + data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), + status: http.StatusUnsupportedMediaType, + contentType: "text/plain", + svcErr: nil, + }, + { + desc: "reject invitation with invalid data", + token: validToken, + data: `data`, + status: http.StatusBadRequest, + contentType: contentType, + svcErr: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + if tc.token == validToken { + tc.session = authn.Session{UserID: userID, DomainID: domainID} + } + authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr) + repoCall := svc.On("RejectInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr) + req := testRequest{ + client: is.Client(), + method: http.MethodPost, + url: is.URL + "/invitations/reject", + token: tc.token, + contentType: tc.contentType, + body: strings.NewReader(tc.data), + } + + res, err := req.make() + assert.Nil(t, err, tc.desc) + assert.Equal(t, tc.status, res.StatusCode, tc.desc) + repoCall.Unset() + authnCall.Unset() + }) + } +} + type respBody struct { Err string `json:"error"` Message string `json:"message"` diff --git a/domains/api/http/requests.go b/domains/api/http/requests.go index 7f12eaa1a..b723044d6 100644 --- a/domains/api/http/requests.go +++ b/domains/api/http/requests.go @@ -8,6 +8,8 @@ import ( "github.com/absmach/supermq/domains" ) +const maxLimitSize = 100 + type page struct { offset uint64 limit uint64 @@ -112,3 +114,56 @@ func (req freezeDomainReq) validate() error { return nil } + +type sendInvitationReq struct { + InviteeUserID string `json:"invitee_user_id,omitempty"` + RoleID string `json:"role_id,omitempty"` +} + +func (req *sendInvitationReq) validate() error { + if req.InviteeUserID == "" || req.RoleID == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type listInvitationsReq struct { + domains.InvitationPageMeta +} + +func (req *listInvitationsReq) validate() error { + if req.InvitationPageMeta.Limit > maxLimitSize || req.InvitationPageMeta.Limit < 1 { + return apiutil.ErrLimitSize + } + + return nil +} + +type acceptInvitationReq struct { + DomainID string `json:"domain_id,omitempty"` +} + +func (req *acceptInvitationReq) validate() error { + if req.DomainID == "" { + return apiutil.ErrMissingDomainID + } + + return nil +} + +type invitationReq struct { + userID string + domainID string +} + +func (req *invitationReq) validate() error { + if req.userID == "" { + return apiutil.ErrMissingID + } + if req.domainID == "" { + return apiutil.ErrMissingDomainID + } + + return nil +} diff --git a/domains/api/http/responses.go b/domains/api/http/responses.go index ae8e07636..a573e734f 100644 --- a/domains/api/http/responses.go +++ b/domains/api/http/responses.go @@ -14,6 +14,15 @@ var ( _ supermq.Response = (*createDomainRes)(nil) _ supermq.Response = (*retrieveDomainRes)(nil) _ supermq.Response = (*listDomainsRes)(nil) + _ supermq.Response = (*enableDomainRes)(nil) + _ supermq.Response = (*disableDomainRes)(nil) + _ supermq.Response = (*freezeDomainRes)(nil) + _ supermq.Response = (*sendInvitationRes)(nil) + _ supermq.Response = (*viewInvitationRes)(nil) + _ supermq.Response = (*listInvitationsRes)(nil) + _ supermq.Response = (*acceptInvitationRes)(nil) + _ supermq.Response = (*rejectInvitationRes)(nil) + _ supermq.Response = (*deleteInvitationRes)(nil) ) type createDomainRes struct { @@ -121,3 +130,93 @@ func (res freezeDomainRes) Headers() map[string]string { func (res freezeDomainRes) Empty() bool { return true } + +type sendInvitationRes struct { + Message string `json:"message"` +} + +func (res sendInvitationRes) Code() int { + return http.StatusCreated +} + +func (res sendInvitationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res sendInvitationRes) Empty() bool { + return true +} + +type viewInvitationRes struct { + domains.Invitation `json:",inline"` +} + +func (res viewInvitationRes) Code() int { + return http.StatusOK +} + +func (res viewInvitationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res viewInvitationRes) Empty() bool { + return false +} + +type listInvitationsRes struct { + domains.InvitationPage `json:",inline"` +} + +func (res listInvitationsRes) Code() int { + return http.StatusOK +} + +func (res listInvitationsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listInvitationsRes) Empty() bool { + return false +} + +type acceptInvitationRes struct{} + +func (res acceptInvitationRes) Code() int { + return http.StatusNoContent +} + +func (res acceptInvitationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res acceptInvitationRes) Empty() bool { + return true +} + +type deleteInvitationRes struct{} + +func (res deleteInvitationRes) Code() int { + return http.StatusNoContent +} + +func (res deleteInvitationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteInvitationRes) Empty() bool { + return true +} + +type rejectInvitationRes struct{} + +func (res rejectInvitationRes) Code() int { + return http.StatusNoContent +} + +func (res rejectInvitationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res rejectInvitationRes) Empty() bool { + return true +} diff --git a/domains/api/http/transport.go b/domains/api/http/transport.go index fa68fe2f3..a8da26df2 100644 --- a/domains/api/http/transport.go +++ b/domains/api/http/transport.go @@ -5,6 +5,7 @@ package http import ( "log/slog" + "net/http" "github.com/absmach/supermq" api "github.com/absmach/supermq/api/http" @@ -18,7 +19,8 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) -func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) *chi.Mux { +// MakeHandler returns a HTTP handler for Domains and Invitations API endpoints. +func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -82,9 +84,61 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, ), "freeze_domain").ServeHTTP) roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts) }) + + r.Route("/{domainID}/invitations", func(r chi.Router) { + r.Use(api.AuthenticateMiddleware(authn, true)) + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + sendInvitationEndpoint(svc), + decodeSendInvitationReq, + api.EncodeResponse, + opts..., + ), "send_invitation").ServeHTTP) + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listDomainInvitationsEndpoint(svc), + decodeListInvitationsReq, + api.EncodeResponse, + opts..., + ), "list_domain_invitations").ServeHTTP) + r.Route("/{userID}", func(r chi.Router) { + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + viewInvitationEndpoint(svc), + decodeInvitationReq, + api.EncodeResponse, + opts..., + ), "view_invitation").ServeHTTP) + r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( + deleteInvitationEndpoint(svc), + decodeInvitationReq, + api.EncodeResponse, + opts..., + ), "delete_invitation").ServeHTTP) + }) + }) }) - mux.Get("/health", supermq.Health("auth", instanceID)) + mux.Route("/invitations", func(r chi.Router) { + r.Use(api.AuthenticateMiddleware(authn, false)) + r.Get("/", otelhttp.NewHandler(kithttp.NewServer( + listUserInvitationsEndpoint(svc), + decodeListInvitationsReq, + api.EncodeResponse, + opts..., + ), "list_user_invitations").ServeHTTP) + r.Post("/accept", otelhttp.NewHandler(kithttp.NewServer( + acceptInvitationEndpoint(svc), + decodeAcceptInvitationReq, + api.EncodeResponse, + opts..., + ), "accept_invitation").ServeHTTP) + r.Post("/reject", otelhttp.NewHandler(kithttp.NewServer( + rejectInvitationEndpoint(svc), + decodeAcceptInvitationReq, + api.EncodeResponse, + opts..., + ), "reject_invitation").ServeHTTP) + }) + + mux.Get("/health", supermq.Health("domains", instanceID)) mux.Handle("/metrics", promhttp.Handler()) return mux diff --git a/domains/domains.go b/domains/domains.go index 5df2acf23..b153a55d7 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -105,6 +105,7 @@ type DomainReq struct { UpdatedBy *string `json:"updated_by,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` } + type Domain struct { ID string `json:"id"` Name string `json:"name"` @@ -137,7 +138,7 @@ type Page struct { ID string `json:"id,omitempty"` IDs []string `json:"-"` Identity string `json:"identity,omitempty"` - UserID string `json:"-"` + UserID string `json:"user_id,omitempty"` } type DomainsPage struct { @@ -164,13 +165,64 @@ func (page DomainsPage) MarshalJSON() ([]byte, error) { //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" type Service interface { + // CreateDomain creates a new domain. CreateDomain(ctx context.Context, sesssion authn.Session, d Domain) (Domain, []roles.RoleProvision, error) + + // RetrieveDomain retrieves a domain specified by the provided ID. RetrieveDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) + + // UpdateDomain updates the domain specified by the provided ID. UpdateDomain(ctx context.Context, sesssion authn.Session, id string, d DomainReq) (Domain, error) + + // EnableDomain enables the domain specified by the provided ID. EnableDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) + + // DisableDomain disables the domain specified by the provided ID. + // Only platform administrators and domain admins can disable domains. DisableDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) + + // FreezeDomain freezes the domain specified by the provided ID. + // Only platform administrators can freeze domains. FreezeDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) + + // ListDomains returns a list of domains. ListDomains(ctx context.Context, sesssion authn.Session, page Page) (DomainsPage, error) + + // SendInvitation sends an invitation to the given user. + // Only domain administrators and platform administrators can send invitations. + SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) (err error) + + // ViewInvitation returns an invitation. + // People who can view invitations are: + // - the invited user: they can view their own invitations + // - the user who sent the invitation + // - domain administrators + // - platform administrators + ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation Invitation, err error) + + // ListInvitations returns a list of invitations. + // People who can list invitations are: + // - platform administrators can list all invitations + // - domain administrators can list invitations for their domain + // By default, it will list invitations the current user has sent or received. + ListInvitations(ctx context.Context, session authn.Session, page InvitationPageMeta) (invitations InvitationPage, err error) + + // AcceptInvitation accepts an invitation by adding the user to the domain. + AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) + + // DeleteInvitation deletes an invitation. + // People who can delete invitations are: + // - the invited user: they can delete their own invitations + // - the user who sent the invitation + // - domain administrators + // - platform administrators + DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) + + // RejectInvitation rejects an invitation. + // People who can reject invitations are: + // - the invited user: they can reject their own invitations + RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) + roles.RoleManager } @@ -178,26 +230,45 @@ type Service interface { // //go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines" type Repository interface { - // Save creates db insert transaction for the given domain. - Save(ctx context.Context, d Domain) (Domain, error) + // SaveDomain creates db insert transaction for the given domain. + SaveDomain(ctx context.Context, d Domain) (Domain, error) - // RetrieveByID retrieves Domain by its unique ID. - RetrieveByID(ctx context.Context, id string) (Domain, error) + // RetrieveDomainByID retrieves a domain by its unique ID. + RetrieveDomainByID(ctx context.Context, id string) (Domain, error) - RetrieveByUserAndID(ctx context.Context, userID, id string) (Domain, error) + // RetrieveDomainByUserAndID retrieves a domain by its unique ID and user ID. + RetrieveDomainByUserAndID(ctx context.Context, userID, id string) (Domain, error) - // RetrieveAllByIDs retrieves for given Domain IDs. - RetrieveAllByIDs(ctx context.Context, pm Page) (DomainsPage, error) + // RetrieveAllDomainsByIDs retrieves for given Domain IDs. + RetrieveAllDomainsByIDs(ctx context.Context, pm Page) (DomainsPage, error) - // Update updates the client name and metadata. - Update(ctx context.Context, id string, d DomainReq) (Domain, error) + // UpdateDomain updates the domain name and metadata. + UpdateDomain(ctx context.Context, id string, d DomainReq) (Domain, error) - // Delete - Delete(ctx context.Context, id string) error + // DeleteDomain deletes the domain. + DeleteDomain(ctx context.Context, id string) error // ListDomains list all the domains ListDomains(ctx context.Context, pm Page) (DomainsPage, error) + // CreateInvitation creates an invitation. + SaveInvitation(ctx context.Context, invitation Invitation) (err error) + + // RetrieveInvitation retrieves an invitation. + RetrieveInvitation(ctx context.Context, userID, domainID string) (Invitation, error) + + // RetrieveAllInvitations retrieves all invitations. + RetrieveAllInvitations(ctx context.Context, page InvitationPageMeta) (invitations InvitationPage, err error) + + // UpdateConfirmation updates an invitation by setting the confirmation time. + UpdateConfirmation(ctx context.Context, invitation Invitation) (err error) + + // UpdateRejection updates an invitation by setting the rejection time. + UpdateRejection(ctx context.Context, invitation Invitation) (err error) + + // Delete deletes an invitation. + DeleteInvitation(ctx context.Context, userID, domainID string) (err error) + roles.Repository } diff --git a/domains/events/events.go b/domains/events/events.go index ee60fdd8e..aa5b300b6 100644 --- a/domains/events/events.go +++ b/domains/events/events.go @@ -23,6 +23,13 @@ const ( domainFreeze = domainPrefix + "freeze" domainList = domainPrefix + "list" domainUserDelete = domainPrefix + "user_delete" + invitationPrefix = "invitation." + invitationSend = invitationPrefix + "send" + invitationAccept = invitationPrefix + "accept" + invitationReject = invitationPrefix + "reject" + invitationList = invitationPrefix + "list" + invitationRetrieve = invitationPrefix + "retrieve" + invitationDelete = invitationPrefix + "delete" ) var ( @@ -34,6 +41,12 @@ var ( _ events.Event = (*disableDomainEvent)(nil) _ events.Event = (*freezeDomainEvent)(nil) _ events.Event = (*listDomainsEvent)(nil) + _ events.Event = (*sendInvitationEvent)(nil) + _ events.Event = (*viewInvitationEvent)(nil) + _ events.Event = (*listInvitationsEvent)(nil) + _ events.Event = (*acceptInvitationEvent)(nil) + _ events.Event = (*rejectInvitationEvent)(nil) + _ events.Event = (*deleteInvitationEvent)(nil) ) type createDomainEvent struct { @@ -275,3 +288,130 @@ func (lde listDomainsEvent) Encode() (map[string]interface{}, error) { return val, nil } + +type sendInvitationEvent struct { + invitation domains.Invitation + session authn.Session +} + +func (sie sendInvitationEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationSend, + "invitee_user_id": sie.invitation.InviteeUserID, + "domain_id": sie.invitation.DomainID, + "invited_by": sie.session.UserID, + "role_id": sie.invitation.RoleID, + "token_type": sie.session.Type.String(), + "super_admin": sie.session.SuperAdmin, + } + + return val, nil +} + +type viewInvitationEvent struct { + inviteeUserID string + domainID string + roleID string + roleName string + session authn.Session +} + +func (vie viewInvitationEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationRetrieve, + "invitee_user_id": vie.inviteeUserID, + "domain_id": vie.domainID, + "role_id": vie.roleID, + "role_name": vie.roleName, + "token_type": vie.session.Type.String(), + "super_admin": vie.session.SuperAdmin, + } + + return val, nil +} + +type listInvitationsEvent struct { + domains.InvitationPageMeta + session authn.Session +} + +func (lie listInvitationsEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationList, + "offset": lie.Offset, + "limit": lie.Limit, + "user_id": lie.session.UserID, + "token_type": lie.session.Type.String(), + "super_admin": lie.session.SuperAdmin, + } + + if lie.InvitedBy != "" { + val["invited_by"] = lie.InvitedBy + } + if lie.InviteeUserID != "" { + val["invitee_user_id"] = lie.InviteeUserID + } + if lie.DomainID != "" { + val["domain_id"] = lie.DomainID + } + if lie.RoleID != "" { + val["role_id"] = lie.RoleID + } + if lie.State.String() != "" { + val["state"] = lie.State.String() + } + + return val, nil +} + +type acceptInvitationEvent struct { + domainID string + session authn.Session +} + +func (aie acceptInvitationEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationAccept, + "domain_id": aie.domainID, + "invitee_user_id": aie.session.UserID, + "token_type": aie.session.Type.String(), + "super_admin": aie.session.SuperAdmin, + } + + return val, nil +} + +type rejectInvitationEvent struct { + domainID string + session authn.Session +} + +func (rie rejectInvitationEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationReject, + "domain_id": rie.domainID, + "invitee_user_id": rie.session.UserID, + "token_type": rie.session.Type.String(), + "super_admin": rie.session.SuperAdmin, + } + + return val, nil +} + +type deleteInvitationEvent struct { + inviteeUserID string + domainID string + session authn.Session +} + +func (die deleteInvitationEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": invitationDelete, + "invitee_user_id": die.inviteeUserID, + "domain_id": die.domainID, + "token_type": die.session.Type.String(), + "super_admin": die.session.SuperAdmin, + } + + return val, nil +} diff --git a/domains/events/streams.go b/domains/events/streams.go index 12a382588..3e07b5791 100644 --- a/domains/events/streams.go +++ b/domains/events/streams.go @@ -176,3 +176,95 @@ func (es *eventStore) ListDomains(ctx context.Context, session authn.Session, p return dp, nil } + +func (es *eventStore) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) error { + if err := es.svc.SendInvitation(ctx, session, invitation); err != nil { + return err + } + + event := sendInvitationEvent{ + invitation: invitation, + session: session, + } + + return es.Publish(ctx, event) +} + +func (es *eventStore) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (domains.Invitation, error) { + invitation, err := es.svc.ViewInvitation(ctx, session, userID, domainID) + if err != nil { + return invitation, err + } + + event := viewInvitationEvent{ + inviteeUserID: userID, + domainID: domainID, + roleID: invitation.RoleID, + roleName: invitation.RoleName, + session: session, + } + + if err := es.Publish(ctx, event); err != nil { + return invitation, err + } + + return invitation, nil +} + +func (es *eventStore) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (domains.InvitationPage, error) { + ip, err := es.svc.ListInvitations(ctx, session, pm) + if err != nil { + return ip, err + } + + event := listInvitationsEvent{ + InvitationPageMeta: pm, + session: session, + } + + if err := es.Publish(ctx, event); err != nil { + return ip, err + } + + return ip, nil +} + +func (es *eventStore) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error { + if err := es.svc.AcceptInvitation(ctx, session, domainID); err != nil { + return err + } + + event := acceptInvitationEvent{ + domainID: domainID, + session: session, + } + + return es.Publish(ctx, event) +} + +func (es *eventStore) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error { + if err := es.svc.RejectInvitation(ctx, session, domainID); err != nil { + return err + } + + event := rejectInvitationEvent{ + domainID: domainID, + session: session, + } + + return es.Publish(ctx, event) +} + +func (es *eventStore) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) error { + if err := es.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID); err != nil { + return err + } + + event := deleteInvitationEvent{ + inviteeUserID: inviteeUserID, + domainID: domainID, + session: session, + } + + return es.Publish(ctx, event) +} diff --git a/domains/invitations.go b/domains/invitations.go new file mode 100644 index 000000000..c85525302 --- /dev/null +++ b/domains/invitations.go @@ -0,0 +1,57 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package domains + +import ( + "encoding/json" + "time" +) + +// Invitation is an invitation to join a domain. +type Invitation struct { + InvitedBy string `json:"invited_by"` + InviteeUserID string `json:"invitee_user_id"` + DomainID string `json:"domain_id"` + RoleID string `json:"role_id,omitempty"` + RoleName string `json:"role_name,omitempty"` + Actions []string `json:"actions,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + ConfirmedAt time.Time `json:"confirmed_at,omitempty"` + RejectedAt time.Time `json:"rejected_at,omitempty"` +} + +// InvitationPage is a page of invitations. +type InvitationPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Invitations []Invitation `json:"invitations"` +} + +func (page InvitationPage) MarshalJSON() ([]byte, error) { + type Alias InvitationPage + a := struct { + Alias + }{ + Alias: Alias(page), + } + + if a.Invitations == nil { + a.Invitations = make([]Invitation, 0) + } + + return json.Marshal(a) +} + +type InvitationPageMeta struct { + Offset uint64 `json:"offset" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + InvitedBy string `json:"invited_by,omitempty" db:"invited_by,omitempty"` + InviteeUserID string `json:"invitee_user_id,omitempty" db:"invitee_user_id,omitempty"` + DomainID string `json:"domain_id,omitempty" db:"domain_id,omitempty"` + RoleID string `json:"role_id,omitempty" db:"role_id,omitempty"` + InvitedByOrUserID string `db:"invited_by_or_user_id,omitempty"` + State State `json:"state,omitempty"` +} diff --git a/domains/invitations_test.go b/domains/invitations_test.go new file mode 100644 index 000000000..197dfceab --- /dev/null +++ b/domains/invitations_test.go @@ -0,0 +1,50 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package domains_test + +import ( + "fmt" + "testing" + + "github.com/absmach/supermq/domains" + "github.com/stretchr/testify/assert" +) + +func TestInvitation_MarshalJSON(t *testing.T) { + cases := []struct { + desc string + page domains.InvitationPage + res string + }{ + { + desc: "empty page", + page: domains.InvitationPage{ + Invitations: []domains.Invitation(nil), + }, + res: `{"total":0,"offset":0,"limit":0,"invitations":[]}`, + }, + { + desc: "page with invitations", + page: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 0, + Invitations: []domains.Invitation{ + { + InvitedBy: "John", + InviteeUserID: "123", + DomainID: "123", + }, + }, + }, + res: `{"total":1,"offset":0,"limit":0,"invitations":[{"invited_by":"John","invitee_user_id":"123","domain_id":"123","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","confirmed_at":"0001-01-01T00:00:00Z","rejected_at":"0001-01-01T00:00:00Z"}]}`, + }, + } + + for _, tc := range cases { + data, err := tc.page.MarshalJSON() + assert.NoError(t, err, "Unexpected error: %v", err) + assert.Equal(t, tc.res, string(data), fmt.Sprintf("%s: expected %s, got %s", tc.desc, tc.res, string(data))) + } +} diff --git a/domains/middleware/authorization.go b/domains/middleware/authorization.go index 0b8e9ef63..f043e7624 100644 --- a/domains/middleware/authorization.go +++ b/domains/middleware/authorization.go @@ -6,10 +6,13 @@ package middleware import ( "context" + "github.com/absmach/supermq/auth" "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/authz" smqauthz "github.com/absmach/supermq/pkg/authz" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" @@ -18,6 +21,9 @@ import ( var _ domains.Service = (*authorizationMiddleware)(nil) +// ErrMemberExist indicates that the user is already a member of the domain. +var ErrMemberExist = errors.New("user is already a member of the domain") + type authorizationMiddleware struct { svc domains.Service authz smqauthz.Authorization @@ -134,6 +140,69 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth return am.svc.ListDomains(ctx, session, page) } +func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) { + domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.InviteeUserID) + if err := am.extAuthorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil { + // return error if the user is already a member of the domain + return errors.Wrap(svcerr.ErrConflict, ErrMemberExist) + } + + if err := am.checkAdmin(ctx, session); err != nil { + return err + } + + return am.svc.SendInvitation(ctx, session, invitation) +} + +func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domain string) (invitation domains.Invitation, err error) { + session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) + if session.UserID != inviteeUserID { + if err := am.checkAdmin(ctx, session); err != nil { + return domains.Invitation{}, err + } + } + + return am.svc.ViewInvitation(ctx, session, inviteeUserID, domain) +} + +func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (invs domains.InvitationPage, err error) { + session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) + if err := am.extAuthorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil { + session.SuperAdmin = true + page.DomainID = "" + } + + if !session.SuperAdmin { + switch { + case page.DomainID != "": + if err := am.extAuthorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil { + return domains.InvitationPage{}, err + } + default: + page.InvitedByOrUserID = session.UserID + } + } + + return am.svc.ListInvitations(ctx, session, page) +} + +func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + return am.svc.AcceptInvitation(ctx, session, domainID) +} + +func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + return am.svc.RejectInvitation(ctx, session, domainID) +} + +func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) { + session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) + if err := am.checkAdmin(ctx, session); err != nil { + return err + } + + return am.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID) +} + func (am *authorizationMiddleware) authorize(ctx context.Context, op svcutil.Operation, authReq authz.PolicyReq) error { perm, err := am.opp.GetPermission(op) if err != nil { @@ -147,3 +216,49 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op svcutil.Ope return nil } + +// checkAdmin checks if the given user is a domain or platform administrator. +func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error { + req := smqauthz.PolicyReq{ + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.DomainUserID, + Permission: policies.AdminPermission, + ObjectType: policies.DomainType, + Object: session.DomainID, + } + if err := am.authz.Authorize(ctx, req); err == nil { + return nil + } + + req = smqauthz.PolicyReq{ + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: session.UserID, + Permission: policies.AdminPermission, + ObjectType: policies.PlatformType, + Object: policies.SuperMQObject, + } + + if err := am.authz.Authorize(ctx, req); err == nil { + return nil + } + + return svcerr.ErrAuthorization +} + +func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm, objType, obj string) error { + req := authz.PolicyReq{ + SubjectType: policies.UserType, + SubjectKind: policies.UsersKind, + Subject: subj, + Permission: perm, + ObjectType: objType, + Object: obj, + } + if err := am.authz.Authorize(ctx, req); err != nil { + return err + } + + return nil +} diff --git a/domains/middleware/logging.go b/domains/middleware/logging.go index ef6da0c03..c4b2caf99 100644 --- a/domains/middleware/logging.go +++ b/domains/middleware/logging.go @@ -164,3 +164,106 @@ func (lm *loggingMiddleware) ListDomains(ctx context.Context, session authn.Sess }(time.Now()) return lm.svc.ListDomains(ctx, session, page) } + +func (lm *loggingMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("invitee_user_id", invitation.InviteeUserID), + slog.String("domain_id", invitation.DomainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Send invitation failed", args...) + return + } + lm.logger.Info("Send invitation completed successfully", args...) + }(time.Now()) + return lm.svc.SendInvitation(ctx, session, invitation) +} + +func (lm *loggingMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation domains.Invitation, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("invitee_user_id", inviteeUserID), + slog.String("domain_id", domainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("View invitation failed", args...) + return + } + lm.logger.Info("View invitation completed successfully", args...) + }(time.Now()) + return lm.svc.ViewInvitation(ctx, session, inviteeUserID, domainID) +} + +func (lm *loggingMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Group("page", + slog.Uint64("offset", pm.Offset), + slog.Uint64("limit", pm.Limit), + slog.Uint64("total", invs.Total), + ), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List invitations failed", args...) + return + } + lm.logger.Info("List invitations completed successfully", args...) + }(time.Now()) + return lm.svc.ListInvitations(ctx, session, pm) +} + +func (lm *loggingMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", domainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Accept invitation failed", args...) + return + } + lm.logger.Info("Accept invitation completed successfully", args...) + }(time.Now()) + return lm.svc.AcceptInvitation(ctx, session, domainID) +} + +func (lm *loggingMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("domain_id", domainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Reject invitation failed", args...) + return + } + lm.logger.Info("Reject invitation completed successfully", args...) + }(time.Now()) + return lm.svc.RejectInvitation(ctx, session, domainID) +} + +func (lm *loggingMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("invitee_user_id", inviteeUserID), + slog.String("domain_id", domainID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete invitation failed", args...) + return + } + lm.logger.Info("Delete invitation completed successfully", args...) + }(time.Now()) + return lm.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID) +} diff --git a/domains/middleware/metrics.go b/domains/middleware/metrics.go index 6ed643184..64ea289e9 100644 --- a/domains/middleware/metrics.go +++ b/domains/middleware/metrics.go @@ -92,3 +92,51 @@ func (ms *metricsMiddleware) ListDomains(ctx context.Context, session authn.Sess }(time.Now()) return ms.svc.ListDomains(ctx, session, page) } + +func (mm *metricsMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "send_invitation").Add(1) + mm.latency.With("method", "send_invitation").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.SendInvitation(ctx, session, invitation) +} + +func (mm *metricsMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation domains.Invitation, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "view_invitation").Add(1) + mm.latency.With("method", "view_invitation").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ViewInvitation(ctx, session, userID, domainID) +} + +func (mm *metricsMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) { + defer func(begin time.Time) { + mm.counter.With("method", "list_invitations").Add(1) + mm.latency.With("method", "list_invitations").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.ListInvitations(ctx, session, pm) +} + +func (mm *metricsMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "accept_invitation").Add(1) + mm.latency.With("method", "accept_invitation").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.AcceptInvitation(ctx, session, domainID) +} + +func (mm *metricsMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "reject_invitation").Add(1) + mm.latency.With("method", "reject_invitation").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RejectInvitation(ctx, session, domainID) +} + +func (mm *metricsMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { + defer func(begin time.Time) { + mm.counter.With("method", "delete_invitation").Add(1) + mm.latency.With("method", "delete_invitation").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.DeleteInvitation(ctx, session, userID, domainID) +} diff --git a/domains/mocks/repository.go b/domains/mocks/repository.go index e8aad64c1..84124dfff 100644 --- a/domains/mocks/repository.go +++ b/domains/mocks/repository.go @@ -48,12 +48,12 @@ func (_m *Repository) AddRoles(ctx context.Context, rps []roles.RoleProvision) ( return r0, r1 } -// Delete provides a mock function with given fields: ctx, id -func (_m *Repository) Delete(ctx context.Context, id string) error { +// DeleteDomain provides a mock function with given fields: ctx, id +func (_m *Repository) DeleteDomain(ctx context.Context, id string) error { ret := _m.Called(ctx, id) if len(ret) == 0 { - panic("no return value specified for Delete") + panic("no return value specified for DeleteDomain") } var r0 error @@ -66,6 +66,24 @@ func (_m *Repository) Delete(ctx context.Context, id string) error { return r0 } +// DeleteInvitation provides a mock function with given fields: ctx, userID, domainID +func (_m *Repository) DeleteInvitation(ctx context.Context, userID string, domainID string) error { + ret := _m.Called(ctx, userID, domainID) + + if len(ret) == 0 { + panic("no return value specified for DeleteInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, domainID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ListDomains provides a mock function with given fields: ctx, pm func (_m *Repository) ListDomains(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) { ret := _m.Called(ctx, pm) @@ -176,12 +194,12 @@ func (_m *Repository) RemoveRoles(ctx context.Context, roleIDs []string) error { return r0 } -// RetrieveAllByIDs provides a mock function with given fields: ctx, pm -func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) { +// RetrieveAllDomainsByIDs provides a mock function with given fields: ctx, pm +func (_m *Repository) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) { ret := _m.Called(ctx, pm) if len(ret) == 0 { - panic("no return value specified for RetrieveAllByIDs") + panic("no return value specified for RetrieveAllDomainsByIDs") } var r0 domains.DomainsPage @@ -204,6 +222,34 @@ func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (do return r0, r1 } +// RetrieveAllInvitations provides a mock function with given fields: ctx, page +func (_m *Repository) RetrieveAllInvitations(ctx context.Context, page domains.InvitationPageMeta) (domains.InvitationPage, error) { + ret := _m.Called(ctx, page) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAllInvitations") + } + + var r0 domains.InvitationPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, domains.InvitationPageMeta) (domains.InvitationPage, error)); ok { + return rf(ctx, page) + } + if rf, ok := ret.Get(0).(func(context.Context, domains.InvitationPageMeta) domains.InvitationPage); ok { + r0 = rf(ctx, page) + } else { + r0 = ret.Get(0).(domains.InvitationPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, domains.InvitationPageMeta) error); ok { + r1 = rf(ctx, page) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveAllRoles provides a mock function with given fields: ctx, entityID, limit, offset func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { ret := _m.Called(ctx, entityID, limit, offset) @@ -232,12 +278,12 @@ func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, lim return r0, r1 } -// RetrieveByID provides a mock function with given fields: ctx, id -func (_m *Repository) RetrieveByID(ctx context.Context, id string) (domains.Domain, error) { +// RetrieveDomainByID provides a mock function with given fields: ctx, id +func (_m *Repository) RetrieveDomainByID(ctx context.Context, id string) (domains.Domain, error) { ret := _m.Called(ctx, id) if len(ret) == 0 { - panic("no return value specified for RetrieveByID") + panic("no return value specified for RetrieveDomainByID") } var r0 domains.Domain @@ -260,12 +306,12 @@ func (_m *Repository) RetrieveByID(ctx context.Context, id string) (domains.Doma return r0, r1 } -// RetrieveByUserAndID provides a mock function with given fields: ctx, userID, id -func (_m *Repository) RetrieveByUserAndID(ctx context.Context, userID string, id string) (domains.Domain, error) { +// RetrieveDomainByUserAndID provides a mock function with given fields: ctx, userID, id +func (_m *Repository) RetrieveDomainByUserAndID(ctx context.Context, userID string, id string) (domains.Domain, error) { ret := _m.Called(ctx, userID, id) if len(ret) == 0 { - panic("no return value specified for RetrieveByUserAndID") + panic("no return value specified for RetrieveDomainByUserAndID") } var r0 domains.Domain @@ -355,6 +401,34 @@ func (_m *Repository) RetrieveEntityRole(ctx context.Context, entityID string, r return r0, r1 } +// RetrieveInvitation provides a mock function with given fields: ctx, userID, domainID +func (_m *Repository) RetrieveInvitation(ctx context.Context, userID string, domainID string) (domains.Invitation, error) { + ret := _m.Called(ctx, userID, domainID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveInvitation") + } + + var r0 domains.Invitation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (domains.Invitation, error)); ok { + return rf(ctx, userID, domainID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) domains.Invitation); ok { + r0 = rf(ctx, userID, domainID) + } else { + r0 = ret.Get(0).(domains.Invitation) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, domainID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveRole provides a mock function with given fields: ctx, roleID func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Role, error) { ret := _m.Called(ctx, roleID) @@ -629,12 +703,12 @@ func (_m *Repository) RoleRemoveMembers(ctx context.Context, role roles.Role, me return r0 } -// Save provides a mock function with given fields: ctx, d -func (_m *Repository) Save(ctx context.Context, d domains.Domain) (domains.Domain, error) { +// SaveDomain provides a mock function with given fields: ctx, d +func (_m *Repository) SaveDomain(ctx context.Context, d domains.Domain) (domains.Domain, error) { ret := _m.Called(ctx, d) if len(ret) == 0 { - panic("no return value specified for Save") + panic("no return value specified for SaveDomain") } var r0 domains.Domain @@ -657,12 +731,48 @@ func (_m *Repository) Save(ctx context.Context, d domains.Domain) (domains.Domai return r0, r1 } -// Update provides a mock function with given fields: ctx, id, d -func (_m *Repository) Update(ctx context.Context, id string, d domains.DomainReq) (domains.Domain, error) { +// SaveInvitation provides a mock function with given fields: ctx, invitation +func (_m *Repository) SaveInvitation(ctx context.Context, invitation domains.Invitation) error { + ret := _m.Called(ctx, invitation) + + if len(ret) == 0 { + panic("no return value specified for SaveInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok { + r0 = rf(ctx, invitation) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateConfirmation provides a mock function with given fields: ctx, invitation +func (_m *Repository) UpdateConfirmation(ctx context.Context, invitation domains.Invitation) error { + ret := _m.Called(ctx, invitation) + + if len(ret) == 0 { + panic("no return value specified for UpdateConfirmation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok { + r0 = rf(ctx, invitation) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateDomain provides a mock function with given fields: ctx, id, d +func (_m *Repository) UpdateDomain(ctx context.Context, id string, d domains.DomainReq) (domains.Domain, error) { ret := _m.Called(ctx, id, d) if len(ret) == 0 { - panic("no return value specified for Update") + panic("no return value specified for UpdateDomain") } var r0 domains.Domain @@ -685,6 +795,24 @@ func (_m *Repository) Update(ctx context.Context, id string, d domains.DomainReq return r0, r1 } +// UpdateRejection provides a mock function with given fields: ctx, invitation +func (_m *Repository) UpdateRejection(ctx context.Context, invitation domains.Invitation) error { + ret := _m.Called(ctx, invitation) + + if len(ret) == 0 { + panic("no return value specified for UpdateRejection") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok { + r0 = rf(ctx, invitation) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateRole provides a mock function with given fields: ctx, ro func (_m *Repository) UpdateRole(ctx context.Context, ro roles.Role) (roles.Role, error) { ret := _m.Called(ctx, ro) diff --git a/domains/mocks/service.go b/domains/mocks/service.go index 5523941a8..236d635b7 100644 --- a/domains/mocks/service.go +++ b/domains/mocks/service.go @@ -21,6 +21,24 @@ type Service struct { mock.Mock } +// AcceptInvitation provides a mock function with given fields: ctx, session, domainID +func (_m *Service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error { + ret := _m.Called(ctx, session, domainID) + + if len(ret) == 0 { + panic("no return value specified for AcceptInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = rf(ctx, session, domainID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // AddRole provides a mock function with given fields: ctx, session, entityID, roleName, optionalActions, optionalMembers func (_m *Service) AddRole(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) { ret := _m.Called(ctx, session, entityID, roleName, optionalActions, optionalMembers) @@ -86,6 +104,24 @@ func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d d return r0, r1, r2 } +// DeleteInvitation provides a mock function with given fields: ctx, session, inviteeUserID, domainID +func (_m *Service) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID string, domainID string) error { + ret := _m.Called(ctx, session, inviteeUserID, domainID) + + if len(ret) == 0 { + panic("no return value specified for DeleteInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { + r0 = rf(ctx, session, inviteeUserID, domainID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DisableDomain provides a mock function with given fields: ctx, sesssion, id func (_m *Service) DisableDomain(ctx context.Context, sesssion authn.Session, id string) (domains.Domain, error) { ret := _m.Called(ctx, sesssion, id) @@ -256,6 +292,52 @@ func (_m *Service) ListEntityMembers(ctx context.Context, session authn.Session, return r0, r1 } +// ListInvitations provides a mock function with given fields: ctx, session, page +func (_m *Service) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (domains.InvitationPage, error) { + ret := _m.Called(ctx, session, page) + + if len(ret) == 0 { + panic("no return value specified for ListInvitations") + } + + var r0 domains.InvitationPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.InvitationPageMeta) (domains.InvitationPage, error)); ok { + return rf(ctx, session, page) + } + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.InvitationPageMeta) domains.InvitationPage); ok { + r0 = rf(ctx, session, page) + } else { + r0 = ret.Get(0).(domains.InvitationPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, domains.InvitationPageMeta) error); ok { + r1 = rf(ctx, session, page) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RejectInvitation provides a mock function with given fields: ctx, session, domainID +func (_m *Service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error { + ret := _m.Called(ctx, session, domainID) + + if len(ret) == 0 { + panic("no return value specified for RejectInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = rf(ctx, session, domainID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // RemoveEntityMembers provides a mock function with given fields: ctx, session, entityID, members func (_m *Service) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error { ret := _m.Called(ctx, session, entityID, members) @@ -640,6 +722,24 @@ func (_m *Service) RoleRemoveMembers(ctx context.Context, session authn.Session, return r0 } +// SendInvitation provides a mock function with given fields: ctx, session, invitation +func (_m *Service) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) error { + ret := _m.Called(ctx, session, invitation) + + if len(ret) == 0 { + panic("no return value specified for SendInvitation") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.Invitation) error); ok { + r0 = rf(ctx, session, invitation) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateDomain provides a mock function with given fields: ctx, sesssion, id, d func (_m *Service) UpdateDomain(ctx context.Context, sesssion authn.Session, id string, d domains.DomainReq) (domains.Domain, error) { ret := _m.Called(ctx, sesssion, id, d) @@ -696,6 +796,34 @@ func (_m *Service) UpdateRoleName(ctx context.Context, session authn.Session, en return r0, r1 } +// ViewInvitation provides a mock function with given fields: ctx, session, inviteeUserID, domainID +func (_m *Service) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID string, domainID string) (domains.Invitation, error) { + ret := _m.Called(ctx, session, inviteeUserID, domainID) + + if len(ret) == 0 { + panic("no return value specified for ViewInvitation") + } + + var r0 domains.Invitation + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (domains.Invitation, error)); ok { + return rf(ctx, session, inviteeUserID, domainID) + } + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) domains.Invitation); ok { + r0 = rf(ctx, session, inviteeUserID, domainID) + } else { + r0 = ret.Get(0).(domains.Invitation) + } + + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { + r1 = rf(ctx, session, inviteeUserID, domainID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewService(t interface { diff --git a/domains/postgres/domains.go b/domains/postgres/domains.go index e1892573b..3894e0e69 100644 --- a/domains/postgres/domains.go +++ b/domains/postgres/domains.go @@ -36,9 +36,9 @@ type domainRepo struct { rolesPostgres.Repository } -// New instantiates a PostgreSQL +// NewRepository instantiates a PostgreSQL // implementation of Domain repository. -func New(db postgres.Database) domains.Repository { +func NewRepository(db postgres.Database) domains.Repository { rmsvcRepo := rolesPostgres.NewRepository(db, policies.DomainType, rolesTableNamePrefix, entityTableName, entityIDColumnName) return &domainRepo{ db: db, @@ -46,7 +46,7 @@ func New(db postgres.Database) domains.Repository { } } -func (repo domainRepo) Save(ctx context.Context, d domains.Domain) (dd domains.Domain, err error) { +func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd domains.Domain, err error) { q := `INSERT INTO domains (id, name, tags, alias, metadata, created_at, updated_at, updated_by, created_by, status) VALUES (:id, :name, :tags, :alias, :metadata, :created_at, :updated_at, :updated_by, :created_by, :status) RETURNING id, name, tags, alias, metadata, created_at, updated_at, updated_by, created_by, status;` @@ -76,8 +76,8 @@ func (repo domainRepo) Save(ctx context.Context, d domains.Domain) (dd domains.D return domain, nil } -// RetrieveByID retrieves Domain by its unique ID. -func (repo domainRepo) RetrieveByID(ctx context.Context, id string) (domains.Domain, error) { +// RetrieveDomainByID retrieves Domain by its unique ID. +func (repo domainRepo) RetrieveDomainByID(ctx context.Context, id string) (domains.Domain, error) { q := `SELECT d.id as id, d.name as name, d.tags as tags, d.alias as alias, d.metadata as metadata, d.created_at as created_at, d.updated_at as updated_at, d.updated_by as updated_by, d.created_by as created_by, d.status as status FROM domains d WHERE d.id = :id` @@ -107,7 +107,7 @@ func (repo domainRepo) RetrieveByID(ctx context.Context, id string) (domains.Dom return domains.Domain{}, repoerr.ErrNotFound } -func (repo domainRepo) RetrieveByUserAndID(ctx context.Context, userID, id string) (domains.Domain, error) { +func (repo domainRepo) RetrieveDomainByUserAndID(ctx context.Context, userID, id string) (domains.Domain, error) { q := repo.userDomainsBaseQuery() + `SELECT d.id as id, @@ -156,7 +156,7 @@ func (repo domainRepo) RetrieveByUserAndID(ctx context.Context, userID, id strin } // RetrieveAllByIDs retrieves for given Domain IDs . -func (repo domainRepo) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) { +func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) { var q string if len(pm.IDs) == 0 { return domains.DomainsPage{}, nil @@ -170,7 +170,7 @@ func (repo domainRepo) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (d FROM domains d` q = fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", q, query, pm.Limit, pm.Offset) - dbPage, err := toDBClientsPage(pm) + dbPage, err := toDBDomainsPage(pm) if err != nil { return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) } @@ -254,7 +254,7 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain q = fmt.Sprintf(q, squery) - dbPage, err := toDBClientsPage(pm) + dbPage, err := toDBDomainsPage(pm) if err != nil { return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) } @@ -294,8 +294,8 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain }, nil } -// Update updates the client name and metadata. -func (repo domainRepo) Update(ctx context.Context, id string, dr domains.DomainReq) (domains.Domain, error) { +// UpdateDomain updates the client name and metadata. +func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.DomainReq) (domains.Domain, error) { var query []string var upq string d := domains.Domain{ID: id} @@ -361,7 +361,7 @@ func (repo domainRepo) Update(ctx context.Context, id string, dr domains.DomainR } // Delete delete domain from database. -func (repo domainRepo) Delete(ctx context.Context, id string) error { +func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error { q := "DELETE FROM domains WHERE id = $1;" res, err := repo.db.ExecContext(ctx, q, id) @@ -553,7 +553,7 @@ type dbDomainsPage struct { UserID string `db:"member_id"` } -func toDBClientsPage(pm domains.Page) (dbDomainsPage, error) { +func toDBDomainsPage(pm domains.Page) (dbDomainsPage, error) { _, data, err := postgres.CreateMetadataQuery("", pm.Metadata) if err != nil { return dbDomainsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) diff --git a/domains/postgres/domains_test.go b/domains/postgres/domains_test.go index 1c4cc8c22..745daee91 100644 --- a/domains/postgres/domains_test.go +++ b/domains/postgres/domains_test.go @@ -27,13 +27,13 @@ var ( userID = testsutil.GenerateUUID(&testing.T{}) ) -func TestSave(t *testing.T) { +func TestSaveDomain(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM domains") require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) }) - repo := postgres.New(database) + repo := postgres.NewRepository(database) cases := []struct { desc string @@ -134,7 +134,7 @@ func TestSave(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - domain, err := repo.Save(context.Background(), tc.domain) + domain, err := repo.SaveDomain(context.Background(), tc.domain) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { assert.Equal(t, tc.domain, domain, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.domain, domain)) @@ -149,7 +149,7 @@ func TestRetrieveByID(t *testing.T) { require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) }) - repo := postgres.New(database) + repo := postgres.NewRepository(database) domain := domains.Domain{ ID: domainID, @@ -166,7 +166,7 @@ func TestRetrieveByID(t *testing.T) { Status: domains.EnabledStatus, } - _, err := repo.Save(context.Background(), domain) + _, err := repo.SaveDomain(context.Background(), domain) require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID)) cases := []struct { @@ -197,7 +197,7 @@ func TestRetrieveByID(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - d, err := repo.RetrieveByID(context.Background(), tc.domainID) + d, err := repo.RetrieveDomainByID(context.Background(), tc.domainID) assert.Equal(t, tc.response, d, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, d)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) }) @@ -210,7 +210,7 @@ func TestRetrieveAllByIDs(t *testing.T) { require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) }) - repo := postgres.New(database) + repo := postgres.NewRepository(database) items := []domains.Domain{} for i := 0; i < 10; i++ { @@ -233,7 +233,7 @@ func TestRetrieveAllByIDs(t *testing.T) { "test1": "test1", } } - _, err := repo.Save(context.Background(), domain) + _, err := repo.SaveDomain(context.Background(), domain) require.Nil(t, err, fmt.Sprintf("save domain unexpected error: %s", err)) items = append(items, domain) } @@ -434,7 +434,7 @@ func TestRetrieveAllByIDs(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - dp, err := repo.RetrieveAllByIDs(context.Background(), tc.pm) + dp, err := repo.RetrieveAllDomainsByIDs(context.Background(), tc.pm) assert.Equal(t, tc.response, dp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, dp)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) }) @@ -455,7 +455,7 @@ func TestUpdate(t *testing.T) { updatedStatus := domains.DisabledStatus updatedAlias := "test1" - repo := postgres.New(database) + repo := postgres.NewRepository(database) domain := domains.Domain{ ID: domainID, @@ -470,7 +470,7 @@ func TestUpdate(t *testing.T) { Status: domains.EnabledStatus, } - _, err := repo.Save(context.Background(), domain) + _, err := repo.SaveDomain(context.Background(), domain) require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID)) cases := []struct { @@ -561,7 +561,7 @@ func TestUpdate(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - d, err := repo.Update(context.Background(), tc.domainID, tc.d) + d, err := repo.UpdateDomain(context.Background(), tc.domainID, tc.d) d.UpdatedAt = tc.response.UpdatedAt assert.Equal(t, tc.response, d, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, d)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) @@ -575,7 +575,7 @@ func TestDelete(t *testing.T) { require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) }) - repo := postgres.New(database) + repo := postgres.NewRepository(database) domain := domains.Domain{ ID: domainID, @@ -590,7 +590,7 @@ func TestDelete(t *testing.T) { Status: domains.EnabledStatus, } - _, err := repo.Save(context.Background(), domain) + _, err := repo.SaveDomain(context.Background(), domain) require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID)) cases := []struct { @@ -617,7 +617,7 @@ func TestDelete(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - err := repo.Delete(context.Background(), tc.domainID) + err := repo.DeleteDomain(context.Background(), tc.domainID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) }) } @@ -629,7 +629,7 @@ func TestListDomains(t *testing.T) { require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) }) - repo := postgres.New(database) + repo := postgres.NewRepository(database) items := []domains.Domain{} for i := 0; i < 10; i++ { @@ -652,7 +652,7 @@ func TestListDomains(t *testing.T) { "test1": "test1", } } - _, err := repo.Save(context.Background(), domain) + _, err := repo.SaveDomain(context.Background(), domain) require.Nil(t, err, fmt.Sprintf("save domain unexpected error: %s", err)) items = append(items, domain) } diff --git a/domains/postgres/init.go b/domains/postgres/init.go index 35d928412..ae912a8ee 100644 --- a/domains/postgres/init.go +++ b/domains/postgres/init.go @@ -40,6 +40,27 @@ func Migration() (*migrate.MemoryMigrationSource, error) { `DROP TABLE IF EXISTS domains`, }, }, + { + Id: "domain_2", + Up: []string{ + `CREATE TABLE IF NOT EXISTS invitations ( + invited_by VARCHAR(36) NOT NULL, + invitee_user_id VARCHAR(36) NOT NULL, + domain_id VARCHAR(36) NOT NULL, + role_id VARCHAR(36) NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP, + confirmed_at TIMESTAMP, + rejected_at TIMESTAMP, + UNIQUE (invitee_user_id, domain_id), + PRIMARY KEY (invitee_user_id, domain_id), + FOREIGN KEY (domain_id) REFERENCES domains(id) ON DELETE CASCADE + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS invitations`, + }, + }, }, } diff --git a/domains/postgres/invitations.go b/domains/postgres/invitations.go new file mode 100644 index 000000000..9d3f08382 --- /dev/null +++ b/domains/postgres/invitations.go @@ -0,0 +1,229 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/absmach/supermq/domains" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/postgres" +) + +func (repo domainRepo) SaveInvitation(ctx context.Context, invitation domains.Invitation) (err error) { + q := `INSERT INTO invitations (invited_by, invitee_user_id, domain_id, role_id, created_at) + VALUES (:invited_by, :invitee_user_id, :domain_id, :role_id, :created_at)` + + dbInv := toDBInvitation(invitation) + if _, err = repo.db.NamedExecContext(ctx, q, dbInv); err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (repo domainRepo) RetrieveInvitation(ctx context.Context, inviteeUserID, domainID string) (domains.Invitation, error) { + q := `SELECT invited_by, invitee_user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id;` + + dbinv := dbInvitation{ + InviteeUserID: inviteeUserID, + DomainID: domainID, + } + rows, err := repo.db.NamedQueryContext(ctx, q, dbinv) + if err != nil { + return domains.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + dbinv = dbInvitation{} + if rows.Next() { + if err = rows.StructScan(&dbinv); err != nil { + return domains.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + + return toInvitation(dbinv), nil + } + + return domains.Invitation{}, repoerr.ErrNotFound +} + +func (repo domainRepo) RetrieveAllInvitations(ctx context.Context, pm domains.InvitationPageMeta) (domains.InvitationPage, error) { + query := pageQuery(pm) + + q := fmt.Sprintf("SELECT invited_by, invitee_user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query) + + rows, err := repo.db.NamedQueryContext(ctx, q, pm) + if err != nil { + return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []domains.Invitation + for rows.Next() { + var dbinv dbInvitation + if err = rows.StructScan(&dbinv); err != nil { + return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + items = append(items, toInvitation(dbinv)) + } + + tq := fmt.Sprintf(`SELECT COUNT(*) FROM invitations %s`, query) + + total, err := postgres.Total(ctx, repo.db, tq, pm) + if err != nil { + return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + + invPage := domains.InvitationPage{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + Invitations: items, + } + + return invPage, nil +} + +func (repo domainRepo) UpdateConfirmation(ctx context.Context, invitation domains.Invitation) (err error) { + q := `UPDATE invitations SET confirmed_at = :confirmed_at, updated_at = :updated_at WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id` + + dbinv := toDBInvitation(invitation) + result, err := repo.db.NamedExecContext(ctx, q, dbinv) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo domainRepo) UpdateRejection(ctx context.Context, invitation domains.Invitation) (err error) { + q := `UPDATE invitations SET rejected_at = :rejected_at, updated_at = :updated_at WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id` + + dbInv := toDBInvitation(invitation) + result, err := repo.db.NamedExecContext(ctx, q, dbInv) + if err != nil { + return postgres.HandleError(repoerr.ErrUpdateEntity, err) + } + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func (repo domainRepo) DeleteInvitation(ctx context.Context, inviteeUserID, domain string) (err error) { + q := `DELETE FROM invitations WHERE invitee_user_id = $1 AND domain_id = $2` + + result, err := repo.db.ExecContext(ctx, q, inviteeUserID, domain) + if err != nil { + return postgres.HandleError(repoerr.ErrRemoveEntity, err) + } + if rows, _ := result.RowsAffected(); rows == 0 { + return repoerr.ErrNotFound + } + + return nil +} + +func pageQuery(pm domains.InvitationPageMeta) string { + var query []string + var emq string + if pm.DomainID != "" { + query = append(query, "domain_id = :domain_id") + } + if pm.InviteeUserID != "" { + query = append(query, "invitee_user_id = :invitee_user_id") + } + if pm.InvitedBy != "" { + query = append(query, "invited_by = :invited_by") + } + if pm.RoleID != "" { + query = append(query, "role_id = :role_id") + } + if pm.InvitedByOrUserID != "" { + query = append(query, "(invited_by = :invited_by_or_user_id OR invitee_user_id = :invited_by_or_user_id)") + } + if pm.State == domains.Accepted { + query = append(query, "confirmed_at IS NOT NULL") + } + if pm.State == domains.Pending { + query = append(query, "confirmed_at IS NULL AND rejected_at IS NULL") + } + if pm.State == domains.Rejected { + query = append(query, "rejected_at IS NOT NULL") + } + + if len(query) > 0 { + emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) + } + + return emq +} + +type dbInvitation struct { + InvitedBy string `db:"invited_by"` + InviteeUserID string `db:"invitee_user_id"` + DomainID string `db:"domain_id"` + RoleID string `db:"role_id,omitempty"` + Relation string `db:"relation"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + ConfirmedAt sql.NullTime `db:"confirmed_at,omitempty"` + RejectedAt sql.NullTime `db:"rejected_at,omitempty"` +} + +func toDBInvitation(inv domains.Invitation) dbInvitation { + var updatedAt, confirmedAt, rejectedAt sql.NullTime + if inv.UpdatedAt != (time.Time{}) { + updatedAt = sql.NullTime{Time: inv.UpdatedAt, Valid: true} + } + if inv.ConfirmedAt != (time.Time{}) { + confirmedAt = sql.NullTime{Time: inv.ConfirmedAt, Valid: true} + } + if inv.RejectedAt != (time.Time{}) { + rejectedAt = sql.NullTime{Time: inv.RejectedAt, Valid: true} + } + + return dbInvitation{ + InvitedBy: inv.InvitedBy, + InviteeUserID: inv.InviteeUserID, + DomainID: inv.DomainID, + RoleID: inv.RoleID, + CreatedAt: inv.CreatedAt, + UpdatedAt: updatedAt, + ConfirmedAt: confirmedAt, + RejectedAt: rejectedAt, + } +} + +func toInvitation(dbinv dbInvitation) domains.Invitation { + var updatedAt, confirmedAt, rejectedAt time.Time + if dbinv.UpdatedAt.Valid { + updatedAt = dbinv.UpdatedAt.Time + } + if dbinv.ConfirmedAt.Valid { + confirmedAt = dbinv.ConfirmedAt.Time + } + if dbinv.RejectedAt.Valid { + rejectedAt = dbinv.RejectedAt.Time + } + + return domains.Invitation{ + InvitedBy: dbinv.InvitedBy, + InviteeUserID: dbinv.InviteeUserID, + DomainID: dbinv.DomainID, + RoleID: dbinv.RoleID, + CreatedAt: dbinv.CreatedAt, + UpdatedAt: updatedAt, + ConfirmedAt: confirmedAt, + RejectedAt: rejectedAt, + } +} diff --git a/domains/postgres/invitations_test.go b/domains/postgres/invitations_test.go new file mode 100644 index 000000000..167d3b474 --- /dev/null +++ b/domains/postgres/invitations_test.go @@ -0,0 +1,833 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/supermq/domains" + "github.com/absmach/supermq/domains/postgres" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var invalidUUID = strings.Repeat("a", 37) + +func TestSaveInvitation(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + userID := testsutil.GenerateUUID(t) + roleID := testsutil.GenerateUUID(t) + + cases := []struct { + desc string + invitation domains.Invitation + err error + }{ + { + desc: "add new invitation successfully", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: userID, + DomainID: domainID, + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: nil, + }, + { + desc: "add new invitation with an confirmed_at date", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + CreatedAt: time.Now(), + RoleID: roleID, + ConfirmedAt: time.Now(), + }, + err: nil, + }, + { + desc: "add invitation with duplicate invitation", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: userID, + DomainID: domainID, + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: repoerr.ErrConflict, + }, + { + desc: "add invitation with invalid invitation invited_by", + invitation: domains.Invitation{ + InvitedBy: invalidUUID, + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "add invitation with invalid invitation domain", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: invalidUUID, + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "add invitation with invalid invitation invitee user id", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: invalidUUID, + DomainID: testsutil.GenerateUUID(t), + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: repoerr.ErrMalformedEntity, + }, + { + desc: "add invitation with empty invitation domain", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "add invitation with empty invitation invitee user id", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: nil, + }, + { + desc: "add invitation with empty invitation invited_by", + invitation: domains.Invitation{ + DomainID: domainID, + InviteeUserID: testsutil.GenerateUUID(t), + RoleID: roleID, + CreatedAt: time.Now(), + }, + err: nil, + }, + { + desc: "add invitation with empty invitation role id", + invitation: domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + CreatedAt: time.Now(), + }, + err: nil, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.SaveInvitation(context.Background(), tc.invitation) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err)) + }) + } +} + +func TestInvitationRetrieve(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + + invitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: testsutil.GenerateUUID(t), + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + + err := repo.SaveInvitation(context.Background(), invitation) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + + cases := []struct { + desc string + userID string + domainID string + response domains.Invitation + err error + }{ + { + desc: "retrieve invitations successfully", + userID: invitation.InviteeUserID, + domainID: invitation.DomainID, + response: invitation, + err: nil, + }, + { + desc: "retrieve invitations with invalid invitee user id", + userID: testsutil.GenerateUUID(t), + domainID: invitation.DomainID, + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve invitations with invalid invitation domain_id", + userID: invitation.InviteeUserID, + domainID: testsutil.GenerateUUID(t), + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve invitations with invalid invitee user id and domain_id", + userID: testsutil.GenerateUUID(t), + domainID: testsutil.GenerateUUID(t), + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve invitations with empty invitee user id", + userID: "", + domainID: invitation.DomainID, + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve invitations with empty invitation domain_id", + userID: invitation.InviteeUserID, + domainID: "", + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + { + desc: "retrieve invitations with empty invitation user id and domain_id", + userID: "", + domainID: "", + response: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + inv, err := repo.RetrieveInvitation(context.Background(), tc.userID, tc.domainID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.response, inv, fmt.Sprintf("desc: %s\n", tc.desc)) + }) + } +} + +func TestInvitationRetrieveAll(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + + num := 200 + + var items []domains.Invitation + for i := 0; i < num; i++ { + invitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: testsutil.GenerateUUID(t), + CreatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + err := repo.SaveInvitation(context.Background(), invitation) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + items = append(items, invitation) + } + items[100].ConfirmedAt = time.Now().UTC().Truncate(time.Microsecond) + err := repo.UpdateConfirmation(context.Background(), items[100]) + require.Nil(t, err, fmt.Sprintf("update invitation unexpected error: %s", err)) + + swap := items[100] + items = append(items[:100], items[101:]...) + items = append(items, swap) + + cases := []struct { + desc string + page domains.InvitationPageMeta + response domains.InvitationPage + err error + }{ + { + desc: "retrieve invitations successfully", + page: domains.InvitationPageMeta{ + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 0, + Limit: 10, + Invitations: items[:10], + }, + err: nil, + }, + { + desc: "retrieve invitations with offset", + page: domains.InvitationPageMeta{ + Offset: 10, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 10, + Limit: 10, + Invitations: items[10:20], + }, + }, + { + desc: "retrieve invitations with limit", + page: domains.InvitationPageMeta{ + Offset: 0, + Limit: 50, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 0, + Limit: 50, + Invitations: items[:50], + }, + }, + { + desc: "retrieve invitations with offset and limit", + page: domains.InvitationPageMeta{ + Offset: 10, + Limit: 50, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 10, + Limit: 50, + Invitations: items[10:60], + }, + }, + { + desc: "retrieve invitations with offset out of range", + page: domains.InvitationPageMeta{ + Offset: 1000, + Limit: 50, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 1000, + Limit: 50, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with offset and limit out of range", + page: domains.InvitationPageMeta{ + Offset: 170, + Limit: 50, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 170, + Limit: 50, + Invitations: items[170:200], + }, + }, + { + desc: "retrieve invitations with limit out of range", + page: domains.InvitationPageMeta{ + Offset: 0, + Limit: 1000, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 0, + Limit: 1000, + Invitations: items, + }, + }, + { + desc: "retrieve invitations with empty page", + page: domains.InvitationPageMeta{}, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 0, + Limit: 0, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with domain", + page: domains.InvitationPageMeta{ + DomainID: items[0].DomainID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: uint64(num), + Offset: 0, + Limit: 10, + Invitations: items[:10], + }, + }, + { + desc: "retrieve invitations with invitee user id", + page: domains.InvitationPageMeta{ + InviteeUserID: items[0].InviteeUserID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with invited_by", + page: domains.InvitationPageMeta{ + InvitedBy: items[0].InvitedBy, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with role_id", + page: domains.InvitationPageMeta{ + RoleID: items[3].RoleID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[3]}, + }, + }, + { + desc: "retrieve invitations with invited_by_or_user_id", + page: domains.InvitationPageMeta{ + InvitedByOrUserID: items[0].InviteeUserID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with domain_id and invitee user id", + page: domains.InvitationPageMeta{ + DomainID: items[0].DomainID, + InviteeUserID: items[0].InviteeUserID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with domain_id and invited_by", + page: domains.InvitationPageMeta{ + DomainID: items[0].DomainID, + InvitedBy: items[0].InvitedBy, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with invitee user id and invited_by", + page: domains.InvitationPageMeta{ + InviteeUserID: items[0].InviteeUserID, + InvitedBy: items[0].InvitedBy, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with domain_id, invitee user id and invited_by", + page: domains.InvitationPageMeta{ + DomainID: items[0].DomainID, + InviteeUserID: items[0].InviteeUserID, + InvitedBy: items[0].InvitedBy, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with domain_id, invitee user id, invited_by and role_id", + page: domains.InvitationPageMeta{ + DomainID: items[0].DomainID, + InviteeUserID: items[0].InviteeUserID, + InvitedBy: items[0].InvitedBy, + RoleID: items[0].RoleID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[0]}, + }, + }, + { + desc: "retrieve invitations with invalid domain", + page: domains.InvitationPageMeta{ + DomainID: invalidUUID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 0, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with invalid invitee user id", + page: domains.InvitationPageMeta{ + InviteeUserID: testsutil.GenerateUUID(t), + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 0, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with invalid invited_by", + page: domains.InvitationPageMeta{ + InvitedBy: invalidUUID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 0, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with invalid role_id", + page: domains.InvitationPageMeta{ + RoleID: invalidUUID, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 0, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation(nil), + }, + }, + { + desc: "retrieve invitations with accepted state", + page: domains.InvitationPageMeta{ + State: domains.Accepted, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{items[num-1]}, + }, + }, + { + desc: "retrieve invitations with pending state", + page: domains.InvitationPageMeta{ + State: domains.Pending, + Offset: 0, + Limit: 10, + }, + response: domains.InvitationPage{ + Total: uint64(num - 1), + Offset: 0, + Limit: 10, + Invitations: items[0:10], + }, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + page, err := repo.RetrieveAllInvitations(context.Background(), tc.page) + assert.Equal(t, tc.response.Total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, page.Total)) + assert.Equal(t, tc.response.Offset, page.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, page.Offset)) + assert.Equal(t, tc.response.Limit, page.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, page.Limit)) + assert.ElementsMatch(t, page.Invitations, tc.response.Invitations, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response.Invitations, page.Invitations)) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestInvitationUpdateConfirmation(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + + invitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + } + err := repo.SaveInvitation(context.Background(), invitation) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + + cases := []struct { + desc string + invitation domains.Invitation + err error + }{ + { + desc: "update invitation successfully", + invitation: domains.Invitation{ + DomainID: invitation.DomainID, + InviteeUserID: invitation.InviteeUserID, + ConfirmedAt: time.Now(), + }, + err: nil, + }, + { + desc: "update invitation with invalid invitee user id", + invitation: domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: invitation.InviteeUserID, + ConfirmedAt: time.Now(), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "update invitation with invalid domain", + invitation: domains.Invitation{ + InviteeUserID: invitation.InviteeUserID, + DomainID: testsutil.GenerateUUID(t), + ConfirmedAt: time.Now(), + }, + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.UpdateConfirmation(context.Background(), tc.invitation) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestInvitationUpdateRejection(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + + invitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + } + err := repo.SaveInvitation(context.Background(), invitation) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + + cases := []struct { + desc string + invitation domains.Invitation + err error + }{ + { + desc: "update invitation successfully", + invitation: domains.Invitation{ + DomainID: invitation.DomainID, + InviteeUserID: invitation.InviteeUserID, + RejectedAt: time.Now(), + }, + err: nil, + }, + { + desc: "update invitation with invalid invitee user id", + invitation: domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: invitation.InviteeUserID, + RejectedAt: time.Now(), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "update invitation with invalid domain", + invitation: domains.Invitation{ + InviteeUserID: invitation.InviteeUserID, + DomainID: testsutil.GenerateUUID(t), + RejectedAt: time.Now(), + }, + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.UpdateRejection(context.Background(), tc.invitation) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func TestInvitationDelete(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM invitations") + require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM domains") + require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err)) + }) + repo := postgres.NewRepository(database) + + domainID := saveDomain(t, repo) + + invitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: domainID, + RoleID: testsutil.GenerateUUID(t), + CreatedAt: time.Now(), + } + err := repo.SaveInvitation(context.Background(), invitation) + require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) + + cases := []struct { + desc string + invitation domains.Invitation + err error + }{ + { + desc: "delete invitation successfully", + invitation: domains.Invitation{ + InviteeUserID: invitation.InviteeUserID, + DomainID: invitation.DomainID, + }, + err: nil, + }, + { + desc: "delete invitation with invalid invitation id", + invitation: domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + }, + err: repoerr.ErrNotFound, + }, + { + desc: "delete invitation with empty invitation id", + invitation: domains.Invitation{}, + err: repoerr.ErrNotFound, + }, + } + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := repo.DeleteInvitation(context.Background(), tc.invitation.InviteeUserID, tc.invitation.DomainID) + assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + }) + } +} + +func saveDomain(t *testing.T, repo domains.Repository) string { + domain := domains.Domain{ + ID: testsutil.GenerateUUID(t), + Name: "test", + Alias: "test", + Tags: []string{"test"}, + Metadata: map[string]interface{}{ + "test": "test", + }, + CreatedBy: userID, + UpdatedBy: userID, + CreatedAt: time.Now().UTC().Truncate(time.Millisecond), + UpdatedAt: time.Now().UTC().Truncate(time.Millisecond), + Status: domains.EnabledStatus, + } + + _, err := repo.SaveDomain(context.Background(), domain) + require.Nil(t, err, fmt.Sprintf("failed to save domain %s", domain.ID)) + + return domain.ID +} diff --git a/domains/private/service.go b/domains/private/service.go index 53de8a3f3..8d944c0df 100644 --- a/domains/private/service.go +++ b/domains/private/service.go @@ -38,7 +38,7 @@ func (svc service) RetrieveEntity(ctx context.Context, id string) (domains.Domai if err == nil { return domains.Domain{ID: id, Status: status}, nil } - dom, err := svc.repo.RetrieveByID(ctx, id) + dom, err := svc.repo.RetrieveDomainByID(ctx, id) if err != nil { return domains.Domain{}, errors.Wrap(svcerr.ErrViewEntity, err) } diff --git a/domains/service.go b/domains/service.go index 7f6cba134..8d88b7fed 100644 --- a/domains/service.go +++ b/domains/service.go @@ -61,13 +61,13 @@ func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Do d.CreatedAt = time.Now() // Domain is created in repo first, because Roles table have foreign key relation with Domain ID - dom, err := svc.repo.Save(ctx, d) + dom, err := svc.repo.SaveDomain(ctx, d) if err != nil { return Domain{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err) } defer func() { if retErr != nil { - if errRollBack := svc.repo.Delete(ctx, domainID); errRollBack != nil { + if errRollBack := svc.repo.DeleteDomain(ctx, domainID); errRollBack != nil { retErr = errors.Wrap(retErr, errors.Wrap(errRollbackRepo, errRollBack)) } } @@ -100,9 +100,9 @@ func (svc service) RetrieveDomain(ctx context.Context, session authn.Session, id var err error switch session.SuperAdmin { case true: - domain, err = svc.repo.RetrieveByID(ctx, id) + domain, err = svc.repo.RetrieveDomainByID(ctx, id) default: - domain, err = svc.repo.RetrieveByUserAndID(ctx, session.UserID, id) + domain, err = svc.repo.RetrieveDomainByUserAndID(ctx, session.UserID, id) } if err != nil { return Domain{}, errors.Wrap(svcerr.ErrViewEntity, err) @@ -114,7 +114,7 @@ func (svc service) UpdateDomain(ctx context.Context, session authn.Session, id s updatedAt := time.Now() d.UpdatedAt = &updatedAt d.UpdatedBy = &session.UserID - dom, err := svc.repo.Update(ctx, id, d) + dom, err := svc.repo.UpdateDomain(ctx, id, d) if err != nil { return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } @@ -124,7 +124,7 @@ func (svc service) UpdateDomain(ctx context.Context, session authn.Session, id s func (svc service) EnableDomain(ctx context.Context, session authn.Session, id string) (Domain, error) { status := EnabledStatus updatedAt := time.Now() - dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) + dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) if err != nil { return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } @@ -138,7 +138,7 @@ func (svc service) EnableDomain(ctx context.Context, session authn.Session, id s func (svc service) DisableDomain(ctx context.Context, session authn.Session, id string) (Domain, error) { status := DisabledStatus updatedAt := time.Now() - dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) + dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) if err != nil { return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } @@ -153,7 +153,7 @@ func (svc service) DisableDomain(ctx context.Context, session authn.Session, id func (svc service) FreezeDomain(ctx context.Context, session authn.Session, id string) (Domain, error) { status := FreezeStatus updatedAt := time.Now() - dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) + dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt}) if err != nil { return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } @@ -176,3 +176,133 @@ func (svc service) ListDomains(ctx context.Context, session authn.Session, p Pag } return dp, nil } + +func (svc *service) SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) error { + if _, err := svc.repo.RetrieveRole(ctx, invitation.RoleID); err != nil { + return errors.Wrap(svcerr.ErrInvalidRole, err) + } + invitation.InvitedBy = session.UserID + + invitation.CreatedAt = time.Now() + + if err := svc.repo.SaveInvitation(ctx, invitation); err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + return nil +} + +func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation Invitation, err error) { + inv, err := svc.repo.RetrieveInvitation(ctx, inviteeUserID, domainID) + if err != nil { + return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + role, err := svc.repo.RetrieveRole(ctx, inv.RoleID) + if err != nil { + return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + actions, err := svc.repo.RoleListActions(ctx, inv.RoleID) + if err != nil { + return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + inv.Actions = actions + inv.RoleName = role.Name + + return inv, nil +} + +func (svc *service) ListInvitations(ctx context.Context, session authn.Session, page InvitationPageMeta) (invitations InvitationPage, err error) { + ip, err := svc.repo.RetrieveAllInvitations(ctx, page) + if err != nil { + return InvitationPage{}, err + } + return ip, nil +} + +func (svc *service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error { + inv, err := svc.repo.RetrieveInvitation(ctx, session.UserID, domainID) + if err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + if inv.InviteeUserID != session.UserID { + return svcerr.ErrAuthorization + } + + if !inv.ConfirmedAt.IsZero() { + return svcerr.ErrInvitationAlreadyAccepted + } + + if !inv.RejectedAt.IsZero() { + return svcerr.ErrInvitationAlreadyRejected + } + + session.DomainID = domainID + + if _, err := svc.RoleAddMembers(ctx, session, domainID, inv.RoleID, []string{session.UserID}); err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + inv.ConfirmedAt = time.Now() + inv.UpdatedAt = inv.ConfirmedAt + + if err := svc.repo.UpdateConfirmation(ctx, inv); err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return nil +} + +func (svc *service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error { + inv, err := svc.repo.RetrieveInvitation(ctx, session.UserID, domainID) + if err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + if inv.InviteeUserID != session.UserID { + return svcerr.ErrAuthorization + } + + if !inv.ConfirmedAt.IsZero() { + return svcerr.ErrInvitationAlreadyAccepted + } + + if !inv.RejectedAt.IsZero() { + return svcerr.ErrInvitationAlreadyRejected + } + + inv.RejectedAt = time.Now() + inv.UpdatedAt = inv.RejectedAt + + if err := svc.repo.UpdateRejection(ctx, inv); err != nil { + return errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return nil +} + +func (svc *service) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) error { + if session.UserID == inviteeUserID { + if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + return nil + } + + inv, err := svc.repo.RetrieveInvitation(ctx, inviteeUserID, domainID) + if err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + if inv.InvitedBy == session.UserID { + if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + return nil + } + + if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil { + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/domains/service_test.go b/domains/service_test.go index fbfe53475..b4aaaf0b8 100644 --- a/domains/service_test.go +++ b/domains/service_test.go @@ -53,8 +53,13 @@ var ( CreatedBy: validID, UpdatedBy: validID, } - userID = testsutil.GenerateUUID(&testing.T{}) - validSession = authn.Session{UserID: userID} + userID = testsutil.GenerateUUID(&testing.T{}) + validSession = authn.Session{UserID: userID} + validInvitation = domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(&testing.T{}), + DomainID: testsutil.GenerateUUID(&testing.T{}), + RoleID: testsutil.GenerateUUID(&testing.T{}), + } ) var ( @@ -166,8 +171,8 @@ func TestCreateDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("Save", mock.Anything, mock.Anything).Return(tc.d, tc.saveDomainErr) - repoCall1 := drepo.On("Delete", mock.Anything, mock.Anything).Return(tc.deleteDomainErr) + repoCall := drepo.On("SaveDomain", mock.Anything, mock.Anything).Return(tc.d, tc.saveDomainErr) + repoCall1 := drepo.On("DeleteDomain", mock.Anything, mock.Anything).Return(tc.deleteDomainErr) repoCall2 := drepo.On("AddRoles", mock.Anything, mock.Anything).Return([]roles.RoleProvision{}, tc.addRolesErr) policyCall := policy.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr) policyCall1 := policy.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr) @@ -228,8 +233,8 @@ func TestRetrieveDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("RetrieveByID", context.Background(), tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr) - repoCall1 := drepo.On("RetrieveByUserAndID", context.Background(), tc.session.UserID, tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr) + repoCall := drepo.On("RetrieveDomainByID", context.Background(), tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr) + repoCall1 := drepo.On("RetrieveDomainByUserAndID", context.Background(), tc.session.UserID, tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr) domain, err := svc.RetrieveDomain(context.Background(), tc.session, tc.domainID) assert.True(t, errors.Contains(err, tc.err)) assert.Equal(t, tc.retrieveDomainRes, domain) @@ -292,7 +297,7 @@ func TestUpdateDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.updateRes, tc.updateErr) + repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.updateRes, tc.updateErr) domain, err := svc.UpdateDomain(context.Background(), tc.session, tc.domainID, tc.updateReq) assert.True(t, errors.Contains(err, tc.err)) assert.Equal(t, tc.updateRes, domain) @@ -354,7 +359,7 @@ func TestEnableDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.enableRes, tc.enableErr) + repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.enableRes, tc.enableErr) cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr) domain, err := svc.EnableDomain(context.Background(), tc.session, tc.domainID) assert.True(t, errors.Contains(err, tc.err)) @@ -418,7 +423,7 @@ func TestDisableDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.disableRes, tc.disableErr) + repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.disableRes, tc.disableErr) cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr) domain, err := svc.DisableDomain(context.Background(), tc.session, tc.domainID) assert.True(t, errors.Contains(err, tc.err)) @@ -482,7 +487,7 @@ func TestFreezeDomain(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.freezeRes, tc.freezeErr) + repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.freezeRes, tc.freezeErr) cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr) domain, err := svc.FreezeDomain(context.Background(), tc.session, tc.domainID) assert.True(t, errors.Contains(err, tc.err)) @@ -565,3 +570,450 @@ func TestListDomains(t *testing.T) { }) } } + +func TestSendInvitation(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + session authn.Session + req domains.Invitation + retrieveRoleErr error + createInvitationErr error + err error + }{ + { + desc: "send invitation successful", + session: validSession, + req: validInvitation, + err: nil, + }, + { + desc: "send invitation with invalid role id", + session: validSession, + req: domains.Invitation{ + DomainID: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + RoleID: inValid, + }, + retrieveRoleErr: repoerr.ErrNotFound, + err: svcerr.ErrInvalidRole, + }, + { + desc: "send invitations with failed to save invitation", + session: validSession, + req: validInvitation, + createInvitationErr: repoerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveRole", context.Background(), tc.req.RoleID).Return(roles.Role{}, tc.retrieveRoleErr) + repoCall1 := drepo.On("SaveInvitation", context.Background(), mock.Anything).Return(tc.createInvitationErr) + err := svc.SendInvitation(context.Background(), tc.session, tc.req) + assert.True(t, errors.Contains(err, tc.err)) + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestViewInvitation(t *testing.T) { + svc := newService() + + validInvitation := domains.Invitation{ + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + Actions: []string{"read", "delete"}, + CreatedAt: time.Now().Add(-time.Hour), + UpdatedAt: time.Now().Add(-time.Hour), + ConfirmedAt: time.Now().Add(-time.Hour), + } + cases := []struct { + desc string + userID string + domainID string + session authn.Session + req domains.Invitation + resp domains.Invitation + retrieveInvitationErr error + listRolesErr error + retrieveRoleErr error + err error + }{ + { + desc: "view invitation successful", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + session: validSession, + resp: validInvitation, + err: nil, + }, + { + desc: "view invitation with error retrieving invitation", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + session: validSession, + retrieveInvitationErr: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + { + desc: "view invitation with failed to retrieve role actions", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + session: validSession, + listRolesErr: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + { + desc: "view invitation with failed to retrieve role", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + session: validSession, + retrieveRoleErr: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.retrieveInvitationErr) + repoCall1 := drepo.On("RoleListActions", context.Background(), tc.resp.RoleID).Return(tc.resp.Actions, tc.listRolesErr) + repoCall2 := drepo.On("RetrieveRole", context.Background(), tc.resp.RoleID).Return(roles.Role{}, tc.retrieveRoleErr) + inv, err := svc.ViewInvitation(context.Background(), tc.session, tc.userID, tc.domainID) + assert.True(t, errors.Contains(err, tc.err)) + assert.Equal(t, tc.resp, inv, tc.desc) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} + +func TestListInvitations(t *testing.T) { + svc := newService() + + validPageMeta := domains.InvitationPageMeta{ + Offset: 0, + Limit: 10, + } + validResp := domains.InvitationPage{ + Total: 1, + Offset: 0, + Limit: 10, + Invitations: []domains.Invitation{ + { + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + RoleName: "admin", + CreatedAt: time.Now().Add(-time.Hour), + UpdatedAt: time.Now().Add(-time.Hour), + ConfirmedAt: time.Now().Add(-time.Hour), + }, + }, + } + + cases := []struct { + desc string + session authn.Session + page domains.InvitationPageMeta + resp domains.InvitationPage + err error + repoErr error + }{ + { + desc: "list invitations successful", + session: validSession, + page: validPageMeta, + resp: validResp, + err: nil, + repoErr: nil, + }, + + { + desc: "list invitations unsuccessful", + session: validSession, + page: validPageMeta, + err: repoerr.ErrViewEntity, + resp: domains.InvitationPage{}, + repoErr: repoerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveAllInvitations", context.Background(), mock.Anything).Return(tc.resp, tc.repoErr) + resp, err := svc.ListInvitations(context.Background(), tc.session, tc.page) + assert.Equal(t, tc.err, err, tc.desc) + assert.Equal(t, tc.resp, resp, tc.desc) + repoCall.Unset() + }) + } +} + +func TestAcceptInvitation(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + domainID string + session authn.Session + resp domains.Invitation + retrieveInvitationErr error + updateConfirmationErr error + addRoleMemberErr error + err error + }{ + { + desc: "accept invitation successful", + domainID: validID, + session: validSession, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "accept invitation with failed to retrieve invitation", + session: validSession, + retrieveInvitationErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "accept invitation with of different user", + session: validSession, + resp: domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + }, + err: svcerr.ErrAuthorization, + }, + { + desc: "accept invitation with failed to add role member", + domainID: validID, + session: validSession, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + }, + addRoleMemberErr: repoerr.ErrMalformedEntity, + err: svcerr.ErrUpdateEntity, + }, + { + desc: "accept invitation with failed update confirmation", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: validID, + RoleID: testsutil.GenerateUUID(t), + }, + updateConfirmationErr: repoerr.ErrNotFound, + err: svcerr.ErrUpdateEntity, + }, + { + desc: "accept invitation that is already confirmed", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + ConfirmedAt: time.Now(), + }, + err: svcerr.ErrInvitationAlreadyAccepted, + }, + { + desc: "accept rejected invitation", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + RejectedAt: time.Now(), + }, + err: svcerr.ErrInvitationAlreadyRejected, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveInvitation", context.Background(), tc.session.UserID, tc.domainID).Return(tc.resp, tc.retrieveInvitationErr) + repoCall1 := drepo.On("RetrieveEntityRole", context.Background(), tc.domainID, tc.resp.RoleID).Return(roles.Role{}, tc.addRoleMemberErr) + policyCall := policy.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addRoleMemberErr) + repoCall2 := drepo.On("RoleAddMembers", context.Background(), mock.Anything, []string{tc.resp.InviteeUserID}).Return([]string{}, tc.addRoleMemberErr) + repoCall3 := drepo.On("UpdateConfirmation", context.Background(), mock.Anything).Return(tc.updateConfirmationErr) + err := svc.AcceptInvitation(context.Background(), tc.session, tc.domainID) + assert.True(t, errors.Contains(err, tc.err)) + repoCall.Unset() + repoCall1.Unset() + policyCall.Unset() + repoCall2.Unset() + repoCall3.Unset() + }) + } +} + +func TestRejectInvitation(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + domainID string + session authn.Session + resp domains.Invitation + retrieveInvitationErr error + updateConfirmationErr error + addRoleMemberErr error + err error + }{ + { + desc: "reject invitation successful", + domainID: validID, + session: validSession, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + }, + err: nil, + }, + { + desc: "reject invitation with failed to retrieve invitation", + session: validSession, + retrieveInvitationErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + { + desc: "reject invitation with of different user", + session: validSession, + resp: domains.Invitation{ + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + }, + err: svcerr.ErrAuthorization, + }, + { + desc: "reject invitation with failed update confirmation", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: validID, + RoleID: testsutil.GenerateUUID(t), + }, + updateConfirmationErr: repoerr.ErrNotFound, + err: svcerr.ErrUpdateEntity, + }, + { + desc: "reject invitation that is already confirmed", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + ConfirmedAt: time.Now(), + }, + err: svcerr.ErrInvitationAlreadyAccepted, + }, + { + desc: "reject rejected invitation", + session: validSession, + domainID: validID, + resp: domains.Invitation{ + InviteeUserID: userID, + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + RejectedAt: time.Now(), + }, + err: svcerr.ErrInvitationAlreadyRejected, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveInvitation", context.Background(), tc.session.UserID, tc.domainID).Return(tc.resp, tc.retrieveInvitationErr) + repoCall1 := drepo.On("UpdateRejection", context.Background(), mock.Anything).Return(tc.updateConfirmationErr) + err := svc.RejectInvitation(context.Background(), tc.session, tc.domainID) + assert.True(t, errors.Contains(err, tc.err)) + repoCall.Unset() + repoCall1.Unset() + }) + } +} + +func TestDeleteInvitation(t *testing.T) { + svc := newService() + + cases := []struct { + desc string + userID string + domainID string + resp domains.Invitation + retrieveInvitationErr error + deleteInvitationErr error + err error + }{ + { + desc: "delete invitations successful", + userID: testsutil.GenerateUUID(t), + domainID: testsutil.GenerateUUID(t), + resp: validInvitation, + err: nil, + }, + { + desc: "delete invitations for the same user", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + resp: validInvitation, + err: nil, + }, + { + desc: "delete invitations for the invited user", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + resp: validInvitation, + err: nil, + }, + { + desc: "delete invitation with error retrieving invitation", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + resp: domains.Invitation{}, + retrieveInvitationErr: repoerr.ErrNotFound, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "delete invitation with error deleting invitation", + userID: validInvitation.InviteeUserID, + domainID: validInvitation.DomainID, + resp: domains.Invitation{}, + deleteInvitationErr: repoerr.ErrNotFound, + err: svcerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := drepo.On("RetrieveInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.retrieveInvitationErr) + repoCall1 := drepo.On("DeleteInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.deleteInvitationErr) + err := svc.DeleteInvitation(context.Background(), authn.Session{}, tc.userID, tc.domainID) + assert.True(t, errors.Contains(err, tc.err)) + repoCall.Unset() + repoCall1.Unset() + }) + } +} diff --git a/invitations/state.go b/domains/state.go similarity index 92% rename from invitations/state.go rename to domains/state.go index 50649e865..e18f571bf 100644 --- a/invitations/state.go +++ b/domains/state.go @@ -1,7 +1,7 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package invitations +package domains import ( "encoding/json" @@ -14,7 +14,7 @@ import ( type State uint8 const ( - All State = iota // All is used for querying purposes to list invitations irrespective of their state - both pending and accepted. + AllState State = iota // All is used for querying purposes to list invitations irrespective of their state - both pending and accepted. Pending // Pending is the state of an invitation that has not been accepted yet. Accepted // Accepted is the state of an invitation that has been accepted. Rejected // Rejected is the state of an invitation that has been rejected. @@ -32,7 +32,7 @@ const ( // String converts invitation state to string literal. func (s State) String() string { switch s { - case All: + case AllState: return all case Pending: return pending @@ -49,7 +49,7 @@ func (s State) String() string { func ToState(status string) (State, error) { switch status { case all: - return All, nil + return AllState, nil case pending: return Pending, nil case accepted: diff --git a/invitations/state_test.go b/domains/state_test.go similarity index 51% rename from invitations/state_test.go rename to domains/state_test.go index 451f209e6..e31756143 100644 --- a/invitations/state_test.go +++ b/domains/state_test.go @@ -1,27 +1,27 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 -package invitations_test +package domains_test import ( "testing" apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" + "github.com/absmach/supermq/domains" "github.com/stretchr/testify/assert" ) func TestState_String(t *testing.T) { tests := []struct { name string - state invitations.State + state domains.State expected string }{ - {"Pending", invitations.Pending, "pending"}, - {"Accepted", invitations.Accepted, "accepted"}, - {"Rejected", invitations.Rejected, "rejected"}, - {"All", invitations.All, "all"}, - {"Unknown", invitations.State(100), "unknown"}, + {"Pending", domains.Pending, "pending"}, + {"Accepted", domains.Accepted, "accepted"}, + {"Rejected", domains.Rejected, "rejected"}, + {"All", domains.AllState, "all"}, + {"Unknown", domains.State(100), "unknown"}, } for _, tt := range tests { @@ -34,18 +34,18 @@ func TestToState(t *testing.T) { tests := []struct { name string status string - state invitations.State + state domains.State err error }{ - {"Pending", "pending", invitations.Pending, nil}, - {"Accepted", "accepted", invitations.Accepted, nil}, - {"Rejected", "rejected", invitations.Rejected, nil}, - {"All", "all", invitations.All, nil}, - {"Unknown", "unknown", invitations.State(0), apiutil.ErrInvitationState}, + {"Pending", "pending", domains.Pending, nil}, + {"Accepted", "accepted", domains.Accepted, nil}, + {"Rejected", "rejected", domains.Rejected, nil}, + {"All", "all", domains.AllState, nil}, + {"Unknown", "unknown", domains.State(0), apiutil.ErrInvitationState}, } for _, tt := range tests { - got, err := invitations.ToState(tt.status) + got, err := domains.ToState(tt.status) assert.Equal(t, tt.err, err, "ToState() error = %v, expected %v", err, tt.err) assert.Equal(t, tt.state, got, "ToState() = %v, expected %v", got, tt.state) } @@ -54,15 +54,15 @@ func TestToState(t *testing.T) { func TestState_MarshalJSON(t *testing.T) { tests := []struct { name string - state invitations.State + state domains.State expected []byte err error }{ - {"Pending", invitations.Pending, []byte(`"pending"`), nil}, - {"Accepted", invitations.Accepted, []byte(`"accepted"`), nil}, - {"Rejected", invitations.Rejected, []byte(`"rejected"`), nil}, - {"All", invitations.All, []byte(`"all"`), nil}, - {"Unknown", invitations.State(100), []byte(`"unknown"`), nil}, + {"Pending", domains.Pending, []byte(`"pending"`), nil}, + {"Accepted", domains.Accepted, []byte(`"accepted"`), nil}, + {"Rejected", domains.Rejected, []byte(`"rejected"`), nil}, + {"All", domains.AllState, []byte(`"all"`), nil}, + {"Unknown", domains.State(100), []byte(`"unknown"`), nil}, } for _, tt := range tests { @@ -76,18 +76,18 @@ func TestState_UnmarshalJSON(t *testing.T) { tests := []struct { name string data []byte - state invitations.State + state domains.State err error }{ - {"Pending", []byte(`"pending"`), invitations.Pending, nil}, - {"Accepted", []byte(`"accepted"`), invitations.Accepted, nil}, - {"Rejected", []byte(`"rejected"`), invitations.Rejected, nil}, - {"All", []byte(`"all"`), invitations.All, nil}, - {"Unknown", []byte(`"unknown"`), invitations.State(0), apiutil.ErrInvitationState}, + {"Pending", []byte(`"pending"`), domains.Pending, nil}, + {"Accepted", []byte(`"accepted"`), domains.Accepted, nil}, + {"Rejected", []byte(`"rejected"`), domains.Rejected, nil}, + {"All", []byte(`"all"`), domains.AllState, nil}, + {"Unknown", []byte(`"unknown"`), domains.State(0), apiutil.ErrInvitationState}, } for _, tt := range tests { - var state invitations.State + var state domains.State err := state.UnmarshalJSON(tt.data) assert.Equal(t, tt.err, err, "State.UnmarshalJSON() error = %v, expected %v", err, tt.err) assert.Equal(t, tt.state, state, "State.UnmarshalJSON() = %v, expected %v", state, tt.state) diff --git a/domains/tracing/tracing.go b/domains/tracing/tracing.go index 40e675ea3..b8c057907 100644 --- a/domains/tracing/tracing.go +++ b/domains/tracing/tracing.go @@ -80,3 +80,64 @@ func (tm *tracingMiddleware) ListDomains(ctx context.Context, session authn.Sess defer span.End() return tm.svc.ListDomains(ctx, session, p) } + +func (tm *tracingMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) { + ctx, span := tm.tracer.Start(ctx, "send_invitation", trace.WithAttributes( + attribute.String("domain_id", invitation.DomainID), + attribute.String("invitee_user_id", invitation.InviteeUserID), + )) + defer span.End() + + return tm.svc.SendInvitation(ctx, session, invitation) +} + +func (tm *tracingMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domain string) (invitation domains.Invitation, err error) { + ctx, span := tm.tracer.Start(ctx, "view_invitation", trace.WithAttributes( + attribute.String("invitee_user_id", inviteeUserID), + attribute.String("domain_id", domain), + )) + defer span.End() + + return tm.svc.ViewInvitation(ctx, session, inviteeUserID, domain) +} + +func (tm *tracingMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) { + ctx, span := tm.tracer.Start(ctx, "list_invitations", trace.WithAttributes( + attribute.Int("limit", int(pm.Limit)), + attribute.Int("offset", int(pm.Offset)), + attribute.String("invitee_user_id", pm.InviteeUserID), + attribute.String("domain_id", pm.DomainID), + attribute.String("invited_by", pm.InvitedBy), + )) + defer span.End() + + return tm.svc.ListInvitations(ctx, session, pm) +} + +func (tm *tracingMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + ctx, span := tm.tracer.Start(ctx, "accept_invitation", trace.WithAttributes( + attribute.String("domain_id", domainID), + )) + defer span.End() + + return tm.svc.AcceptInvitation(ctx, session, domainID) +} + +func (tm *tracingMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { + ctx, span := tm.tracer.Start(ctx, "reject_invitation", trace.WithAttributes( + attribute.String("domain_id", domainID), + )) + defer span.End() + + return tm.svc.RejectInvitation(ctx, session, domainID) +} + +func (tm *tracingMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) { + ctx, span := tm.tracer.Start(ctx, "delete_invitation", trace.WithAttributes( + attribute.String("invitee_user_id", inviteeUserID), + attribute.String("domain_id", domainID), + )) + defer span.End() + + return tm.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID) +} diff --git a/invitations/README.md b/invitations/README.md deleted file mode 100644 index 6d2afa8de..000000000 --- a/invitations/README.md +++ /dev/null @@ -1,80 +0,0 @@ -# Invitation Service - -Invitation service is responsible for sending invitations to users to join a domain. - -## Configuration - -The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. - -| Variable | Description | Default | -| ------------------------------- | ------------------------------------------------ | ----------------------- | -| SMQ_INVITATION_LOG_LEVEL | Log level for the Invitation service | debug | -| SMQ_USERS_URL | Users service URL | | -| SMQ_DOMAINS_URL | Domains service URL | | -| SMQ_INVITATIONS_HTTP_HOST | Invitation service HTTP listening host | localhost | -| SMQ_INVITATIONS_HTTP_PORT | Invitation service HTTP listening port | 9020 | -| SMQ_INVITATIONS_HTTP_SERVER_CERT | Invitation service server certificate | "" | -| SMQ_INVITATIONS_HTTP_SERVER_KEY | Invitation service server key | "" | -| SMQ_AUTH_GRPC_URL | Auth service gRPC URL | localhost:8181 | -| SMQ_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s | -| SMQ_AUTH_GRPC_CLIENT_CERT | Path to client certificate in PEM format | "" | -| SMQ_AUTH_GRPC_CLIENT_KEY | Path to client key in PEM format | "" | -| SMQ_AUTH_GRPC_CLIENT_CA_CERTS | Path to trusted CAs in PEM format | "" | -| SMQ_INVITATIONS_DB_HOST | Invitation service database host | localhost | -| SMQ_INVITATIONS_DB_USER | Invitation service database user | supermq | -| SMQ_INVITATIONS_DB_PASS | Invitation service database password | supermq | -| SMQ_INVITATIONS_DB_PORT | Invitation service database port | 5432 | -| SMQ_INVITATIONS_DB_NAME | Invitation service database name | invitations | -| SMQ_INVITATIONS_DB_SSL_MODE | Invitation service database SSL mode | disable | -| SMQ_INVITATIONS_DB_SSL_CERT | Invitation service database SSL certificate | "" | -| SMQ_INVITATIONS_DB_SSL_KEY | Invitation service database SSL key | "" | -| SMQ_INVITATIONS_DB_SSL_ROOT_CERT | Invitation service database SSL root certificate | "" | -| SMQ_INVITATIONS_INSTANCE_ID | Invitation service instance ID | | - -## Deployment - -The service itself is distributed as Docker container. Check the [`invitation`](https://github.com/absmach/amdm/blob/main/docker/docker-compose.yml) service section in docker-compose file to see how service is deployed. - -To start the service outside of the container, execute the following shell script: - -```bash -# download the latest version of the service -git clone https://github.com/absmach/supermq - -cd supermq - -# compile the http -make invitation - -# copy binary to bin -make install - -# set the environment variables and run the service -SMQ_INVITATION_LOG_LEVEL=info \ -SMQ_INVITATIONS_ENDPOINT=/invitations \ -SMQ_USERS_URL="http://localhost:9002" \ -SMQ_DOMAINS_URL="http://localhost:8189" \ -SMQ_INVITATIONS_HTTP_HOST=localhost \ -SMQ_INVITATIONS_HTTP_PORT=9020 \ -SMQ_INVITATIONS_HTTP_SERVER_CERT="" \ -SMQ_INVITATIONS_HTTP_SERVER_KEY="" \ -SMQ_AUTH_GRPC_URL=localhost:8181 \ -SMQ_AUTH_GRPC_TIMEOUT=1s \ -SMQ_AUTH_GRPC_CLIENT_CERT="" \ -SMQ_AUTH_GRPC_CLIENT_KEY="" \ -SMQ_AUTH_GRPC_CLIENT_CA_CERTS="" \ -SMQ_INVITATIONS_DB_HOST=localhost \ -SMQ_INVITATIONS_DB_USER=supermq \ -SMQ_INVITATIONS_DB_PASS=supermq \ -SMQ_INVITATIONS_DB_PORT=5432 \ -SMQ_INVITATIONS_DB_NAME=invitations \ -SMQ_INVITATIONS_DB_SSL_MODE=disable \ -SMQ_INVITATIONS_DB_SSL_CERT="" \ -SMQ_INVITATIONS_DB_SSL_KEY="" \ -SMQ_INVITATIONS_DB_SSL_ROOT_CERT="" \ -$GOBIN/supermq-invitation -``` - -## Usage - -For more information about service capabilities and its usage, please check out the [API documentation](https://docs.api.supermq.abstractmachines.fr/?urls.primaryName=invitations.yml). diff --git a/invitations/api/doc.go b/invitations/api/doc.go deleted file mode 100644 index 7cd03c095..000000000 --- a/invitations/api/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api diff --git a/invitations/api/endpoint.go b/invitations/api/endpoint.go deleted file mode 100644 index efde368a7..000000000 --- a/invitations/api/endpoint.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - - api "github.com/absmach/supermq/api/http" - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/go-kit/kit/endpoint" -) - -// InvitationSent is the message returned when an invitation is sent. -const InvitationSent = "invitation sent" - -func sendInvitationEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(sendInvitationReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - session.DomainID = req.DomainID - invitation := invitations.Invitation{ - UserID: req.UserID, - DomainID: req.DomainID, - Relation: req.Relation, - Resend: req.Resend, - } - - if err := svc.SendInvitation(ctx, session, invitation); err != nil { - return nil, err - } - - return sendInvitationRes{ - Message: InvitationSent, - }, nil - } -} - -func viewInvitationEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(invitationReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - session.DomainID = req.domainID - invitation, err := svc.ViewInvitation(ctx, session, req.userID, req.domainID) - if err != nil { - return nil, err - } - - return viewInvitationRes{ - Invitation: invitation, - }, nil - } -} - -func listInvitationsEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(listInvitationsReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - session.DomainID = req.DomainID - - page, err := svc.ListInvitations(ctx, session, req.Page) - if err != nil { - return nil, err - } - - return listInvitationsRes{ - page, - }, nil - } -} - -func acceptInvitationEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(acceptInvitationReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - - if err := svc.AcceptInvitation(ctx, session, req.DomainID); err != nil { - return nil, err - } - - return acceptInvitationRes{}, nil - } -} - -func rejectInvitationEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(acceptInvitationReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - - if err := svc.RejectInvitation(ctx, session, req.DomainID); err != nil { - return nil, err - } - - return rejectInvitationRes{}, nil - } -} - -func deleteInvitationEndpoint(svc invitations.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (interface{}, error) { - req := request.(invitationReq) - if err := req.validate(); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - - session, ok := ctx.Value(api.SessionKey).(authn.Session) - if !ok { - return nil, svcerr.ErrAuthorization - } - session.DomainID = req.domainID - - if err := svc.DeleteInvitation(ctx, session, req.userID, req.domainID); err != nil { - return nil, err - } - - return deleteInvitationRes{}, nil - } -} diff --git a/invitations/api/endpoint_test.go b/invitations/api/endpoint_test.go deleted file mode 100644 index 540c16966..000000000 --- a/invitations/api/endpoint_test.go +++ /dev/null @@ -1,672 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api_test - -import ( - "fmt" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/invitations/api" - "github.com/absmach/supermq/invitations/mocks" - smqlog "github.com/absmach/supermq/logger" - smqauthn "github.com/absmach/supermq/pkg/authn" - authnmocks "github.com/absmach/supermq/pkg/authn/mocks" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -var ( - validToken = "valid" - validContenType = "application/json" - validID = testsutil.GenerateUUID(&testing.T{}) - domainID = testsutil.GenerateUUID(&testing.T{}) -) - -type testRequest struct { - client *http.Client - method string - url string - token string - contentType string - body io.Reader -} - -func (tr testRequest) make() (*http.Response, error) { - req, err := http.NewRequest(tr.method, tr.url, tr.body) - if err != nil { - return nil, err - } - - if tr.token != "" { - req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token) - } - - if tr.contentType != "" { - req.Header.Set("Content-Type", tr.contentType) - } - - return tr.client.Do(req) -} - -func newIvitationsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { - svc := new(mocks.Service) - logger := smqlog.NewMock() - authn := new(authnmocks.Authentication) - mux := api.MakeHandler(svc, logger, authn, "test") - return httptest.NewServer(mux), svc, authn -} - -func TestSendInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - - cases := []struct { - desc string - token string - data string - contentType string - status int - authnRes smqauthn.Session - authnErr error - svcErr error - }{ - { - desc: "valid request", - token: validToken, - data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, domainID, "domain"), - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - status: http.StatusCreated, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, validID, "domain"), - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "empty domain_id", - token: validToken, - data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, "", "domain"), - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid content type", - token: validToken, - data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, validID, "domain"), - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - status: http.StatusUnsupportedMediaType, - contentType: "text/plain", - svcErr: nil, - }, - { - desc: "invalid data", - token: validToken, - data: `data`, - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with service error", - token: validToken, - data: fmt.Sprintf(`{"user_id": "%s", "domain_id": "%s", "relation": "%s"}`, validID, domainID, "domain"), - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - status: http.StatusForbidden, - contentType: validContenType, - svcErr: svcerr.ErrAuthorization, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("SendInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodPost, - url: is.URL + "/invitations", - token: tc.token, - contentType: tc.contentType, - body: strings.NewReader(tc.data), - } - - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} - -func TestListInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - - cases := []struct { - desc string - token string - query string - contentType string - status int - svcErr error - authnRes smqauthn.Session - authnErr error - }{ - { - desc: "valid request", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with offset", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "offset=1", - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with invalid offset", - token: validToken, - query: "offset=invalid", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with limit", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "limit=1", - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with invalid limit", - token: validToken, - query: "limit=invalid", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with user_id", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: fmt.Sprintf("user_id=%s", validID), - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with duplicate user_id", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "user_id=1&user_id=2", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with invited_by", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: fmt.Sprintf("invited_by=%s", validID), - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with duplicate invited_by", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "invited_by=1&invited_by=2", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with relation", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: fmt.Sprintf("relation=%s", "relation"), - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with duplicate relation", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "relation=1&relation=2", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with state", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - query: "state=pending", - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with invalid state", - token: validToken, - query: "state=invalid", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with duplicate state", - token: validToken, - query: "state=all&state=all", - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with service error", - authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID}, - token: validToken, - status: http.StatusForbidden, - contentType: validContenType, - svcErr: svcerr.ErrAuthorization, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("ListInvitations", mock.Anything, tc.authnRes, mock.Anything).Return(invitations.InvitationPage{}, tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodGet, - url: is.URL + "/invitations?" + tc.query, - token: tc.token, - contentType: tc.contentType, - } - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} - -func TestViewInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - - cases := []struct { - desc string - token string - domainID string - userID string - contentType string - status int - svcErr error - authnRes smqauthn.Session - authnErr error - }{ - { - desc: "valid request", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - userID: validID, - domainID: domainID, - status: http.StatusOK, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - userID: validID, - domainID: domainID, - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with service error", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - userID: validID, - domainID: domainID, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: svcerr.ErrViewEntity, - }, - { - desc: "with empty user_id", - token: validToken, - userID: "", - domainID: domainID, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with empty domain", - token: validToken, - userID: validID, - domainID: "", - status: http.StatusNotFound, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with empty user_id and domain_id", - token: validToken, - userID: "", - domainID: "", - status: http.StatusNotFound, - contentType: validContenType, - svcErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("ViewInvitation", mock.Anything, tc.authnRes, tc.userID, tc.domainID).Return(invitations.Invitation{}, tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodGet, - url: is.URL + "/invitations/" + tc.userID + "/" + tc.domainID, - token: tc.token, - contentType: tc.contentType, - } - - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} - -func TestDeleteInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - _ = authn - - cases := []struct { - desc string - token string - domainID string - userID string - contentType string - status int - svcErr error - authnRes smqauthn.Session - authnErr error - }{ - { - desc: "valid request", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - userID: validID, - domainID: domainID, - status: http.StatusNoContent, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - userID: validID, - domainID: domainID, - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with service error", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - userID: validID, - domainID: domainID, - status: http.StatusForbidden, - contentType: validContenType, - svcErr: svcerr.ErrAuthorization, - }, - { - desc: "with empty user_id", - token: validToken, - userID: "", - domainID: domainID, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with empty domain_id", - token: validToken, - userID: validID, - domainID: "", - status: http.StatusNotFound, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with empty user_id and domain_id", - token: validToken, - userID: "", - domainID: "", - status: http.StatusNotFound, - contentType: validContenType, - svcErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("DeleteInvitation", mock.Anything, tc.authnRes, tc.userID, tc.domainID).Return(tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodDelete, - url: is.URL + "/invitations/" + tc.userID + "/" + tc.domainID, - token: tc.token, - contentType: tc.contentType, - } - - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} - -func TestAcceptInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - _ = authn - cases := []struct { - desc string - token string - data string - contentType string - status int - svcErr error - authnRes smqauthn.Session - authnErr error - }{ - { - desc: "valid request", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - token: validToken, - status: http.StatusNoContent, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "with service error", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusForbidden, - contentType: validContenType, - svcErr: svcerr.ErrAuthorization, - }, - { - desc: "invalid content type", - token: validToken, - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusUnsupportedMediaType, - contentType: "text/plain", - svcErr: nil, - }, - { - desc: "invalid data", - token: validToken, - data: `data`, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("AcceptInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodPost, - url: is.URL + "/invitations/accept", - token: tc.token, - contentType: tc.contentType, - body: strings.NewReader(tc.data), - } - - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} - -func TestRejectInvitation(t *testing.T) { - is, svc, authn := newIvitationsServer() - _ = authn - - cases := []struct { - desc string - token string - data string - contentType string - status int - svcErr error - authnRes smqauthn.Session - authnErr error - }{ - { - desc: "valid request", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusNoContent, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "invalid token", - token: "", - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusUnauthorized, - contentType: validContenType, - svcErr: nil, - }, - { - desc: "unauthorized error", - authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID}, - token: validToken, - data: fmt.Sprintf(`{"domain_id": "%s"}`, "invalid"), - status: http.StatusForbidden, - contentType: validContenType, - svcErr: svcerr.ErrAuthorization, - }, - { - desc: "invalid content type", - token: validToken, - data: fmt.Sprintf(`{"domain_id": "%s"}`, validID), - status: http.StatusUnsupportedMediaType, - contentType: "text/plain", - svcErr: nil, - }, - { - desc: "invalid data", - token: validToken, - data: `data`, - status: http.StatusBadRequest, - contentType: validContenType, - svcErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - repoCall := svc.On("RejectInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr) - req := testRequest{ - client: is.Client(), - method: http.MethodPost, - url: is.URL + "/invitations/reject", - token: tc.token, - contentType: tc.contentType, - body: strings.NewReader(tc.data), - } - - res, err := req.make() - assert.Nil(t, err, tc.desc) - assert.Equal(t, tc.status, res.StatusCode, tc.desc) - repoCall.Unset() - authnCall.Unset() - }) - } -} diff --git a/invitations/api/requests.go b/invitations/api/requests.go deleted file mode 100644 index c81c976ab..000000000 --- a/invitations/api/requests.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" -) - -const maxLimitSize = 100 - -type sendInvitationReq struct { - UserID string `json:"user_id,omitempty"` - DomainID string `json:"domain_id,omitempty"` - Relation string `json:"relation,omitempty"` - Resend bool `json:"resend,omitempty"` -} - -func (req *sendInvitationReq) validate() error { - if req.UserID == "" { - return apiutil.ErrMissingID - } - if req.DomainID == "" { - return apiutil.ErrMissingDomainID - } - if err := invitations.CheckRelation(req.Relation); err != nil { - return err - } - - return nil -} - -type listInvitationsReq struct { - invitations.Page -} - -func (req *listInvitationsReq) validate() error { - if req.Page.Limit > maxLimitSize || req.Page.Limit < 1 { - return apiutil.ErrLimitSize - } - - return nil -} - -type acceptInvitationReq struct { - DomainID string `json:"domain_id,omitempty"` -} - -func (req *acceptInvitationReq) validate() error { - if req.DomainID == "" { - return apiutil.ErrMissingDomainID - } - - return nil -} - -type invitationReq struct { - userID string - domainID string -} - -func (req *invitationReq) validate() error { - if req.userID == "" { - return apiutil.ErrMissingID - } - if req.domainID == "" { - return apiutil.ErrMissingDomainID - } - - return nil -} diff --git a/invitations/api/requests_test.go b/invitations/api/requests_test.go deleted file mode 100644 index 7b8bb5a1b..000000000 --- a/invitations/api/requests_test.go +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "fmt" - "testing" - - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/policies" - "github.com/stretchr/testify/assert" -) - -var valid = "valid" - -func TestSendInvitationReqValidation(t *testing.T) { - cases := []struct { - desc string - req sendInvitationReq - err error - }{ - { - desc: "valid request", - req: sendInvitationReq{ - UserID: valid, - DomainID: valid, - Relation: policies.DomainRelation, - Resend: true, - }, - err: nil, - }, - { - desc: "empty user ID", - req: sendInvitationReq{ - UserID: "", - DomainID: valid, - Relation: policies.DomainRelation, - Resend: true, - }, - err: apiutil.ErrMissingID, - }, - { - desc: "empty domain_id", - req: sendInvitationReq{ - UserID: valid, - DomainID: "", - Relation: policies.DomainRelation, - Resend: true, - }, - err: apiutil.ErrMissingDomainID, - }, - { - desc: "missing relation", - req: sendInvitationReq{ - UserID: valid, - DomainID: valid, - Relation: "", - Resend: true, - }, - err: apiutil.ErrMissingRelation, - }, - { - desc: "invalid relation", - req: sendInvitationReq{ - UserID: valid, - DomainID: valid, - Relation: "invalid", - Resend: true, - }, - err: apiutil.ErrInvalidRelation, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - err := tc.req.validate() - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - }) - } -} - -func TestListInvitationsReq(t *testing.T) { - cases := []struct { - desc string - req listInvitationsReq - err error - }{ - { - desc: "valid request", - req: listInvitationsReq{ - Page: invitations.Page{Limit: 1}, - }, - err: nil, - }, - { - desc: "invalid limit", - req: listInvitationsReq{ - Page: invitations.Page{Limit: 1000}, - }, - err: apiutil.ErrLimitSize, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - err := tc.req.validate() - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - }) - } -} - -func TestAcceptInvitationReq(t *testing.T) { - cases := []struct { - desc string - req acceptInvitationReq - err error - }{ - { - desc: "valid request", - req: acceptInvitationReq{ - DomainID: valid, - }, - err: nil, - }, - { - desc: "empty domain_id", - req: acceptInvitationReq{ - DomainID: "", - }, - err: apiutil.ErrMissingDomainID, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - err := tc.req.validate() - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - }) - } -} - -func TestInvitationReqValidation(t *testing.T) { - cases := []struct { - desc string - req invitationReq - err error - }{ - { - desc: "valid request", - req: invitationReq{ - userID: valid, - domainID: valid, - }, - err: nil, - }, - { - desc: "empty user ID", - req: invitationReq{ - userID: "", - domainID: valid, - }, - err: apiutil.ErrMissingID, - }, - { - desc: "empty domain", - req: invitationReq{ - userID: valid, - domainID: "", - }, - err: apiutil.ErrMissingDomainID, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - err := tc.req.validate() - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - }) - } -} diff --git a/invitations/api/responses.go b/invitations/api/responses.go deleted file mode 100644 index fd8aba951..000000000 --- a/invitations/api/responses.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "net/http" - - "github.com/absmach/supermq" - "github.com/absmach/supermq/invitations" -) - -var ( - _ supermq.Response = (*sendInvitationRes)(nil) - _ supermq.Response = (*viewInvitationRes)(nil) - _ supermq.Response = (*listInvitationsRes)(nil) - _ supermq.Response = (*acceptInvitationRes)(nil) - _ supermq.Response = (*rejectInvitationRes)(nil) - _ supermq.Response = (*deleteInvitationRes)(nil) -) - -type sendInvitationRes struct { - Message string `json:"message"` -} - -func (res sendInvitationRes) Code() int { - return http.StatusCreated -} - -func (res sendInvitationRes) Headers() map[string]string { - return map[string]string{} -} - -func (res sendInvitationRes) Empty() bool { - return true -} - -type viewInvitationRes struct { - invitations.Invitation `json:",inline"` -} - -func (res viewInvitationRes) Code() int { - return http.StatusOK -} - -func (res viewInvitationRes) Headers() map[string]string { - return map[string]string{} -} - -func (res viewInvitationRes) Empty() bool { - return false -} - -type listInvitationsRes struct { - invitations.InvitationPage `json:",inline"` -} - -func (res listInvitationsRes) Code() int { - return http.StatusOK -} - -func (res listInvitationsRes) Headers() map[string]string { - return map[string]string{} -} - -func (res listInvitationsRes) Empty() bool { - return false -} - -type acceptInvitationRes struct{} - -func (res acceptInvitationRes) Code() int { - return http.StatusNoContent -} - -func (res acceptInvitationRes) Headers() map[string]string { - return map[string]string{} -} - -func (res acceptInvitationRes) Empty() bool { - return true -} - -type deleteInvitationRes struct{} - -func (res deleteInvitationRes) Code() int { - return http.StatusNoContent -} - -func (res deleteInvitationRes) Headers() map[string]string { - return map[string]string{} -} - -func (res deleteInvitationRes) Empty() bool { - return true -} - -type rejectInvitationRes struct{} - -func (res rejectInvitationRes) Code() int { - return http.StatusNoContent -} - -func (res rejectInvitationRes) Headers() map[string]string { - return map[string]string{} -} - -func (res rejectInvitationRes) Empty() bool { - return true -} diff --git a/invitations/api/transport.go b/invitations/api/transport.go deleted file mode 100644 index 06eae2177..000000000 --- a/invitations/api/transport.go +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package api - -import ( - "context" - "encoding/json" - "log/slog" - "net/http" - "strings" - - "github.com/absmach/supermq" - api "github.com/absmach/supermq/api/http" - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/errors" - "github.com/go-chi/chi/v5" - kithttp "github.com/go-kit/kit/transport/http" - "github.com/prometheus/client_golang/prometheus/promhttp" - "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" -) - -const ( - userIDKey = "user_id" - domainIDKey = "domain_id" - invitedByKey = "invited_by" - relationKey = "relation" - stateKey = "state" -) - -func MakeHandler(svc invitations.Service, logger *slog.Logger, authn smqauthn.Authentication, instanceID string) http.Handler { - opts := []kithttp.ServerOption{ - kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), - } - - mux := chi.NewRouter() - - mux.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, false)) - - r.Route("/invitations", func(r chi.Router) { - r.Post("/", otelhttp.NewHandler(kithttp.NewServer( - sendInvitationEndpoint(svc), - decodeSendInvitationReq, - api.EncodeResponse, - opts..., - ), "send_invitation").ServeHTTP) - r.Get("/", otelhttp.NewHandler(kithttp.NewServer( - listInvitationsEndpoint(svc), - decodeListInvitationsReq, - api.EncodeResponse, - opts..., - ), "list_invitations").ServeHTTP) - r.Route("/{user_id}/{domain_id}", func(r chi.Router) { - r.Get("/", otelhttp.NewHandler(kithttp.NewServer( - viewInvitationEndpoint(svc), - decodeInvitationReq, - api.EncodeResponse, - opts..., - ), "view_invitations").ServeHTTP) - r.Delete("/", otelhttp.NewHandler(kithttp.NewServer( - deleteInvitationEndpoint(svc), - decodeInvitationReq, - api.EncodeResponse, - opts..., - ), "delete_invitation").ServeHTTP) - }) - r.Post("/accept", otelhttp.NewHandler(kithttp.NewServer( - acceptInvitationEndpoint(svc), - decodeAcceptInvitationReq, - api.EncodeResponse, - opts..., - ), "accept_invitation").ServeHTTP) - r.Post("/reject", otelhttp.NewHandler(kithttp.NewServer( - rejectInvitationEndpoint(svc), - decodeAcceptInvitationReq, - api.EncodeResponse, - opts..., - ), "reject_invitation").ServeHTTP) - }) - }) - - mux.Get("/health", supermq.Health("invitations", instanceID)) - mux.Handle("/metrics", promhttp.Handler()) - - return mux -} - -func decodeSendInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { - if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { - return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) - } - - var req sendInvitationReq - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) - } - - return req, nil -} - -func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, error) { - offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - userID, err := apiutil.ReadStringQuery(r, userIDKey, "") - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - invitedBy, err := apiutil.ReadStringQuery(r, invitedByKey, "") - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - relation, err := apiutil.ReadStringQuery(r, relationKey, "") - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - domainID, err := apiutil.ReadStringQuery(r, domainIDKey, "") - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - st, err := apiutil.ReadStringQuery(r, stateKey, invitations.All.String()) - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - state, err := invitations.ToState(st) - if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) - } - req := listInvitationsReq{ - Page: invitations.Page{ - Offset: offset, - Limit: limit, - InvitedBy: invitedBy, - UserID: userID, - Relation: relation, - DomainID: domainID, - State: state, - }, - } - - return req, nil -} - -func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { - if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { - return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) - } - - var req acceptInvitationReq - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) - } - - return req, nil -} - -func decodeInvitationReq(_ context.Context, r *http.Request) (interface{}, error) { - req := invitationReq{ - userID: chi.URLParam(r, "user_id"), - domainID: chi.URLParam(r, "domain_id"), - } - - return req, nil -} diff --git a/invitations/doc.go b/invitations/doc.go deleted file mode 100644 index 124fb7577..000000000 --- a/invitations/doc.go +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package invitations provides the API to manage invitations. -// -// An invitation is a request to join a domain. -package invitations diff --git a/invitations/invitations.go b/invitations/invitations.go deleted file mode 100644 index e7d681756..000000000 --- a/invitations/invitations.go +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package invitations - -import ( - "context" - "encoding/json" - "time" - - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/policies" -) - -// Invitation is an invitation to join a domain. -type Invitation struct { - InvitedBy string `json:"invited_by"` - UserID string `json:"user_id"` - DomainID string `json:"domain_id"` - Token string `json:"token,omitempty"` - Relation string `json:"relation,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at,omitempty"` - ConfirmedAt time.Time `json:"confirmed_at,omitempty"` - RejectedAt time.Time `json:"rejected_at,omitempty"` - Resend bool `json:"resend,omitempty"` -} - -// Page is a page of invitations. -type Page struct { - Offset uint64 `json:"offset" db:"offset"` - Limit uint64 `json:"limit" db:"limit"` - InvitedBy string `json:"invited_by,omitempty" db:"invited_by,omitempty"` - UserID string `json:"user_id,omitempty" db:"user_id,omitempty"` - DomainID string `json:"domain_id,omitempty" db:"domain_id,omitempty"` - Relation string `json:"relation,omitempty" db:"relation,omitempty"` - InvitedByOrUserID string `db:"invited_by_or_user_id,omitempty"` - State State `json:"state,omitempty"` -} - -// InvitationPage is a page of invitations. -type InvitationPage struct { - Total uint64 `json:"total"` - Offset uint64 `json:"offset"` - Limit uint64 `json:"limit"` - Invitations []Invitation `json:"invitations"` -} - -func (page InvitationPage) MarshalJSON() ([]byte, error) { - type Alias InvitationPage - a := struct { - Alias - }{ - Alias: Alias(page), - } - - if a.Invitations == nil { - a.Invitations = make([]Invitation, 0) - } - - return json.Marshal(a) -} - -// Service is an interface that defines methods for managing invitations. -// -//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" -type Service interface { - // SendInvitation sends an invitation to the given user. - // Only domain administrators and platform administrators can send invitations. - SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) (err error) - - // ViewInvitation returns an invitation. - // People who can view invitations are: - // - the invited user: they can view their own invitations - // - the user who sent the invitation - // - domain administrators - // - platform administrators - ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error) - - // ListInvitations returns a list of invitations. - // People who can list invitations are: - // - platform administrators can list all invitations - // - domain administrators can list invitations for their domain - // By default, it will list invitations the current user has sent or received. - ListInvitations(ctx context.Context, session authn.Session, page Page) (invitations InvitationPage, err error) - - // AcceptInvitation accepts an invitation by adding the user to the domain. - AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) - - // DeleteInvitation deletes an invitation. - // People who can delete invitations are: - // - the invited user: they can delete their own invitations - // - the user who sent the invitation - // - domain administrators - // - platform administrators - DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) - - // RejectInvitation rejects an invitation. - // People who can reject invitations are: - // - the invited user: they can reject their own invitations - RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) -} - -//go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines" -type Repository interface { - // Create creates an invitation. - Create(ctx context.Context, invitation Invitation) (err error) - - // Retrieve returns an invitation. - Retrieve(ctx context.Context, userID, domainID string) (Invitation, error) - - // RetrieveAll returns a list of invitations based on the given page. - RetrieveAll(ctx context.Context, page Page) (invitations InvitationPage, err error) - - // UpdateToken updates an invitation by setting the token. - UpdateToken(ctx context.Context, invitation Invitation) (err error) - - // UpdateConfirmation updates an invitation by setting the confirmation time. - UpdateConfirmation(ctx context.Context, invitation Invitation) (err error) - - // UpdateRejection updates an invitation by setting the rejection time. - UpdateRejection(ctx context.Context, invitation Invitation) (err error) - - // Delete deletes an invitation. - Delete(ctx context.Context, userID, domainID string) (err error) -} - -// CheckRelation checks if the given relation is valid. -// It returns an error if the relation is empty or invalid. -func CheckRelation(relation string) error { - if relation == "" { - return apiutil.ErrMissingRelation - } - if relation != policies.AdministratorRelation && - relation != policies.EditorRelation && - relation != policies.ContributorRelation && - relation != policies.MemberRelation && - relation != policies.GuestRelation && - relation != policies.DomainRelation && - relation != policies.ParentGroupRelation && - relation != policies.RoleGroupRelation && - relation != policies.GroupRelation && - relation != policies.PlatformRelation { - return apiutil.ErrInvalidRelation - } - - return nil -} diff --git a/invitations/invitations_test.go b/invitations/invitations_test.go deleted file mode 100644 index 5eb120dff..000000000 --- a/invitations/invitations_test.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package invitations_test - -import ( - "fmt" - "testing" - - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/invitations" - "github.com/stretchr/testify/assert" -) - -func TestInvitation_MarshalJSON(t *testing.T) { - cases := []struct { - desc string - page invitations.InvitationPage - res string - }{ - { - desc: "empty page", - page: invitations.InvitationPage{ - Invitations: []invitations.Invitation(nil), - }, - res: `{"total":0,"offset":0,"limit":0,"invitations":[]}`, - }, - { - desc: "page with invitations", - page: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 0, - Invitations: []invitations.Invitation{ - { - InvitedBy: "John", - UserID: "123", - DomainID: "123", - }, - }, - }, - res: `{"total":1,"offset":0,"limit":0,"invitations":[{"invited_by":"John","user_id":"123","domain_id":"123","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","confirmed_at":"0001-01-01T00:00:00Z","rejected_at":"0001-01-01T00:00:00Z"}]}`, - }, - } - - for _, tc := range cases { - data, err := tc.page.MarshalJSON() - assert.NoError(t, err, "Unexpected error: %v", err) - assert.Equal(t, tc.res, string(data), fmt.Sprintf("%s: expected %s, got %s", tc.desc, tc.res, string(data))) - } -} - -func TestCheckRelation(t *testing.T) { - cases := []struct { - relation string - err error - }{ - {"", apiutil.ErrMissingRelation}, - {"admin", apiutil.ErrInvalidRelation}, - {"editor", nil}, - {"contributor", nil}, - {"member", nil}, - {"guest", nil}, - {"domain", nil}, - {"parent_group", nil}, - {"role_group", nil}, - {"group", nil}, - {"platform", nil}, - } - - for _, tc := range cases { - err := invitations.CheckRelation(tc.relation) - assert.Equal(t, tc.err, err, "CheckRelation(%q) expected %v, got %v", tc.relation, tc.err, err) - } -} diff --git a/invitations/middleware/authorization.go b/invitations/middleware/authorization.go deleted file mode 100644 index 6d955eb41..000000000 --- a/invitations/middleware/authorization.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - - "github.com/absmach/supermq/auth" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/authz" - "github.com/absmach/supermq/pkg/errors" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/policies" -) - -// ErrMemberExist indicates that the user is already a member of the domain. -var ErrMemberExist = errors.New("user is already a member of the domain") - -var _ invitations.Service = (*tracing)(nil) - -type authorizationMiddleware struct { - authz authz.Authorization - svc invitations.Service -} - -func AuthorizationMiddleware(authz authz.Authorization, svc invitations.Service) invitations.Service { - return &authorizationMiddleware{authz, svc} -} - -func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) { - session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.UserID) - if err := am.authorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil { - // return error if the user is already a member of the domain - return errors.Wrap(svcerr.ErrConflict, ErrMemberExist) - } - - if err := am.checkAdmin(ctx, session); err != nil { - return err - } - - return am.svc.SendInvitation(ctx, session, invitation) -} - -func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) { - session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - if session.UserID != userID { - if err := am.checkAdmin(ctx, session); err != nil { - return invitations.Invitation{}, err - } - } - - return am.svc.ViewInvitation(ctx, session, userID, domain) -} - -func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) { - session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil { - session.SuperAdmin = true - } - - if !session.SuperAdmin { - switch { - case page.DomainID != "": - if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil { - return invitations.InvitationPage{}, err - } - default: - page.InvitedByOrUserID = session.UserID - } - } - - return am.svc.ListInvitations(ctx, session, page) -} - -func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - return am.svc.AcceptInvitation(ctx, session, domainID) -} - -func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - return am.svc.RejectInvitation(ctx, session, domainID) -} - -func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { - session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID) - if err := am.checkAdmin(ctx, session); err != nil { - return err - } - - return am.svc.DeleteInvitation(ctx, session, userID, domainID) -} - -// checkAdmin checks if the given user is a domain or platform administrator. -func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error { - if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID); err == nil { - return nil - } - - if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil { - return nil - } - - return svcerr.ErrAuthorization -} - -func (am *authorizationMiddleware) authorize(ctx context.Context, subj, perm, objType, obj string) error { - req := authz.PolicyReq{ - SubjectType: policies.UserType, - SubjectKind: policies.UsersKind, - Subject: subj, - Permission: perm, - ObjectType: objType, - Object: obj, - } - if err := am.authz.Authorize(ctx, req); err != nil { - return err - } - - return nil -} diff --git a/invitations/middleware/doc.go b/invitations/middleware/doc.go deleted file mode 100644 index 1fdf252ff..000000000 --- a/invitations/middleware/doc.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package middleware contains the middleware for the invitations service. -// It is responsible for the following: -// - Logging -// - Metrics -// - Tracing -package middleware diff --git a/invitations/middleware/logging.go b/invitations/middleware/logging.go deleted file mode 100644 index f6c1abef5..000000000 --- a/invitations/middleware/logging.go +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - "log/slog" - "time" - - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/authn" -) - -var _ invitations.Service = (*logging)(nil) - -type logging struct { - logger *slog.Logger - svc invitations.Service -} - -func Logging(logger *slog.Logger, svc invitations.Service) invitations.Service { - return &logging{logger, svc} -} - -func (lm *logging) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("user_id", invitation.UserID), - slog.String("domain_id", invitation.DomainID), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("Send invitation failed", args...) - return - } - lm.logger.Info("Send invitation completed successfully", args...) - }(time.Now()) - return lm.svc.SendInvitation(ctx, session, invitation) -} - -func (lm *logging) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation invitations.Invitation, err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("user_id", userID), - slog.String("domain_id", domainID), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("View invitation failed", args...) - return - } - lm.logger.Info("View invitation completed successfully", args...) - }(time.Now()) - return lm.svc.ViewInvitation(ctx, session, userID, domainID) -} - -func (lm *logging) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.Group("page", - slog.Uint64("offset", page.Offset), - slog.Uint64("limit", page.Limit), - slog.Uint64("total", invs.Total), - ), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("List invitations failed", args...) - return - } - lm.logger.Info("List invitations completed successfully", args...) - }(time.Now()) - return lm.svc.ListInvitations(ctx, session, page) -} - -func (lm *logging) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("domain_id", domainID), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("Accept invitation failed", args...) - return - } - lm.logger.Info("Accept invitation completed successfully", args...) - }(time.Now()) - return lm.svc.AcceptInvitation(ctx, session, domainID) -} - -func (lm *logging) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("domain_id", domainID), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("Reject invitation failed", args...) - return - } - lm.logger.Info("Reject invitation completed successfully", args...) - }(time.Now()) - return lm.svc.RejectInvitation(ctx, session, domainID) -} - -func (lm *logging) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { - defer func(begin time.Time) { - args := []any{ - slog.String("duration", time.Since(begin).String()), - slog.String("user_id", userID), - slog.String("domain_id", domainID), - } - if err != nil { - args = append(args, slog.Any("error", err)) - lm.logger.Warn("Delete invitation failed", args...) - return - } - lm.logger.Info("Delete invitation completed successfully", args...) - }(time.Now()) - return lm.svc.DeleteInvitation(ctx, session, userID, domainID) -} diff --git a/invitations/middleware/metrics.go b/invitations/middleware/metrics.go deleted file mode 100644 index b5b02063c..000000000 --- a/invitations/middleware/metrics.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - "time" - - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/authn" - "github.com/go-kit/kit/metrics" -) - -var _ invitations.Service = (*metricsmw)(nil) - -type metricsmw struct { - counter metrics.Counter - latency metrics.Histogram - svc invitations.Service -} - -func Metrics(counter metrics.Counter, latency metrics.Histogram, svc invitations.Service) invitations.Service { - return &metricsmw{ - counter: counter, - latency: latency, - svc: svc, - } -} - -func (mm *metricsmw) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) { - defer func(begin time.Time) { - mm.counter.With("method", "send_invitation").Add(1) - mm.latency.With("method", "send_invitation").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.SendInvitation(ctx, session, invitation) -} - -func (mm *metricsmw) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation invitations.Invitation, err error) { - defer func(begin time.Time) { - mm.counter.With("method", "view_invitation").Add(1) - mm.latency.With("method", "view_invitation").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.ViewInvitation(ctx, session, userID, domainID) -} - -func (mm *metricsmw) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) { - defer func(begin time.Time) { - mm.counter.With("method", "list_invitations").Add(1) - mm.latency.With("method", "list_invitations").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.ListInvitations(ctx, session, page) -} - -func (mm *metricsmw) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - defer func(begin time.Time) { - mm.counter.With("method", "accept_invitation").Add(1) - mm.latency.With("method", "accept_invitation").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.AcceptInvitation(ctx, session, domainID) -} - -func (mm *metricsmw) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - defer func(begin time.Time) { - mm.counter.With("method", "reject_invitation").Add(1) - mm.latency.With("method", "reject_invitation").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.RejectInvitation(ctx, session, domainID) -} - -func (mm *metricsmw) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { - defer func(begin time.Time) { - mm.counter.With("method", "delete_invitation").Add(1) - mm.latency.With("method", "delete_invitation").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.DeleteInvitation(ctx, session, userID, domainID) -} diff --git a/invitations/middleware/tracing.go b/invitations/middleware/tracing.go deleted file mode 100644 index 0b647be28..000000000 --- a/invitations/middleware/tracing.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package middleware - -import ( - "context" - - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/pkg/authn" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" -) - -var _ invitations.Service = (*tracing)(nil) - -type tracing struct { - tracer trace.Tracer - svc invitations.Service -} - -func Tracing(svc invitations.Service, tracer trace.Tracer) invitations.Service { - return &tracing{tracer, svc} -} - -func (tm *tracing) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) { - ctx, span := tm.tracer.Start(ctx, "send_invitation", trace.WithAttributes( - attribute.String("domain_id", invitation.DomainID), - attribute.String("user_id", invitation.UserID), - )) - defer span.End() - - return tm.svc.SendInvitation(ctx, session, invitation) -} - -func (tm *tracing) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) { - ctx, span := tm.tracer.Start(ctx, "view_invitation", trace.WithAttributes( - attribute.String("user_id", userID), - attribute.String("domain_id", domain), - )) - defer span.End() - - return tm.svc.ViewInvitation(ctx, session, userID, domain) -} - -func (tm *tracing) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) { - ctx, span := tm.tracer.Start(ctx, "list_invitations", trace.WithAttributes( - attribute.Int("limit", int(page.Limit)), - attribute.Int("offset", int(page.Offset)), - attribute.String("user_id", page.UserID), - attribute.String("domain_id", page.DomainID), - attribute.String("invited_by", page.InvitedBy), - )) - defer span.End() - - return tm.svc.ListInvitations(ctx, session, page) -} - -func (tm *tracing) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - ctx, span := tm.tracer.Start(ctx, "accept_invitation", trace.WithAttributes( - attribute.String("domain_id", domainID), - )) - defer span.End() - - return tm.svc.AcceptInvitation(ctx, session, domainID) -} - -func (tm *tracing) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) { - ctx, span := tm.tracer.Start(ctx, "reject_invitation", trace.WithAttributes( - attribute.String("domain_id", domainID), - )) - defer span.End() - - return tm.svc.RejectInvitation(ctx, session, domainID) -} - -func (tm *tracing) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) { - ctx, span := tm.tracer.Start(ctx, "delete_invitation", trace.WithAttributes( - attribute.String("user_id", userID), - attribute.String("domain_id", domainID), - )) - defer span.End() - - return tm.svc.DeleteInvitation(ctx, session, userID, domainID) -} diff --git a/invitations/mocks/doc.go b/invitations/mocks/doc.go deleted file mode 100644 index 4d95a3c13..000000000 --- a/invitations/mocks/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package mocks provides a mock implementation of the invitations repository. -package mocks diff --git a/invitations/mocks/repository.go b/invitations/mocks/repository.go deleted file mode 100644 index 09ade1835..000000000 --- a/invitations/mocks/repository.go +++ /dev/null @@ -1,177 +0,0 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. - -// Copyright (c) Abstract Machines - -package mocks - -import ( - context "context" - - invitations "github.com/absmach/supermq/invitations" - mock "github.com/stretchr/testify/mock" -) - -// Repository is an autogenerated mock type for the Repository type -type Repository struct { - mock.Mock -} - -// Create provides a mock function with given fields: ctx, invitation -func (_m *Repository) Create(ctx context.Context, invitation invitations.Invitation) error { - ret := _m.Called(ctx, invitation) - - if len(ret) == 0 { - panic("no return value specified for Create") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok { - r0 = rf(ctx, invitation) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Delete provides a mock function with given fields: ctx, userID, domainID -func (_m *Repository) Delete(ctx context.Context, userID string, domainID string) error { - ret := _m.Called(ctx, userID, domainID) - - if len(ret) == 0 { - panic("no return value specified for Delete") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { - r0 = rf(ctx, userID, domainID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// Retrieve provides a mock function with given fields: ctx, userID, domainID -func (_m *Repository) Retrieve(ctx context.Context, userID string, domainID string) (invitations.Invitation, error) { - ret := _m.Called(ctx, userID, domainID) - - if len(ret) == 0 { - panic("no return value specified for Retrieve") - } - - var r0 invitations.Invitation - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (invitations.Invitation, error)); ok { - return rf(ctx, userID, domainID) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string) invitations.Invitation); ok { - r0 = rf(ctx, userID, domainID) - } else { - r0 = ret.Get(0).(invitations.Invitation) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, userID, domainID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RetrieveAll provides a mock function with given fields: ctx, page -func (_m *Repository) RetrieveAll(ctx context.Context, page invitations.Page) (invitations.InvitationPage, error) { - ret := _m.Called(ctx, page) - - if len(ret) == 0 { - panic("no return value specified for RetrieveAll") - } - - var r0 invitations.InvitationPage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, invitations.Page) (invitations.InvitationPage, error)); ok { - return rf(ctx, page) - } - if rf, ok := ret.Get(0).(func(context.Context, invitations.Page) invitations.InvitationPage); ok { - r0 = rf(ctx, page) - } else { - r0 = ret.Get(0).(invitations.InvitationPage) - } - - if rf, ok := ret.Get(1).(func(context.Context, invitations.Page) error); ok { - r1 = rf(ctx, page) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// UpdateConfirmation provides a mock function with given fields: ctx, invitation -func (_m *Repository) UpdateConfirmation(ctx context.Context, invitation invitations.Invitation) error { - ret := _m.Called(ctx, invitation) - - if len(ret) == 0 { - panic("no return value specified for UpdateConfirmation") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok { - r0 = rf(ctx, invitation) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateRejection provides a mock function with given fields: ctx, invitation -func (_m *Repository) UpdateRejection(ctx context.Context, invitation invitations.Invitation) error { - ret := _m.Called(ctx, invitation) - - if len(ret) == 0 { - panic("no return value specified for UpdateRejection") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok { - r0 = rf(ctx, invitation) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// UpdateToken provides a mock function with given fields: ctx, invitation -func (_m *Repository) UpdateToken(ctx context.Context, invitation invitations.Invitation) error { - ret := _m.Called(ctx, invitation) - - if len(ret) == 0 { - panic("no return value specified for UpdateToken") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok { - r0 = rf(ctx, invitation) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *Repository { - mock := &Repository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/invitations/mocks/service.go b/invitations/mocks/service.go deleted file mode 100644 index 183a379c0..000000000 --- a/invitations/mocks/service.go +++ /dev/null @@ -1,162 +0,0 @@ -// Code generated by mockery v2.43.2. DO NOT EDIT. - -// Copyright (c) Abstract Machines - -package mocks - -import ( - context "context" - - authn "github.com/absmach/supermq/pkg/authn" - - invitations "github.com/absmach/supermq/invitations" - - mock "github.com/stretchr/testify/mock" -) - -// Service is an autogenerated mock type for the Service type -type Service struct { - mock.Mock -} - -// AcceptInvitation provides a mock function with given fields: ctx, session, domainID -func (_m *Service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error { - ret := _m.Called(ctx, session, domainID) - - if len(ret) == 0 { - panic("no return value specified for AcceptInvitation") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { - r0 = rf(ctx, session, domainID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// DeleteInvitation provides a mock function with given fields: ctx, session, userID, domainID -func (_m *Service) DeleteInvitation(ctx context.Context, session authn.Session, userID string, domainID string) error { - ret := _m.Called(ctx, session, userID, domainID) - - if len(ret) == 0 { - panic("no return value specified for DeleteInvitation") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok { - r0 = rf(ctx, session, userID, domainID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ListInvitations provides a mock function with given fields: ctx, session, page -func (_m *Service) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invitations.InvitationPage, error) { - ret := _m.Called(ctx, session, page) - - if len(ret) == 0 { - panic("no return value specified for ListInvitations") - } - - var r0 invitations.InvitationPage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Page) (invitations.InvitationPage, error)); ok { - return rf(ctx, session, page) - } - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Page) invitations.InvitationPage); ok { - r0 = rf(ctx, session, page) - } else { - r0 = ret.Get(0).(invitations.InvitationPage) - } - - if rf, ok := ret.Get(1).(func(context.Context, authn.Session, invitations.Page) error); ok { - r1 = rf(ctx, session, page) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// RejectInvitation provides a mock function with given fields: ctx, session, domainID -func (_m *Service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error { - ret := _m.Called(ctx, session, domainID) - - if len(ret) == 0 { - panic("no return value specified for RejectInvitation") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { - r0 = rf(ctx, session, domainID) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// SendInvitation provides a mock function with given fields: ctx, session, invitation -func (_m *Service) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) error { - ret := _m.Called(ctx, session, invitation) - - if len(ret) == 0 { - panic("no return value specified for SendInvitation") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Invitation) error); ok { - r0 = rf(ctx, session, invitation) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ViewInvitation provides a mock function with given fields: ctx, session, userID, domainID -func (_m *Service) ViewInvitation(ctx context.Context, session authn.Session, userID string, domainID string) (invitations.Invitation, error) { - ret := _m.Called(ctx, session, userID, domainID) - - if len(ret) == 0 { - panic("no return value specified for ViewInvitation") - } - - var r0 invitations.Invitation - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (invitations.Invitation, error)); ok { - return rf(ctx, session, userID, domainID) - } - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) invitations.Invitation); ok { - r0 = rf(ctx, session, userID, domainID) - } else { - r0 = ret.Get(0).(invitations.Invitation) - } - - if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok { - r1 = rf(ctx, session, userID, domainID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewService(t interface { - mock.TestingT - Cleanup(func()) -}) *Service { - mock := &Service{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/invitations/postgres/doc.go b/invitations/postgres/doc.go deleted file mode 100644 index 086a7bb4c..000000000 --- a/invitations/postgres/doc.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Package postgres provides a postgres implementation of the invitations repository. -package postgres diff --git a/invitations/postgres/init.go b/invitations/postgres/init.go deleted file mode 100644 index 442d8e615..000000000 --- a/invitations/postgres/init.go +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - _ "github.com/jackc/pgx/v5/stdlib" // required for SQL access - migrate "github.com/rubenv/sql-migrate" -) - -func Migration() *migrate.MemoryMigrationSource { - return &migrate.MemoryMigrationSource{ - Migrations: []*migrate.Migration{ - { - Id: "invitations_01", - // VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters - Up: []string{ - `CREATE TABLE IF NOT EXISTS invitations ( - invited_by VARCHAR(36) NOT NULL, - user_id VARCHAR(36) NOT NULL, - domain_id VARCHAR(36) NOT NULL, - token TEXT NOT NULL, - relation VARCHAR(254) NOT NULL, - created_at TIMESTAMP NOT NULL, - updated_at TIMESTAMP, - confirmed_at TIMESTAMP, - UNIQUE (user_id, domain_id), - PRIMARY KEY (user_id, domain_id) - )`, - }, - Down: []string{ - `DROP TABLE IF EXISTS invitations`, - }, - }, - { - Id: "invitations_02_add_rejection", - Up: []string{ - `ALTER TABLE invitations - ADD COLUMN rejected_at TIMESTAMP`, - }, - Down: []string{ - `ALTER TABLE invitations - DROP COLUMN rejected_at`, - }, - }, - }, - } -} diff --git a/invitations/postgres/invitations.go b/invitations/postgres/invitations.go deleted file mode 100644 index 7387c86fb..000000000 --- a/invitations/postgres/invitations.go +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - "context" - "database/sql" - "fmt" - "strings" - "time" - - "github.com/absmach/supermq/invitations" - repoerr "github.com/absmach/supermq/pkg/errors/repository" - "github.com/absmach/supermq/pkg/postgres" -) - -type repository struct { - db postgres.Database -} - -func NewRepository(db postgres.Database) invitations.Repository { - return &repository{db: db} -} - -func (repo *repository) Create(ctx context.Context, invitation invitations.Invitation) (err error) { - q := `INSERT INTO invitations (invited_by, user_id, domain_id, token, relation, created_at) - VALUES (:invited_by, :user_id, :domain_id, :token, :relation, :created_at)` - - dbInv := toDBInvitation(invitation) - if _, err = repo.db.NamedExecContext(ctx, q, dbInv); err != nil { - return postgres.HandleError(repoerr.ErrCreateEntity, err) - } - - return nil -} - -func (repo *repository) Retrieve(ctx context.Context, userID, domainID string) (invitations.Invitation, error) { - q := `SELECT invited_by, user_id, domain_id, token, relation, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE user_id = :user_id AND domain_id = :domain_id;` - - dbinv := dbInvitation{ - UserID: userID, - DomainID: domainID, - } - rows, err := repo.db.NamedQueryContext(ctx, q, dbinv) - if err != nil { - return invitations.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - defer rows.Close() - - dbinv = dbInvitation{} - if rows.Next() { - if err = rows.StructScan(&dbinv); err != nil { - return invitations.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - - return toInvitation(dbinv), nil - } - - return invitations.Invitation{}, repoerr.ErrNotFound -} - -func (repo *repository) RetrieveAll(ctx context.Context, page invitations.Page) (invitations.InvitationPage, error) { - query := pageQuery(page) - - q := fmt.Sprintf("SELECT invited_by, user_id, domain_id, relation, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query) - - rows, err := repo.db.NamedQueryContext(ctx, q, page) - if err != nil { - return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - defer rows.Close() - - var items []invitations.Invitation - for rows.Next() { - var dbinv dbInvitation - if err = rows.StructScan(&dbinv); err != nil { - return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - items = append(items, toInvitation(dbinv)) - } - - tq := fmt.Sprintf(`SELECT COUNT(*) FROM invitations %s`, query) - - total, err := postgres.Total(ctx, repo.db, tq, page) - if err != nil { - return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - - invPage := invitations.InvitationPage{ - Total: total, - Offset: page.Offset, - Limit: page.Limit, - Invitations: items, - } - - return invPage, nil -} - -func (repo *repository) UpdateToken(ctx context.Context, invitation invitations.Invitation) (err error) { - q := `UPDATE invitations SET token = :token, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id` - - dbinv := toDBInvitation(invitation) - result, err := repo.db.NamedExecContext(ctx, q, dbinv) - if err != nil { - return postgres.HandleError(repoerr.ErrUpdateEntity, err) - } - if rows, _ := result.RowsAffected(); rows == 0 { - return repoerr.ErrNotFound - } - - return nil -} - -func (repo *repository) UpdateConfirmation(ctx context.Context, invitation invitations.Invitation) (err error) { - q := `UPDATE invitations SET confirmed_at = :confirmed_at, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id` - - dbinv := toDBInvitation(invitation) - result, err := repo.db.NamedExecContext(ctx, q, dbinv) - if err != nil { - return postgres.HandleError(repoerr.ErrUpdateEntity, err) - } - if rows, _ := result.RowsAffected(); rows == 0 { - return repoerr.ErrNotFound - } - - return nil -} - -func (repo *repository) UpdateRejection(ctx context.Context, invitation invitations.Invitation) (err error) { - q := `UPDATE invitations SET rejected_at = :rejected_at, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id` - - dbInv := toDBInvitation(invitation) - result, err := repo.db.NamedExecContext(ctx, q, dbInv) - if err != nil { - return postgres.HandleError(repoerr.ErrUpdateEntity, err) - } - if rows, _ := result.RowsAffected(); rows == 0 { - return repoerr.ErrNotFound - } - - return nil -} - -func (repo *repository) Delete(ctx context.Context, userID, domain string) (err error) { - q := `DELETE FROM invitations WHERE user_id = $1 AND domain_id = $2` - - result, err := repo.db.ExecContext(ctx, q, userID, domain) - if err != nil { - return postgres.HandleError(repoerr.ErrRemoveEntity, err) - } - if rows, _ := result.RowsAffected(); rows == 0 { - return repoerr.ErrNotFound - } - - return nil -} - -func pageQuery(pm invitations.Page) string { - var query []string - var emq string - if pm.DomainID != "" { - query = append(query, "domain_id = :domain_id") - } - if pm.UserID != "" { - query = append(query, "user_id = :user_id") - } - if pm.InvitedBy != "" { - query = append(query, "invited_by = :invited_by") - } - if pm.Relation != "" { - query = append(query, "relation = :relation") - } - if pm.InvitedByOrUserID != "" { - query = append(query, "(invited_by = :invited_by_or_user_id OR user_id = :invited_by_or_user_id)") - } - if pm.State == invitations.Accepted { - query = append(query, "confirmed_at IS NOT NULL") - } - if pm.State == invitations.Pending { - query = append(query, "confirmed_at IS NULL AND rejected_at IS NULL") - } - if pm.State == invitations.Rejected { - query = append(query, "rejected_at IS NOT NULL") - } - - if len(query) > 0 { - emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) - } - - return emq -} - -type dbInvitation struct { - InvitedBy string `db:"invited_by"` - UserID string `db:"user_id"` - DomainID string `db:"domain_id"` - Token string `db:"token,omitempty"` - Relation string `db:"relation"` - CreatedAt time.Time `db:"created_at"` - UpdatedAt sql.NullTime `db:"updated_at,omitempty"` - ConfirmedAt sql.NullTime `db:"confirmed_at,omitempty"` - RejectedAt sql.NullTime `db:"rejected_at,omitempty"` -} - -func toDBInvitation(inv invitations.Invitation) dbInvitation { - var updatedAt, confirmedAt, rejectedAt sql.NullTime - if inv.UpdatedAt != (time.Time{}) { - updatedAt = sql.NullTime{Time: inv.UpdatedAt, Valid: true} - } - if inv.ConfirmedAt != (time.Time{}) { - confirmedAt = sql.NullTime{Time: inv.ConfirmedAt, Valid: true} - } - if inv.RejectedAt != (time.Time{}) { - rejectedAt = sql.NullTime{Time: inv.RejectedAt, Valid: true} - } - - return dbInvitation{ - InvitedBy: inv.InvitedBy, - UserID: inv.UserID, - DomainID: inv.DomainID, - Token: inv.Token, - Relation: inv.Relation, - CreatedAt: inv.CreatedAt, - UpdatedAt: updatedAt, - ConfirmedAt: confirmedAt, - RejectedAt: rejectedAt, - } -} - -func toInvitation(dbinv dbInvitation) invitations.Invitation { - var updatedAt, confirmedAt, rejectedAt time.Time - if dbinv.UpdatedAt.Valid { - updatedAt = dbinv.UpdatedAt.Time - } - if dbinv.ConfirmedAt.Valid { - confirmedAt = dbinv.ConfirmedAt.Time - } - if dbinv.RejectedAt.Valid { - rejectedAt = dbinv.RejectedAt.Time - } - - return invitations.Invitation{ - InvitedBy: dbinv.InvitedBy, - UserID: dbinv.UserID, - DomainID: dbinv.DomainID, - Token: dbinv.Token, - Relation: dbinv.Relation, - CreatedAt: dbinv.CreatedAt, - UpdatedAt: updatedAt, - ConfirmedAt: confirmedAt, - RejectedAt: rejectedAt, - } -} diff --git a/invitations/postgres/invitations_test.go b/invitations/postgres/invitations_test.go deleted file mode 100644 index b1f7a5496..000000000 --- a/invitations/postgres/invitations_test.go +++ /dev/null @@ -1,811 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres_test - -import ( - "context" - "fmt" - "strings" - "testing" - "time" - - "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/invitations/postgres" - "github.com/absmach/supermq/pkg/errors" - repoerr "github.com/absmach/supermq/pkg/errors/repository" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var ( - invalidUUID = strings.Repeat("a", 37) - validToken = strings.Repeat("a", 1024) - relation = "relation" -) - -func TestInvitationCreate(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - domainID := testsutil.GenerateUUID(t) - userID := testsutil.GenerateUUID(t) - - cases := []struct { - desc string - invitation invitations.Invitation - err error - }{ - { - desc: "add new invitation successfully", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: userID, - DomainID: domainID, - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add new invitation with an confirmed_at date", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - ConfirmedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add invitation with duplicate invitation", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: userID, - DomainID: domainID, - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: repoerr.ErrConflict, - }, - { - desc: "add invitation with invalid invitation invited_by", - invitation: invitations.Invitation{ - InvitedBy: invalidUUID, - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: repoerr.ErrMalformedEntity, - }, - { - desc: "add invitation with invalid invitation relation", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: strings.Repeat("a", 255), - CreatedAt: time.Now(), - }, - err: repoerr.ErrMalformedEntity, - }, - { - desc: "add invitation with invalid invitation domain", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: invalidUUID, - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: repoerr.ErrMalformedEntity, - }, - { - desc: "add invitation with invalid invitation user id", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: invalidUUID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: repoerr.ErrMalformedEntity, - }, - { - desc: "add invitation with empty invitation domain", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add invitation with empty invitation user id", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add invitation with empty invitation invited_by", - invitation: invitations.Invitation{ - DomainID: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "add invitation with empty invitation token", - invitation: invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - Relation: relation, - CreatedAt: time.Now(), - }, - err: nil, - }, - } - for _, tc := range cases { - switch err := repo.Create(context.Background(), tc.invitation); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - default: - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } - } -} - -func TestInvitationRetrieve(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - invitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: relation, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - err := repo.Create(context.Background(), invitation) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) - - cases := []struct { - desc string - userID string - domainID string - response invitations.Invitation - err error - }{ - { - desc: "retrieve invitations successfully", - userID: invitation.UserID, - domainID: invitation.DomainID, - response: invitation, - err: nil, - }, - { - desc: "retrieve invitations with invalid invitation user id", - userID: testsutil.GenerateUUID(t), - domainID: invitation.DomainID, - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - { - desc: "retrieve invitations with invalid invitation domain_id", - userID: invitation.UserID, - domainID: testsutil.GenerateUUID(t), - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - { - desc: "retrieve invitations with invalid invitation user id and domain_id", - userID: testsutil.GenerateUUID(t), - domainID: testsutil.GenerateUUID(t), - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - { - desc: "retrieve invitations with empty invitation user id", - userID: "", - domainID: invitation.DomainID, - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - { - desc: "retrieve invitations with empty invitation domain_id", - userID: invitation.UserID, - domainID: "", - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - { - desc: "retrieve invitations with empty invitation user id and domain_id", - userID: "", - domainID: "", - response: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - } - for _, tc := range cases { - page, err := repo.Retrieve(context.Background(), tc.userID, tc.domainID) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - assert.Equal(t, tc.response, page, fmt.Sprintf("desc: %s\n", tc.desc)) - } -} - -func TestInvitationRetrieveAll(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - num := 200 - - var items []invitations.Invitation - for i := 0; i < num; i++ { - invitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: fmt.Sprintf("%s-%d", relation, i), - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - err := repo.Create(context.Background(), invitation) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) - invitation.Token = "" - items = append(items, invitation) - } - items[100].ConfirmedAt = time.Now().UTC().Truncate(time.Microsecond) - err := repo.UpdateConfirmation(context.Background(), items[100]) - require.Nil(t, err, fmt.Sprintf("update invitation unexpected error: %s", err)) - - swap := items[100] - items = append(items[:100], items[101:]...) - items = append(items, swap) - - cases := []struct { - desc string - page invitations.Page - response invitations.InvitationPage - err error - }{ - { - desc: "retrieve invitations successfully", - page: invitations.Page{ - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 0, - Limit: 10, - Invitations: items[:10], - }, - err: nil, - }, - { - desc: "retrieve invitations with offset", - page: invitations.Page{ - Offset: 10, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 10, - Limit: 10, - Invitations: items[10:20], - }, - }, - { - desc: "retrieve invitations with limit", - page: invitations.Page{ - Offset: 0, - Limit: 50, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 0, - Limit: 50, - Invitations: items[:50], - }, - }, - { - desc: "retrieve invitations with offset and limit", - page: invitations.Page{ - Offset: 10, - Limit: 50, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 10, - Limit: 50, - Invitations: items[10:60], - }, - }, - { - desc: "retrieve invitations with offset out of range", - page: invitations.Page{ - Offset: 1000, - Limit: 50, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 1000, - Limit: 50, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with offset and limit out of range", - page: invitations.Page{ - Offset: 170, - Limit: 50, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 170, - Limit: 50, - Invitations: items[170:200], - }, - }, - { - desc: "retrieve invitations with limit out of range", - page: invitations.Page{ - Offset: 0, - Limit: 1000, - }, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 0, - Limit: 1000, - Invitations: items, - }, - }, - { - desc: "retrieve invitations with empty page", - page: invitations.Page{}, - response: invitations.InvitationPage{ - Total: uint64(num), - Offset: 0, - Limit: 0, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with domain", - page: invitations.Page{ - DomainID: items[0].DomainID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with user id", - page: invitations.Page{ - UserID: items[0].UserID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with invited_by", - page: invitations.Page{ - InvitedBy: items[0].InvitedBy, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with invited_by_or_user_id", - page: invitations.Page{ - InvitedByOrUserID: items[0].UserID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with relation", - page: invitations.Page{ - Relation: relation + "-0", - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with domain_id and user id", - page: invitations.Page{ - DomainID: items[0].DomainID, - UserID: items[0].UserID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with domain_id and invited_by", - page: invitations.Page{ - DomainID: items[0].DomainID, - InvitedBy: items[0].InvitedBy, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with user id and invited_by", - page: invitations.Page{ - UserID: items[0].UserID, - InvitedBy: items[0].InvitedBy, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with domain_id, user id and invited_by", - page: invitations.Page{ - DomainID: items[0].DomainID, - UserID: items[0].UserID, - InvitedBy: items[0].InvitedBy, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with domain_id, user id, invited_by and relation", - page: invitations.Page{ - DomainID: items[0].DomainID, - UserID: items[0].UserID, - InvitedBy: items[0].InvitedBy, - Relation: relation + "-0", - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[0]}, - }, - }, - { - desc: "retrieve invitations with invalid domain", - page: invitations.Page{ - DomainID: invalidUUID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 0, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with invalid user id", - page: invitations.Page{ - UserID: testsutil.GenerateUUID(t), - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 0, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with invalid invited_by", - page: invitations.Page{ - InvitedBy: invalidUUID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 0, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with invalid relation", - page: invitations.Page{ - Relation: invalidUUID, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 0, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation(nil), - }, - }, - { - desc: "retrieve invitations with accepted state", - page: invitations.Page{ - State: invitations.Accepted, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{items[num-1]}, - }, - }, - { - desc: "retrieve invitations with pending state", - page: invitations.Page{ - State: invitations.Pending, - Offset: 0, - Limit: 10, - }, - response: invitations.InvitationPage{ - Total: uint64(num - 1), - Offset: 0, - Limit: 10, - Invitations: items[0:10], - }, - }, - } - for _, tc := range cases { - page, err := repo.RetrieveAll(context.Background(), tc.page) - assert.Equal(t, tc.response.Total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, page.Total)) - assert.Equal(t, tc.response.Offset, page.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, page.Offset)) - assert.Equal(t, tc.response.Limit, page.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, page.Limit)) - assert.ElementsMatch(t, page.Invitations, tc.response.Invitations, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response.Invitations, page.Invitations)) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} - -func TestInvitationUpdateToken(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - invitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - CreatedAt: time.Now(), - } - err := repo.Create(context.Background(), invitation) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) - - cases := []struct { - desc string - invitation invitations.Invitation - err error - }{ - { - desc: "update invitation successfully", - invitation: invitations.Invitation{ - DomainID: invitation.DomainID, - UserID: invitation.UserID, - Token: validToken, - UpdatedAt: time.Now(), - }, - err: nil, - }, - { - desc: "update invitation with invalid user id", - invitation: invitations.Invitation{ - UserID: testsutil.GenerateUUID(t), - DomainID: invitation.DomainID, - Token: validToken, - UpdatedAt: time.Now(), - }, - err: repoerr.ErrNotFound, - }, - { - desc: "update invitation with invalid domain_id", - invitation: invitations.Invitation{ - UserID: invitation.UserID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - UpdatedAt: time.Now(), - }, - err: repoerr.ErrNotFound, - }, - } - for _, tc := range cases { - err := repo.UpdateToken(context.Background(), tc.invitation) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} - -func TestInvitationUpdateConfirmation(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - invitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - CreatedAt: time.Now(), - } - err := repo.Create(context.Background(), invitation) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) - - cases := []struct { - desc string - invitation invitations.Invitation - err error - }{ - { - desc: "update invitation successfully", - invitation: invitations.Invitation{ - DomainID: invitation.DomainID, - UserID: invitation.UserID, - ConfirmedAt: time.Now(), - }, - err: nil, - }, - { - desc: "update invitation with invalid user id", - invitation: invitations.Invitation{ - UserID: testsutil.GenerateUUID(t), - DomainID: invitation.UserID, - ConfirmedAt: time.Now(), - }, - err: repoerr.ErrNotFound, - }, - { - desc: "update invitation with invalid domain", - invitation: invitations.Invitation{ - UserID: invitation.UserID, - DomainID: testsutil.GenerateUUID(t), - ConfirmedAt: time.Now(), - }, - err: repoerr.ErrNotFound, - }, - } - for _, tc := range cases { - err := repo.UpdateConfirmation(context.Background(), tc.invitation) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} - -func TestInvitationDelete(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM invitations") - require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err)) - }) - repo := postgres.NewRepository(database) - - invitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - CreatedAt: time.Now(), - } - err := repo.Create(context.Background(), invitation) - require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err)) - - cases := []struct { - desc string - invitation invitations.Invitation - err error - }{ - { - desc: "delete invitation successfully", - invitation: invitations.Invitation{ - UserID: invitation.UserID, - DomainID: invitation.DomainID, - }, - err: nil, - }, - { - desc: "delete invitation with invalid invitation id", - invitation: invitations.Invitation{ - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - }, - err: repoerr.ErrNotFound, - }, - { - desc: "delete invitation with empty invitation id", - invitation: invitations.Invitation{}, - err: repoerr.ErrNotFound, - }, - } - for _, tc := range cases { - err := repo.Delete(context.Background(), tc.invitation.UserID, tc.invitation.DomainID) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) - } -} diff --git a/invitations/postgres/setup_test.go b/invitations/postgres/setup_test.go deleted file mode 100644 index 9fdf48d39..000000000 --- a/invitations/postgres/setup_test.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres_test - -import ( - "database/sql" - "fmt" - "log" - "os" - "testing" - "time" - - ipostgres "github.com/absmach/supermq/invitations/postgres" - "github.com/absmach/supermq/pkg/postgres" - "github.com/jmoiron/sqlx" - dockertest "github.com/ory/dockertest/v3" - "github.com/ory/dockertest/v3/docker" - "go.opentelemetry.io/otel" -) - -var ( - db *sqlx.DB - database postgres.Database - tracer = otel.Tracer("repo_tests") -) - -func TestMain(m *testing.M) { - pool, err := dockertest.NewPool("") - if err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - container, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "postgres", - Tag: "16.2-alpine", - Env: []string{ - "POSTGRES_USER=test", - "POSTGRES_PASSWORD=test", - "POSTGRES_DB=test", - "listen_addresses = '*'", - }, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) - if err != nil { - log.Fatalf("Could not start container: %s", err) - } - - port := container.GetPort("5432/tcp") - - // exponential backoff-retry, because the application in the container might not be ready to accept connections yet - pool.MaxWait = 120 * time.Second - if err := pool.Retry(func() error { - url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) - db, err := sql.Open("pgx", url) - if err != nil { - return err - } - return db.Ping() - }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) - } - - dbConfig := postgres.Config{ - Host: "localhost", - Port: port, - User: "test", - Pass: "test", - Name: "test", - SSLMode: "disable", - SSLCert: "", - SSLKey: "", - SSLRootCert: "", - } - - if db, err = postgres.Setup(dbConfig, *ipostgres.Migration()); err != nil { - log.Fatalf("Could not setup test DB connection: %s", err) - } - - if db, err = postgres.Connect(dbConfig); err != nil { - log.Fatalf("Could not setup test DB connection: %s", err) - } - database = postgres.NewDatabase(db, dbConfig, tracer) - - code := m.Run() - - // Defers will not be run when using os.Exit - db.Close() - if err := pool.Purge(container); err != nil { - log.Fatalf("Could not purge container: %s", err) - } - - os.Exit(code) -} diff --git a/invitations/service.go b/invitations/service.go deleted file mode 100644 index b157d7b47..000000000 --- a/invitations/service.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package invitations - -import ( - "context" - "time" - - grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" - "github.com/absmach/supermq/auth" - "github.com/absmach/supermq/pkg/authn" - svcerr "github.com/absmach/supermq/pkg/errors/service" - mgsdk "github.com/absmach/supermq/pkg/sdk" -) - -type service struct { - token grpcTokenV1.TokenServiceClient - repo Repository - sdk mgsdk.SDK -} - -func NewService(token grpcTokenV1.TokenServiceClient, repo Repository, sdk mgsdk.SDK) Service { - return &service{ - token: token, - repo: repo, - sdk: sdk, - } -} - -func (svc *service) SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) error { - if err := CheckRelation(invitation.Relation); err != nil { - return err - } - - invitation.InvitedBy = session.UserID - - joinToken, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: session.UserID, Type: uint32(auth.InvitationKey)}) - if err != nil { - return err - } - invitation.Token = joinToken.GetAccessToken() - - if invitation.Resend { - invitation.UpdatedAt = time.Now() - - return svc.repo.UpdateToken(ctx, invitation) - } - - invitation.CreatedAt = time.Now() - - if err := svc.repo.Create(ctx, invitation); err != nil { - return err - } - return nil -} - -func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error) { - inv, err := svc.repo.Retrieve(ctx, userID, domainID) - if err != nil { - return Invitation{}, err - } - inv.Token = "" - - return inv, nil -} - -func (svc *service) ListInvitations(ctx context.Context, session authn.Session, page Page) (invitations InvitationPage, err error) { - ip, err := svc.repo.RetrieveAll(ctx, page) - if err != nil { - return InvitationPage{}, err - } - return ip, nil -} - -func (svc *service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error { - inv, err := svc.repo.Retrieve(ctx, session.UserID, domainID) - if err != nil { - return err - } - - if inv.UserID != session.UserID { - return svcerr.ErrAuthorization - } - - if !inv.ConfirmedAt.IsZero() { - return svcerr.ErrInvitationAlreadyAccepted - } - - if !inv.RejectedAt.IsZero() { - return svcerr.ErrInvitationAlreadyRejected - } - - if _, sdkerr := svc.sdk.AddDomainRoleMembers(inv.DomainID, "admin", []string{session.UserID}, inv.Token); sdkerr != nil { - return sdkerr - } - - inv.ConfirmedAt = time.Now() - inv.UpdatedAt = inv.ConfirmedAt - return svc.repo.UpdateConfirmation(ctx, inv) -} - -func (svc *service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error { - inv, err := svc.repo.Retrieve(ctx, session.UserID, domainID) - if err != nil { - return err - } - - if inv.UserID != session.UserID { - return svcerr.ErrAuthorization - } - - if !inv.ConfirmedAt.IsZero() { - return svcerr.ErrInvitationAlreadyAccepted - } - - if !inv.RejectedAt.IsZero() { - return svcerr.ErrInvitationAlreadyRejected - } - - inv.RejectedAt = time.Now() - inv.UpdatedAt = inv.RejectedAt - return svc.repo.UpdateRejection(ctx, inv) -} - -func (svc *service) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) error { - if session.UserID == userID { - return svc.repo.Delete(ctx, userID, domainID) - } - - inv, err := svc.repo.Retrieve(ctx, userID, domainID) - if err != nil { - return err - } - - if inv.InvitedBy == session.UserID { - return svc.repo.Delete(ctx, userID, domainID) - } - - return svc.repo.Delete(ctx, userID, domainID) -} diff --git a/invitations/service_test.go b/invitations/service_test.go deleted file mode 100644 index 3b28ddabd..000000000 --- a/invitations/service_test.go +++ /dev/null @@ -1,513 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package invitations_test - -import ( - "context" - "testing" - "time" - - grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" - apiutil "github.com/absmach/supermq/api/http/util" - authmocks "github.com/absmach/supermq/auth/mocks" - "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/invitations/mocks" - "github.com/absmach/supermq/pkg/authn" - "github.com/absmach/supermq/pkg/errors" - repoerr "github.com/absmach/supermq/pkg/errors/repository" - svcerr "github.com/absmach/supermq/pkg/errors/service" - "github.com/absmach/supermq/pkg/policies" - sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -var ( - validInvitation = invitations.Invitation{ - UserID: testsutil.GenerateUUID(&testing.T{}), - DomainID: testsutil.GenerateUUID(&testing.T{}), - Relation: policies.ContributorRelation, - } - validDomainUserID = "domain_user_id" - validUserID = "user_id" - validDomainID = "domain_id" - validToken = "valid_token" - invalidToken = "invalid" -) - -func TestSendInvitation(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - svc := invitations.NewService(token, repo, nil) - - cases := []struct { - desc string - token string - session authn.Session - tokenUserID string - req invitations.Invitation - err error - issueErr error - repoErr error - }{ - { - desc: "send invitation successful", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - tokenUserID: testsutil.GenerateUUID(t), - req: validInvitation, - err: nil, - issueErr: nil, - repoErr: nil, - }, - { - desc: "failed to issue token", - token: invalidToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - tokenUserID: testsutil.GenerateUUID(t), - req: validInvitation, - err: svcerr.ErrCreateEntity, - issueErr: svcerr.ErrCreateEntity, - repoErr: nil, - }, - { - desc: "invalid relation", - token: validToken, - tokenUserID: testsutil.GenerateUUID(t), - req: invitations.Invitation{Relation: "invalid"}, - err: apiutil.ErrInvalidRelation, - issueErr: nil, - repoErr: nil, - }, - { - desc: "resend invitation", - token: invalidToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - tokenUserID: testsutil.GenerateUUID(t), - req: invitations.Invitation{ - UserID: validInvitation.UserID, - DomainID: validInvitation.DomainID, - Relation: validInvitation.Relation, - Resend: true, - }, - err: nil, - issueErr: nil, - repoErr: nil, - }, - { - desc: "error during token issuance", - token: validToken, - tokenUserID: testsutil.GenerateUUID(t), - req: validInvitation, - err: svcerr.ErrAuthentication, - issueErr: svcerr.ErrAuthentication, - repoErr: nil, - }, - } - - for _, tc := range cases { - repocall1 := token.On("Issue", context.Background(), mock.Anything).Return(&grpcTokenV1.Token{AccessToken: tc.req.Token}, tc.issueErr) - repocall2 := repo.On("Create", context.Background(), mock.Anything).Return(tc.repoErr) - if tc.req.Resend { - repocall2 = repo.On("UpdateToken", context.Background(), mock.Anything).Return(tc.repoErr) - } - err := svc.SendInvitation(context.Background(), tc.session, tc.req) - assert.Equal(t, tc.err, err, tc.desc) - repocall1.Unset() - repocall2.Unset() - } -} - -func TestViewInvitation(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - svc := invitations.NewService(token, repo, nil) - - validInvitation := invitations.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Relation: policies.ContributorRelation, - CreatedAt: time.Now().Add(-time.Hour), - UpdatedAt: time.Now().Add(-time.Hour), - ConfirmedAt: time.Now().Add(-time.Hour), - } - cases := []struct { - desc string - token string - userID string - domainID string - session authn.Session - tokenUserID string - req invitations.Invitation - resp invitations.Invitation - err error - issueErr error - repoErr error - }{ - { - desc: "view invitation successful", - token: validToken, - tokenUserID: testsutil.GenerateUUID(t), - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - resp: validInvitation, - err: nil, - repoErr: nil, - }, - - { - desc: "error retrieving invitation", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - tokenUserID: testsutil.GenerateUUID(t), - err: svcerr.ErrNotFound, - repoErr: svcerr.ErrNotFound, - }, - { - desc: "valid invitation for the same user", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - resp: validInvitation, - tokenUserID: validInvitation.UserID, - err: nil, - repoErr: nil, - }, - { - desc: "valid invitation for the invited user", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - tokenUserID: validInvitation.InvitedBy, - resp: validInvitation, - err: nil, - repoErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr) - inv, err := svc.ViewInvitation(context.Background(), tc.session, tc.userID, tc.domainID) - assert.Equal(t, tc.err, err, tc.desc) - assert.Equal(t, tc.resp, inv, tc.desc) - repocall1.Unset() - }) - } -} - -func TestListInvitations(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - svc := invitations.NewService(token, repo, nil) - - validPage := invitations.Page{ - Offset: 0, - Limit: 10, - } - validResp := invitations.InvitationPage{ - Total: 1, - Offset: 0, - Limit: 10, - Invitations: []invitations.Invitation{ - { - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Relation: policies.ContributorRelation, - CreatedAt: time.Now().Add(-time.Hour), - UpdatedAt: time.Now().Add(-time.Hour), - ConfirmedAt: time.Now().Add(-time.Hour), - }, - }, - } - - cases := []struct { - desc string - session authn.Session - page invitations.Page - resp invitations.InvitationPage - err error - repoErr error - }{ - { - desc: "list invitations successful", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - page: validPage, - resp: validResp, - err: nil, - repoErr: nil, - }, - - { - desc: "list invitations unsuccessful", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID}, - page: validPage, - err: repoerr.ErrViewEntity, - resp: invitations.InvitationPage{}, - repoErr: repoerr.ErrViewEntity, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - repocall1 := repo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.resp, tc.repoErr) - resp, err := svc.ListInvitations(context.Background(), tc.session, tc.page) - assert.Equal(t, tc.err, err, tc.desc) - assert.Equal(t, tc.resp, resp, tc.desc) - repocall1.Unset() - }) - } -} - -func TestAcceptInvitation(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - sdksvc := new(sdkmocks.SDK) - svc := invitations.NewService(token, repo, sdksvc) - - userID := testsutil.GenerateUUID(t) - - cases := []struct { - desc string - token string - domainID string - session authn.Session - resp invitations.Invitation - err error - repoErr error - sdkErr errors.SDKError - repoErr1 error - }{ - { - desc: "accept invitation successful", - token: validToken, - domainID: "", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - resp: invitations.Invitation{ - UserID: userID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.ContributorRelation, - }, - err: nil, - repoErr: nil, - }, - { - desc: "accept invitation with failed to retrieve all", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - err: svcerr.ErrNotFound, - repoErr: svcerr.ErrNotFound, - }, - { - desc: "accept invitation with sdk err", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: "", - resp: invitations.Invitation{ - UserID: userID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.ContributorRelation, - }, - err: errors.NewSDKError(svcerr.ErrConflict), - repoErr: nil, - sdkErr: errors.NewSDKError(svcerr.ErrConflict), - }, - { - desc: "accept invitation with failed update confirmation", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: "", - resp: invitations.Invitation{ - UserID: userID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.ContributorRelation, - }, - err: svcerr.ErrUpdateEntity, - repoErr: nil, - repoErr1: svcerr.ErrUpdateEntity, - }, - { - desc: "accept invitation that is already confirmed", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: "", - resp: invitations.Invitation{ - UserID: userID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.ContributorRelation, - ConfirmedAt: time.Now(), - }, - err: svcerr.ErrInvitationAlreadyAccepted, - repoErr: nil, - }, - { - desc: "accept rejected invitation", - token: validToken, - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: "", - resp: invitations.Invitation{ - UserID: userID, - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.ContributorRelation, - RejectedAt: time.Now(), - }, - err: svcerr.ErrInvitationAlreadyRejected, - repoErr: nil, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, tc.domainID).Return(tc.resp, tc.repoErr) - sdkcall := sdksvc.On("AddDomainRoleMembers", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, tc.sdkErr) - repocall2 := repo.On("UpdateConfirmation", context.Background(), mock.Anything).Return(tc.repoErr1) - err := svc.AcceptInvitation(context.Background(), tc.session, tc.domainID) - assert.Equal(t, tc.err, err, tc.desc) - repocall1.Unset() - sdkcall.Unset() - repocall2.Unset() - }) - } -} - -func TestDeleteInvitation(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - svc := invitations.NewService(token, repo, nil) - - cases := []struct { - desc string - token string - userID string - domainID string - resp invitations.Invitation - err error - repoErr error - }{ - { - desc: "delete invitations successful", - userID: testsutil.GenerateUUID(t), - domainID: testsutil.GenerateUUID(t), - resp: validInvitation, - err: nil, - repoErr: nil, - }, - { - desc: "delete invitations for the same user", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - resp: validInvitation, - err: nil, - repoErr: nil, - }, - { - desc: "delete invitations for the invited user", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - resp: validInvitation, - err: nil, - repoErr: nil, - }, - { - desc: "error retrieving invitation", - token: validToken, - userID: validInvitation.UserID, - domainID: validInvitation.DomainID, - resp: invitations.Invitation{}, - err: svcerr.ErrNotFound, - repoErr: svcerr.ErrNotFound, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr) - repocall2 := repo.On("Delete", context.Background(), mock.Anything, mock.Anything).Return(tc.repoErr) - err := svc.DeleteInvitation(context.Background(), authn.Session{}, tc.userID, tc.domainID) - assert.Equal(t, tc.err, err, tc.desc) - repocall1.Unset() - repocall2.Unset() - }) - } -} - -func TestRejectInvitation(t *testing.T) { - repo := new(mocks.Repository) - token := new(authmocks.TokenServiceClient) - svc := invitations.NewService(token, repo, nil) - userID := validInvitation.UserID - - cases := []struct { - desc string - session authn.Session - domainID string - resp invitations.Invitation - err error - repoErr error - repoErr1 error - }{ - { - desc: "reject invitations for the same user", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: validInvitation.DomainID, - resp: validInvitation, - err: nil, - repoErr: nil, - repoErr1: nil, - }, - { - desc: "reject invitations for the invited user", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: validInvitation.DomainID, - resp: invitations.Invitation{}, - err: svcerr.ErrAuthorization, - repoErr: nil, - repoErr1: nil, - }, - { - desc: "error retrieving invitation", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: validInvitation.DomainID, - resp: invitations.Invitation{}, - err: repoerr.ErrNotFound, - repoErr: repoerr.ErrNotFound, - repoErr1: nil, - }, - { - desc: "error updating rejection", - session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID}, - domainID: validInvitation.DomainID, - resp: validInvitation, - err: repoerr.ErrUpdateEntity, - repoErr: nil, - repoErr1: repoerr.ErrUpdateEntity, - }, - } - - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr) - repocall3 := repo.On("UpdateRejection", context.Background(), mock.Anything).Return(tc.repoErr1) - err := svc.RejectInvitation(context.Background(), tc.session, tc.domainID) - assert.Equal(t, tc.err, err, tc.desc) - repocall1.Unset() - repocall3.Unset() - }) - } -} diff --git a/pkg/domains/events/consumer/stream.go b/pkg/domains/events/consumer/stream.go index 70412368c..63772777e 100644 --- a/pkg/domains/events/consumer/stream.go +++ b/pkg/domains/events/consumer/stream.go @@ -106,7 +106,7 @@ func (es *eventHandler) createDomainHandler(ctx context.Context, data map[string return errors.Wrap(errCreateDomainEvent, err) } - if _, err := es.repo.Save(ctx, d); err != nil { + if _, err := es.repo.SaveDomain(ctx, d); err != nil { return errors.Wrap(errCreateDomainEvent, err) } if _, err := es.repo.AddRoles(ctx, rps); err != nil { @@ -122,7 +122,7 @@ func (es *eventHandler) updateDomainHandler(ctx context.Context, data map[string return errors.Wrap(errUpdateDomainEvent, err) } - if _, err := es.repo.Update( + if _, err := es.repo.UpdateDomain( ctx, d.ID, domains.DomainReq{ @@ -147,7 +147,7 @@ func (es *eventHandler) enableDomainHandler(ctx context.Context, data map[string } enabled := domains.EnabledStatus - if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &enabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { + if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &enabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { return errors.Wrap(errEnableDomainGroupEvent, err) } @@ -161,7 +161,7 @@ func (es *eventHandler) disableDomainHandler(ctx context.Context, data map[strin } disabled := domains.DisabledStatus - if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &disabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { + if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &disabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { return errors.Wrap(errDisableDomainGroupEvent, err) } @@ -175,7 +175,7 @@ func (es *eventHandler) freezeDomainHandler(ctx context.Context, data map[string } freeze := domains.FreezeStatus - if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &freeze, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { + if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &freeze, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil { return errors.Wrap(errFreezeDomainGroupEvent, err) } @@ -197,7 +197,7 @@ func (es *eventHandler) deleteDomainHandler(ctx context.Context, data map[string return errors.Wrap(errDeleteDomainEvent, err) } - if err := es.repo.Delete(ctx, d.ID); err != nil { + if err := es.repo.DeleteDomain(ctx, d.ID); err != nil { return errors.Wrap(errDeleteDomainEvent, err) } diff --git a/pkg/sdk/domains_test.go b/pkg/sdk/domains_test.go index 667d4cdfa..b032fd590 100644 --- a/pkg/sdk/domains_test.go +++ b/pkg/sdk/domains_test.go @@ -51,8 +51,8 @@ func setupDomains() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio mux := chi.NewRouter() authn := new(authnmocks.Authentication) - mux = domainapi.MakeHandler(svc, authn, mux, logger, "") - return httptest.NewServer(mux), svc, authn + handler := domainapi.MakeHandler(svc, authn, mux, logger, "") + return httptest.NewServer(handler), svc, authn } func TestCreateDomain(t *testing.T) { diff --git a/pkg/sdk/invitations.go b/pkg/sdk/invitations.go index efedf17fa..71417f8ea 100644 --- a/pkg/sdk/invitations.go +++ b/pkg/sdk/invitations.go @@ -18,16 +18,16 @@ const ( ) type Invitation struct { - InvitedBy string `json:"invited_by"` - UserID string `json:"user_id"` - DomainID string `json:"domain_id"` - Token string `json:"token,omitempty"` - Relation string `json:"relation,omitempty"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at,omitempty"` - ConfirmedAt time.Time `json:"confirmed_at,omitempty"` - RejectedAt time.Time `json:"rejected_at,omitempty"` - Resend bool `json:"resend,omitempty"` + InvitedBy string `json:"invited_by"` + InviteeUserID string `json:"invitee_user_id"` + DomainID string `json:"domain_id"` + RoleID string `json:"role_id,omitempty"` + RoleName string `json:"role_name,omitempty"` + Actions []string `json:"actions,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + ConfirmedAt time.Time `json:"confirmed_at,omitempty"` + RejectedAt time.Time `json:"rejected_at,omitempty"` } type InvitationPage struct { @@ -43,7 +43,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error) return errors.NewSDKError(err) } - url := sdk.invitationsURL + "/" + invitationsEndpoint + url := sdk.domainsURL + "/" + domainsEndpoint + "/" + invitation.DomainID + "/" + invitationsEndpoint _, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated) @@ -51,7 +51,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error) } func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitation, err error) { - url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID + url := sdk.domainsURL + "/" + domainsEndpoint + "/" + domainID + "/" + invitationsEndpoint + "/" + userID _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) if sdkerr != nil { @@ -66,7 +66,7 @@ func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitati } func (sdk mgSDK) Invitations(pm PageMetadata, token string) (invitations InvitationPage, err error) { - url, err := sdk.withQueryParams(sdk.invitationsURL, invitationsEndpoint, pm) + url, err := sdk.withQueryParams(sdk.domainsURL, invitationsEndpoint, pm) if err != nil { return InvitationPage{}, errors.NewSDKError(err) } @@ -95,7 +95,7 @@ func (sdk mgSDK) AcceptInvitation(domainID, token string) (err error) { return errors.NewSDKError(err) } - url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + acceptEndpoint + url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + acceptEndpoint _, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusNoContent) @@ -113,7 +113,7 @@ func (sdk mgSDK) RejectInvitation(domainID, token string) (err error) { return errors.NewSDKError(err) } - url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + rejectEndpoint + url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + rejectEndpoint _, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusNoContent) @@ -121,7 +121,7 @@ func (sdk mgSDK) RejectInvitation(domainID, token string) (err error) { } func (sdk mgSDK) DeleteInvitation(userID, domainID, token string) (err error) { - url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID + url := sdk.domainsURL + "/" + domainsEndpoint + "/" + domainID + "/" + invitationsEndpoint + "/" + userID _, _, sdkerr := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent) diff --git a/pkg/sdk/invitations_test.go b/pkg/sdk/invitations_test.go index 4d3413e69..4e77f5674 100644 --- a/pkg/sdk/invitations_test.go +++ b/pkg/sdk/invitations_test.go @@ -6,21 +6,15 @@ package sdk_test import ( "fmt" "net/http" - "net/http/httptest" "testing" "time" apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/domains" "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/invitations" - "github.com/absmach/supermq/invitations/api" - "github.com/absmach/supermq/invitations/mocks" - smqlog "github.com/absmach/supermq/logger" smqauthn "github.com/absmach/supermq/pkg/authn" - authnmocks "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" - policies "github.com/absmach/supermq/pkg/policies" sdk "github.com/absmach/supermq/pkg/sdk" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -31,29 +25,19 @@ var ( invitation = convertInvitation(sdkInvitation) ) -func setupInvitations() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) { - svc := new(mocks.Service) - logger := smqlog.NewMock() - authn := new(authnmocks.Authentication) - mux := api.MakeHandler(svc, logger, authn, "test") - - return httptest.NewServer(mux), svc, authn -} - func TestSendInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) sendInvitationReq := sdk.Invitation{ - UserID: invitation.UserID, - DomainID: invitation.DomainID, - Relation: invitation.Relation, - Resend: invitation.Resend, + InviteeUserID: invitation.InviteeUserID, + DomainID: invitation.DomainID, + RoleID: invitation.RoleID, } cases := []struct { @@ -61,7 +45,7 @@ func TestSendInvitation(t *testing.T) { token string session smqauthn.Session sendInvitationReq sdk.Invitation - svcReq invitations.Invitation + svcReq domains.Invitation authenticateErr error svcErr error err error @@ -86,7 +70,7 @@ func TestSendInvitation(t *testing.T) { desc: "send invitation with empty token", token: "", sendInvitationReq: sendInvitationReq, - svcReq: invitations.Invitation{}, + svcReq: domains.Invitation{}, svcErr: nil, err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, @@ -94,42 +78,38 @@ func TestSendInvitation(t *testing.T) { desc: "send invitation with empty userID", token: validToken, sendInvitationReq: sdk.Invitation{ - UserID: "", - DomainID: invitation.DomainID, - Relation: invitation.Relation, - Resend: invitation.Resend, + InviteeUserID: "", + DomainID: invitation.DomainID, + RoleID: invitation.RoleID, }, - svcReq: invitations.Invitation{}, + svcReq: domains.Invitation{}, svcErr: nil, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), }, { - desc: "send invitation with invalid relation", + desc: "send invitation with empty role ID", token: validToken, sendInvitationReq: sdk.Invitation{ - UserID: invitation.UserID, - DomainID: invitation.DomainID, - Relation: "invalid", - Resend: invitation.Resend, + InviteeUserID: invitation.InviteeUserID, + DomainID: invitation.DomainID, + RoleID: "", }, - svcReq: invitations.Invitation{}, + svcReq: domains.Invitation{}, svcErr: nil, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidRelation), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), }, { desc: "send inviation with invalid domainID", token: validToken, sendInvitationReq: sdk.Invitation{ - UserID: invitation.UserID, - DomainID: wrongID, - Relation: invitation.Relation, - Resend: invitation.Resend, + InviteeUserID: invitation.InviteeUserID, + DomainID: wrongID, + RoleID: invitation.RoleID, }, - svcReq: invitations.Invitation{ - UserID: invitation.UserID, - DomainID: wrongID, - Relation: invitation.Relation, - Resend: invitation.Resend, + svcReq: domains.Invitation{ + InviteeUserID: invitation.InviteeUserID, + DomainID: wrongID, + RoleID: invitation.RoleID, }, svcErr: svcerr.ErrCreateEntity, err: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity), @@ -138,7 +118,11 @@ func TestSendInvitation(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.sendInvitationReq.UserID, DomainID: tc.sendInvitationReq.DomainID} + tc.session = smqauthn.Session{ + UserID: tc.sendInvitationReq.InviteeUserID, + DomainID: tc.sendInvitationReq.DomainID, + DomainUserID: tc.sendInvitationReq.DomainID + "_" + tc.sendInvitationReq.InviteeUserID, + } } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := svc.On("SendInvitation", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr) @@ -155,11 +139,11 @@ func TestSendInvitation(t *testing.T) { } func TestViewInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) @@ -169,7 +153,7 @@ func TestViewInvitation(t *testing.T) { session smqauthn.Session userID string domainID string - svcRes invitations.Invitation + svcRes domains.Invitation svcErr error authenticateErr error response sdk.Invitation @@ -178,7 +162,7 @@ func TestViewInvitation(t *testing.T) { { desc: "view invitation successfully", token: validToken, - userID: invitation.UserID, + userID: invitation.InviteeUserID, domainID: invitation.DomainID, svcRes: invitation, svcErr: nil, @@ -188,9 +172,9 @@ func TestViewInvitation(t *testing.T) { { desc: "view invitation with invalid token", token: invalidToken, - userID: invitation.UserID, + userID: invitation.InviteeUserID, domainID: invitation.DomainID, - svcRes: invitations.Invitation{}, + svcRes: domains.Invitation{}, authenticateErr: svcerr.ErrAuthentication, response: sdk.Invitation{}, err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), @@ -198,29 +182,29 @@ func TestViewInvitation(t *testing.T) { { desc: "view invitation with empty token", token: "", - userID: invitation.UserID, + userID: invitation.InviteeUserID, domainID: invitation.DomainID, - svcRes: invitations.Invitation{}, + svcRes: domains.Invitation{}, svcErr: nil, response: sdk.Invitation{}, err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, { - desc: "view invitation with empty userID", + desc: "view invitation with empty domainID", token: validToken, - userID: "", - domainID: invitation.DomainID, - svcRes: invitations.Invitation{}, + userID: invitation.InviteeUserID, + domainID: "", + svcRes: domains.Invitation{}, svcErr: nil, response: sdk.Invitation{}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest), }, { desc: "view invitation with invalid domainID", token: validToken, - userID: invitation.UserID, + userID: invitation.InviteeUserID, domainID: wrongID, - svcRes: invitations.Invitation{}, + svcRes: domains.Invitation{}, svcErr: svcerr.ErrNotFound, response: sdk.Invitation{}, err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), @@ -229,7 +213,7 @@ func TestViewInvitation(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID} + tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.userID} } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcRes, tc.svcErr) @@ -247,11 +231,11 @@ func TestViewInvitation(t *testing.T) { } func TestListInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) @@ -260,8 +244,8 @@ func TestListInvitation(t *testing.T) { token string session smqauthn.Session pageMeta sdk.PageMetadata - svcReq invitations.Page - svcRes invitations.InvitationPage + svcReq domains.InvitationPageMeta + svcRes domains.InvitationPage svcErr error authenticateErr error response sdk.InvitationPage @@ -274,13 +258,13 @@ func TestListInvitation(t *testing.T) { Offset: 0, Limit: 10, }, - svcReq: invitations.Page{ + svcReq: domains.InvitationPageMeta{ Offset: 0, Limit: 10, }, - svcRes: invitations.InvitationPage{ + svcRes: domains.InvitationPage{ Total: 1, - Invitations: []invitations.Invitation{invitation}, + Invitations: []domains.Invitation{invitation}, }, svcErr: nil, response: sdk.InvitationPage{ @@ -296,11 +280,11 @@ func TestListInvitation(t *testing.T) { Offset: 0, Limit: 10, }, - svcReq: invitations.Page{ + svcReq: domains.InvitationPageMeta{ Offset: 0, Limit: 10, }, - svcRes: invitations.InvitationPage{}, + svcRes: domains.InvitationPage{}, authenticateErr: svcerr.ErrAuthentication, response: sdk.InvitationPage{}, err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), @@ -309,7 +293,7 @@ func TestListInvitation(t *testing.T) { desc: "list invitations with empty token", token: "", pageMeta: sdk.PageMetadata{}, - svcRes: invitations.InvitationPage{}, + svcRes: domains.InvitationPage{}, svcErr: nil, response: sdk.InvitationPage{}, err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), @@ -321,8 +305,8 @@ func TestListInvitation(t *testing.T) { Offset: 0, Limit: 101, }, - svcReq: invitations.Page{}, - svcRes: invitations.InvitationPage{}, + svcReq: domains.InvitationPageMeta{}, + svcRes: domains.InvitationPage{}, svcErr: nil, response: sdk.InvitationPage{}, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), @@ -349,11 +333,11 @@ func TestListInvitation(t *testing.T) { } func TestAcceptInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) @@ -415,11 +399,11 @@ func TestAcceptInvitation(t *testing.T) { } func TestRejectInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) @@ -481,11 +465,11 @@ func TestRejectInvitation(t *testing.T) { } func TestDeleteInvitation(t *testing.T) { - is, svc, auth := setupInvitations() + is, svc, auth := setupDomains() defer is.Close() conf := sdk.Config{ - InvitationsURL: is.URL, + DomainsURL: is.URL, } mgsdk := sdk.NewSDK(conf) @@ -493,64 +477,64 @@ func TestDeleteInvitation(t *testing.T) { desc string token string session smqauthn.Session - userID string + inviteeUserID string domainID string authenticateErr error svcErr error err error }{ { - desc: "delete invitation successfully", - token: validToken, - userID: invitation.UserID, - domainID: invitation.DomainID, - svcErr: nil, - err: nil, + desc: "delete invitation successfully", + token: validToken, + inviteeUserID: invitation.InviteeUserID, + domainID: invitation.DomainID, + svcErr: nil, + err: nil, }, { desc: "delete invitation with invalid token", token: invalidToken, - userID: invitation.UserID, + inviteeUserID: invitation.InviteeUserID, domainID: invitation.DomainID, authenticateErr: svcerr.ErrAuthentication, err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), }, { - desc: "delete invitation with empty token", - token: "", - userID: invitation.UserID, - domainID: invitation.DomainID, - svcErr: nil, - err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + desc: "delete invitation with empty token", + token: "", + inviteeUserID: invitation.InviteeUserID, + domainID: invitation.DomainID, + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, { - desc: "delete invitation with empty userID", - token: validToken, - userID: "", - domainID: invitation.DomainID, - svcErr: nil, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), + desc: "delete invitation with empty domainID", + token: validToken, + inviteeUserID: invitation.InviteeUserID, + domainID: "", + svcErr: nil, + err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest), }, { - desc: "delete invitation with invalid domainID", - token: validToken, - userID: invitation.UserID, - domainID: wrongID, - svcErr: svcerr.ErrNotFound, - err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), + desc: "delete invitation with invalid domainID", + token: validToken, + inviteeUserID: invitation.InviteeUserID, + domainID: wrongID, + svcErr: svcerr.ErrNotFound, + err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound), }, } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { if tc.token == valid { - tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID} + tc.session = smqauthn.Session{UserID: tc.inviteeUserID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.inviteeUserID} } authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) - svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr) - err := mgsdk.DeleteInvitation(tc.userID, tc.domainID, tc.token) + svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.inviteeUserID, tc.domainID).Return(tc.svcErr) + err := mgsdk.DeleteInvitation(tc.inviteeUserID, tc.domainID, tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID) + ok := svcCall.Parent.AssertCalled(t, "DeleteInvitation", mock.Anything, tc.session, tc.inviteeUserID, tc.domainID) assert.True(t, ok) } svcCall.Unset() @@ -563,13 +547,13 @@ func generateTestInvitation(t *testing.T) sdk.Invitation { createdAt, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z") assert.Nil(t, err, fmt.Sprintf("Unexpected error parsing time: %v", err)) return sdk.Invitation{ - InvitedBy: testsutil.GenerateUUID(t), - UserID: testsutil.GenerateUUID(t), - DomainID: testsutil.GenerateUUID(t), - Token: validToken, - Relation: policies.MemberRelation, - CreatedAt: createdAt, - UpdatedAt: createdAt, - Resend: false, + InvitedBy: testsutil.GenerateUUID(t), + InviteeUserID: testsutil.GenerateUUID(t), + DomainID: testsutil.GenerateUUID(t), + RoleID: testsutil.GenerateUUID(t), + RoleName: "admin", + Actions: []string{"read", "update"}, + CreatedAt: createdAt, + UpdatedAt: createdAt, } } diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 9f257c1c7..d2889209b 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -1318,7 +1318,6 @@ type mgSDK struct { groupsURL string channelsURL string domainsURL string - invitationsURL string journalURL string HostURL string @@ -1336,7 +1335,6 @@ type Config struct { GroupsURL string ChannelsURL string DomainsURL string - InvitationsURL string JournalURL string HostURL string @@ -1355,7 +1353,6 @@ func NewSDK(conf Config) SDK { groupsURL: conf.GroupsURL, channelsURL: conf.ChannelsURL, domainsURL: conf.DomainsURL, - invitationsURL: conf.InvitationsURL, journalURL: conf.JournalURL, HostURL: conf.HostURL, diff --git a/pkg/sdk/setup_test.go b/pkg/sdk/setup_test.go index df71d5006..429b0f505 100644 --- a/pkg/sdk/setup_test.go +++ b/pkg/sdk/setup_test.go @@ -12,9 +12,9 @@ import ( mgchannels "github.com/absmach/supermq/channels" "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/domains" groups "github.com/absmach/supermq/groups" "github.com/absmach/supermq/internal/testsutil" - "github.com/absmach/supermq/invitations" "github.com/absmach/supermq/journal" "github.com/absmach/supermq/pkg/roles" sdk "github.com/absmach/supermq/pkg/sdk" @@ -218,17 +218,18 @@ func convertChannel(g sdk.Channel) mgchannels.Channel { } } -func convertInvitation(i sdk.Invitation) invitations.Invitation { - return invitations.Invitation{ - InvitedBy: i.InvitedBy, - UserID: i.UserID, - DomainID: i.DomainID, - Token: i.Token, - Relation: i.Relation, - CreatedAt: i.CreatedAt, - UpdatedAt: i.UpdatedAt, - ConfirmedAt: i.ConfirmedAt, - Resend: i.Resend, +func convertInvitation(i sdk.Invitation) domains.Invitation { + return domains.Invitation{ + InvitedBy: i.InvitedBy, + InviteeUserID: i.InviteeUserID, + DomainID: i.DomainID, + RoleID: i.RoleID, + RoleName: i.RoleName, + Actions: i.Actions, + CreatedAt: i.CreatedAt, + UpdatedAt: i.UpdatedAt, + ConfirmedAt: i.ConfirmedAt, + RejectedAt: i.RejectedAt, } }