SMQ - 2435 - Merge invitations into domains service (#2676)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-02-13 18:24:39 +03:00
committed by GitHub
parent 597ad23ba1
commit 7667eee725
81 changed files with 4387 additions and 5512 deletions
-17
View File
@@ -17,7 +17,6 @@ on:
- "domains/api/http/**"
- "groups/api/http/**"
- "http/api/**"
- "invitations/api/**"
- "journal/api/**"
- "users/api/**"
@@ -33,7 +32,6 @@ env:
CHANNELS_URL: http://localhost:9005
GROUPS_URL: http://localhost:9004
HTTP_ADAPTER_URL: http://localhost:8008
INVITATIONS_URL: http://localhost:9020
AUTH_URL: http://localhost:9001
CERTS_URL: http://localhost:9019
JOURNAL_URL: http://localhost:9021
@@ -95,11 +93,6 @@ jobs:
- "apidocs/openapi/http.yml"
- "http/api/**"
invitations:
- ".github/workflows/api-tests.yml"
- "apidocs/openapi/invitations.yml"
- "invitations/api/**"
clients:
- ".github/workflows/api-tests.yml"
- "apidocs/openapi/clients.yml"
@@ -170,16 +163,6 @@ jobs:
report: false
args: '--header "Authorization: Client ${{ env.CLIENT_SECRET }}" --contrib-openapi-formats-uuid --hypothesis-suppress-health-check=filter_too_much --stateful=links'
- name: Run Invitations API tests
if: steps.changes.outputs.invitations == 'true'
uses: schemathesis/action@v1
with:
schema: apidocs/openapi/invitations.yml
base-url: ${{ env.INVITATIONS_URL }}
checks: all
report: false
args: '--header "Authorization: Bearer ${{ env.USER_TOKEN }}" --contrib-openapi-formats-uuid --hypothesis-suppress-health-check=filter_too_much --stateful=links'
- name: Run Auth API tests
if: steps.changes.outputs.auth == 'true'
uses: schemathesis/action@v1
@@ -58,7 +58,6 @@ jobs:
- "auth/service.go"
- "pkg/events/events.go"
- "pkg/groups/groups.go"
- "invitations/invitations.go"
- "users/emailer.go"
- "users/hasher.go"
- "certs/certs.go"
@@ -119,8 +118,6 @@ jobs:
MOCKERY_VERSION=v2.43.2
go install github.com/vektra/mockery/v2@$MOCKERY_VERSION
mv ./invitations/mocks/repository.go ./invitations/mocks/repository.go.tmp
mv ./invitations/mocks/service.go ./invitations/mocks/service.go.tmp
mv ./auth/mocks/token_client.go ./auth/mocks/token_client.go.tmp
mv ./auth/mocks/authz.go ./auth/mocks/authz.go.tmp
mv ./auth/mocks/keys.go ./auth/mocks/keys.go.tmp
@@ -177,8 +174,6 @@ jobs:
fi
}
check_mock_changes ./invitations/mocks/repository.go " ./invitations/mocks/repository.go"
check_mock_changes ./invitations/mocks/service.go " ./invitations/mocks/service.go"
check_mock_changes ./auth/mocks/token_client.go " ./auth/mocks/token_client.go"
check_mock_changes ./auth/mocks/authz.go " ./auth/mocks/authz.go"
check_mock_changes ./auth/mocks/keys.go " ./auth/mocks/keys.go"
-13
View File
@@ -164,14 +164,6 @@ jobs:
internal:
- "internal/**"
invitations:
- "invitations/**"
- "cmd/invitations/**"
- "auth.pb.go"
- "auth_grpc.pb.go"
- "auth/**"
- "pkg/sdk/**"
journal:
- "journal/**"
- "cmd/journal/**"
@@ -292,11 +284,6 @@ jobs:
run: |
go test --race -v -count=1 -coverprofile=coverage/internal.out ./internal/...
- name: Run invitations tests
if: steps.changes.outputs.invitations == 'true' || steps.changes.outputs.workflow == 'true'
run: |
go test --race -v -count=1 -coverprofile=coverage/invitations.out ./invitations/...
- name: Run logger tests
if: steps.changes.outputs.logger == 'true' || steps.changes.outputs.workflow == 'true'
run: |
+2 -3
View File
@@ -3,8 +3,8 @@
SMQ_DOCKER_IMAGE_NAME_PREFIX ?= supermq
BUILD_DIR ?= build
SERVICES = auth users clients groups channels domains http coap ws cli mqtt certs invitations journal
TEST_API_SERVICES = journal auth certs http invitations clients users channels groups domains
SERVICES = auth users clients groups channels domains http coap ws cli mqtt certs journal
TEST_API_SERVICES = journal auth certs http clients users channels groups domains
TEST_API = $(addprefix test_api_,$(TEST_API_SERVICES))
DOCKERS = $(addprefix docker_,$(SERVICES))
DOCKERS_DEV = $(addprefix docker_dev_,$(SERVICES))
@@ -172,7 +172,6 @@ test_api_domains: TEST_API_URL := http://localhost:9003
test_api_channels: TEST_API_URL := http://localhost:9005
test_api_groups: TEST_API_URL := http://localhost:9004
test_api_http: TEST_API_URL := http://localhost:8008
test_api_invitations: TEST_API_URL := http://localhost:9020
test_api_auth: TEST_API_URL := http://localhost:9001
test_api_certs: TEST_API_URL := http://localhost:9019
test_api_journal: TEST_API_URL := http://localhost:9021
+416 -4
View File
@@ -26,10 +26,15 @@ tags:
description: Find out more about domains
url: https://docs.magistrala.abstractmachines.fr/
- name: Roles
description: All operations involving roles for clients
description: All operations involving roles for domains
externalDocs:
description: Find out more about roles
url: https://docs.supermq.abstractmachines.fr/
description: Find out more about roles
url: https://docs.supermq.abstractmachines.fr/
- name: Invitations
description: All operations involving invitations for domains
externalDocs:
description: Find out more about Invitations
url: https://docs.supermq.abstractmachines.fr/
- name: Health
description: Service health check endpoint.
externalDocs:
@@ -242,7 +247,7 @@ paths:
- bearerAuth: []
responses:
"201":
$ref: "./schemas/roles.yml#/components/responses/CreateRoleRes"
$ref: "./schemas/roles.yml#/components/responses/CreateRoleRes"
"400":
description: Failed due to malformed domain's ID.
"401":
@@ -683,6 +688,216 @@ paths:
"500":
$ref: "#/components/responses/ServiceError"
/domains/{domainID}/invitations:
post:
operationId: sendInvitation
tags:
- Invitations
summary: Send invitation
description: |
Send invitation to user to join domain.
parameters:
- $ref: "#/components/parameters/DomainID"
requestBody:
$ref: "#/components/requestBodies/SendInvitationReq"
security:
- bearerAuth: []
responses:
"201":
description: Invitation sent.
"400":
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"409":
description: Failed due to using an existing identity.
"415":
description: Missing or invalid content type.
"500":
$ref: "#/components/responses/ServiceError"
get:
operationId: listDomainInvitations
tags:
- Invitations
summary: List domain invitations
description: |
Retrieves a list of invitations for a given domain. Due to performance concerns, data
is retrieved in subsets. The API must ensure that the entire
dataset is consumed either by making subsequent requests, or by
increasing the subset size of the initial request.
parameters:
- $ref: "#/components/parameters/DomainID"
- $ref: "#/components/parameters/Limit"
- $ref: "#/components/parameters/Offset"
- $ref: "#/components/parameters/user_id"
- $ref: "#/components/parameters/InvitedBy"
- $ref: "#/components/parameters/State"
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/InvitationPageRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: |
Missing or invalid access token provided.
This endpoint is available only for administrators.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/domains/{domainID}/invitations/{userID}:
get:
operationId: getInvitation
summary: Retrieves a specific invitation
description: |
Retrieves a specific invitation that is identifier by the user ID and domain ID.
tags:
- Invitations
parameters:
- $ref: "#/components/parameters/DomainID"
- $ref: "#/components/parameters/UserID"
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/InvitationRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
delete:
operationId: deleteInvitation
summary: Deletes a specific invitation
description: |
Deletes a specific invitation that is identifier by the user ID and domain ID.
tags:
- Invitations
parameters:
- $ref: "#/components/parameters/DomainID"
- $ref: "#/components/parameters/UserID"
security:
- bearerAuth: []
responses:
"204":
description: Invitation deleted.
"400":
description: Failed due to malformed JSON.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Failed due to non existing user.
"401":
description: Missing or invalid access token provided.
"500":
$ref: "#/components/responses/ServiceError"
/invitations:
get:
operationId: listUserInvitations
tags:
- Invitations
summary: List user invitations
description: |
Retrieves a list of invitations for the current user. Due to performance concerns, data
is retrieved in subsets. The API must ensure that the entire
dataset is consumed either by making subsequent requests, or by
increasing the subset size of the initial request.
parameters:
- $ref: "#/components/parameters/domain_id"
- $ref: "#/components/parameters/Limit"
- $ref: "#/components/parameters/Offset"
- $ref: "#/components/parameters/user_id"
- $ref: "#/components/parameters/InvitedBy"
- $ref: "#/components/parameters/State"
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/InvitationPageRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: |
Missing or invalid access token provided.
This endpoint is available only for administrators.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/invitations/accept:
post:
operationId: acceptInvitation
summary: Accept invitation
description: |
Current logged in user accepts invitation to join domain.
tags:
- Invitations
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/AcceptInvitationReq"
responses:
"204":
description: Invitation accepted.
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"500":
$ref: "#/components/responses/ServiceError"
/invitations/reject:
post:
operationId: rejectInvitation
summary: Reject invitation
description: |
Current logged in user rejects invitation to join domain.
tags:
- Invitations
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/AcceptInvitationReq"
responses:
"204":
description: Invitation rejected.
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"500":
$ref: "#/components/responses/ServiceError"
/health:
get:
summary: Retrieves service health check info.
@@ -824,6 +1039,107 @@ components:
example: domain alias
description: Domain alias.
SendInvitationReqObj:
type: object
properties:
invitee_user_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: User unique identifier.
role_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Identifier for the role to be assigned to the user.
required:
- invitee_user_id
- role_id
Invitation:
type: object
properties:
invited_by:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: User unique identifier.
invitee_user_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Invitee user unique identifier.
domain_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Domain unique identifier.
role_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Role unique identifier.
role_name:
type: string
example: editor
description: Role name.
actions:
type: array
items:
type: string
example: ["read", "write"]
description: Actions allowed for the role.
created_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
updated_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
confirmed_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
xml:
name: invitation
InvitationPage:
type: object
properties:
invitations:
type: array
minItems: 0
uniqueItems: true
items:
$ref: "#/components/schemas/Invitation"
total:
type: integer
example: 1
description: Total number of items.
offset:
type: integer
description: Number of items to skip during retrieval.
limit:
type: integer
example: 10
description: Maximum number of items to return in one page.
required:
- invitations
- total
- offset
Error:
type: object
properties:
error:
type: string
description: Error message
example: { "error": "malformed entity specification" }
parameters:
DomainID:
name: domainID
@@ -910,6 +1226,61 @@ components:
schema:
type: string
required: false
UserID:
name: userID
description: Unique user identifier.
in: path
schema:
type: string
format: uuid
required: true
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
user_id:
name: user_id
description: Unique user identifier.
in: query
schema:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
domain_id:
name: domain_id
description: Unique identifier for a domain.
in: query
schema:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
InvitedBy:
name: invited_by
description: Unique identifier for a user that invited the user.
in: query
schema:
type: string
format: uuid
required: false
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
State:
name: state
description: Invitation state.
in: query
schema:
type: string
enum:
- pending
- accepted
- all
required: false
example: accepted
RoleID:
name: roleID
description: Unique role identifier.
in: query
schema:
type: string
format: uuid
required: true
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
requestBodies:
DomainCreateReq:
@@ -926,6 +1297,28 @@ components:
application/json:
schema:
$ref: "#/components/schemas/DomainUpdate"
SendInvitationReq:
description: JSON-formatted document describing request for sending invitation
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/SendInvitationReqObj"
AcceptInvitationReq:
description: JSON-formatted document describing request for accepting invitation
required: true
content:
application/json:
schema:
type: object
properties:
domain_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Domain unique identifier.
required:
- domain_id
responses:
ServiceError:
@@ -956,6 +1349,25 @@ components:
application/json:
schema:
$ref: "#/components/schemas/DomainsPage"
InvitationRes:
description: Data retrieved.
content:
application/json:
schema:
$ref: "#/components/schemas/Invitation"
links:
delete:
operationId: deleteInvitation
parameters:
user_id: $response.body#/user_id
domain_id: $response.body#/domain_id
InvitationPageRes:
description: Data retrieved.
content:
application/json:
schema:
$ref: "#/components/schemas/InvitationPage"
HealthRes:
description: Service Health Check.
content:
-537
View File
@@ -1,537 +0,0 @@
# Copyright (c) Abstract Machines
# SPDX-License-Identifier: Apache-2.0
openapi: 3.0.3
info:
title: SuperMQ Invitations Service
description: |
This is the Invitations Server based on the OpenAPI 3.0 specification. It is the HTTP API for managing platform invitations. You can now help us improve the API whether it's by making changes to the definition itself or to the code.
Some useful links:
- [The SuperMQ repository](https://github.com/absmach/supermq)
contact:
email: info@abstractmachines.fr
license:
name: Apache 2.0
url: https://github.com/absmach/supermq/blob/main/LICENSE
version: 0.15.1
servers:
- url: http://localhost:9020
- url: https://localhost:9020
tags:
- name: Invitations
description: Everything about your Invitations
externalDocs:
description: Find out more about Invitations
url: https://docs.supermq.abstractmachines.fr/
paths:
/invitations:
post:
operationId: sendInvitation
tags:
- Invitations
summary: Send invitation
description: |
Send invitation to user to join domain.
requestBody:
$ref: "#/components/requestBodies/SendInvitationReq"
security:
- bearerAuth: []
responses:
"201":
description: Invitation sent.
"400":
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"409":
description: Failed due to using an existing identity.
"415":
description: Missing or invalid content type.
"500":
$ref: "#/components/responses/ServiceError"
get:
operationId: listInvitations
tags:
- Invitations
summary: List invitations
description: |
Retrieves a list of invitations. Due to performance concerns, data
is retrieved in subsets. The API must ensure that the entire
dataset is consumed either by making subsequent requests, or by
increasing the subset size of the initial request.
parameters:
- $ref: "#/components/parameters/Limit"
- $ref: "#/components/parameters/Offset"
- $ref: "#/components/parameters/UserID"
- $ref: "#/components/parameters/InvitedBy"
- $ref: "#/components/parameters/DomainID"
- $ref: "#/components/parameters/Relation"
- $ref: "#/components/parameters/State"
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/InvitationPageRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: |
Missing or invalid access token provided.
This endpoint is available only for administrators.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/invitations/accept:
post:
operationId: acceptInvitation
summary: Accept invitation
description: |
Current logged in user accepts invitation to join domain.
tags:
- Invitations
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/AcceptInvitationReq"
responses:
"204":
description: Invitation accepted.
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"500":
$ref: "#/components/responses/ServiceError"
/invitations/reject:
post:
operationId: rejectInvitation
summary: Reject invitation
description: |
Current logged in user rejects invitation to join domain.
tags:
- Invitations
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/AcceptInvitationReq"
responses:
"204":
description: Invitation rejected.
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"500":
$ref: "#/components/responses/ServiceError"
/invitations/{user_id}/{domain_id}:
get:
operationId: getInvitation
summary: Retrieves a specific invitation
description: |
Retrieves a specific invitation that is identifier by the user ID and domain ID.
tags:
- Invitations
parameters:
- $ref: "#/components/parameters/user_id"
- $ref: "#/components/parameters/domain_id"
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/InvitationRes"
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"403":
description: Failed to perform authorization over the entity.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
delete:
operationId: deleteInvitation
summary: Deletes a specific invitation
description: |
Deletes a specific invitation that is identifier by the user ID and domain ID.
tags:
- Invitations
parameters:
- $ref: "#/components/parameters/user_id"
- $ref: "#/components/parameters/domain_id"
security:
- bearerAuth: []
responses:
"204":
description: Invitation deleted.
"400":
description: Failed due to malformed JSON.
"403":
description: Failed to perform authorization over the entity.
"404":
description: Failed due to non existing user.
"401":
description: Missing or invalid access token provided.
"500":
$ref: "#/components/responses/ServiceError"
/health:
get:
summary: Retrieves service health check info.
tags:
- health
security: []
responses:
"200":
$ref: "#/components/responses/HealthRes"
"500":
$ref: "#/components/responses/ServiceError"
components:
schemas:
SendInvitationReqObj:
type: object
properties:
user_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: User unique identifier.
domain_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Domain unique identifier.
relation:
type: string
enum:
- administrator
- editor
- contributor
- member
- guest
- domain
- parent_group
- role_group
- group
- platform
example: editor
description: Relation between user and domain.
resend:
type: boolean
example: true
description: Resend invitation.
required:
- user_id
- domain_id
- relation
Invitation:
type: object
properties:
invited_by:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: User unique identifier.
user_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: User unique identifier.
domain_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Domain unique identifier.
relation:
type: string
enum:
- administrator
- editor
- contributor
- member
- guest
- domain
- parent_group
- role_group
- group
- platform
example: editor
description: Relation between user and domain.
created_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
updated_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
confirmed_at:
type: string
format: date-time
example: "2019-11-26 13:31:52"
description: Time when the group was created.
xml:
name: invitation
InvitationPage:
type: object
properties:
invitations:
type: array
minItems: 0
uniqueItems: true
items:
$ref: "#/components/schemas/Invitation"
total:
type: integer
example: 1
description: Total number of items.
offset:
type: integer
description: Number of items to skip during retrieval.
limit:
type: integer
example: 10
description: Maximum number of items to return in one page.
required:
- invitations
- total
- offset
Error:
type: object
properties:
error:
type: string
description: Error message
example: { "error": "malformed entity specification" }
HealthRes:
type: object
properties:
status:
type: string
description: Service status.
enum:
- pass
version:
type: string
description: Service version.
example: 0.14.0
commit:
type: string
description: Service commit hash.
example: 7d6f4dc4f7f0c1fa3dc24eddfb18bb5073ff4f62
description:
type: string
description: Service description.
example: <service_name> service
build_time:
type: string
description: Service build time.
example: 1970-01-01_00:00:00
parameters:
Offset:
name: offset
description: Number of items to skip during retrieval.
in: query
schema:
type: integer
default: 0
minimum: 0
required: false
example: "0"
Limit:
name: limit
description: Size of the subset to retrieve.
in: query
schema:
type: integer
default: 10
maximum: 10
minimum: 1
required: false
example: "10"
UserID:
name: user_id
description: Unique user identifier.
in: query
schema:
type: string
format: uuid
required: true
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
user_id:
name: user_id
description: Unique user identifier.
in: path
schema:
type: string
format: uuid
required: true
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
DomainID:
name: domain_id
description: Unique identifier for a domain.
in: query
schema:
type: string
format: uuid
required: false
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
domain_id:
name: domain_id
description: Unique identifier for a domain.
in: path
schema:
type: string
format: uuid
required: true
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
InvitedBy:
name: invited_by
description: Unique identifier for a user that invited the user.
in: query
schema:
type: string
format: uuid
required: false
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
Relation:
name: relation
description: Relation between user and domain.
in: query
schema:
type: string
enum:
- administrator
- editor
- contributor
- member
- guest
- domain
- parent_group
- role_group
- group
- platform
required: false
example: editor
State:
name: state
description: Invitation state.
in: query
schema:
type: string
enum:
- pending
- accepted
- all
required: false
example: accepted
requestBodies:
SendInvitationReq:
description: JSON-formatted document describing request for sending invitation
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/SendInvitationReqObj"
AcceptInvitationReq:
description: JSON-formatted document describing request for accepting invitation
required: true
content:
application/json:
schema:
type: object
properties:
domain_id:
type: string
format: uuid
example: bb7edb32-2eac-4aad-aebe-ed96fe073879
description: Domain unique identifier.
required:
- domain_id
responses:
InvitationRes:
description: Data retrieved.
content:
application/json:
schema:
$ref: "#/components/schemas/Invitation"
links:
delete:
operationId: deleteInvitation
parameters:
user_id: $response.body#/user_id
domain_id: $response.body#/domain_id
InvitationPageRes:
description: Data retrieved.
content:
application/json:
schema:
$ref: "#/components/schemas/InvitationPage"
HealthRes:
description: Service Health Check.
content:
application/health+json:
schema:
$ref: "#/components/schemas/HealthRes"
ServiceError:
description: Unexpected server-side error occurred.
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
securitySchemes:
bearerAuth:
type: http
scheme: bearer
bearerFormat: JWT
description: |
* User access: "Authorization: Bearer <user_access_token>"
security:
- bearerAuth: []
-7
View File
@@ -25,7 +25,6 @@ const (
defChannelsURL string = defURL + ":9005"
defGroupsURL string = defURL + ":9004"
defCertsURL string = defURL + ":9019"
defInvitationsURL string = defURL + ":9020"
defHTTPURL string = defURL + ":8008"
defJournalURL string = defURL + ":9021"
defTLSVerification bool = false
@@ -43,7 +42,6 @@ type remotes struct {
GroupsURL string `toml:"groups_url"`
HTTPAdapterURL string `toml:"http_adapter_url"`
CertsURL string `toml:"certs_url"`
InvitationsURL string `toml:"invitations_url"`
JournalURL string `toml:"journal_url"`
HostURL string `toml:"host_url"`
TLSVerification bool `toml:"tls_verification"`
@@ -114,7 +112,6 @@ func ParseConfig(sdkConf smqsdk.Config) (smqsdk.Config, error) {
GroupsURL: defGroupsURL,
HTTPAdapterURL: defHTTPURL,
CertsURL: defCertsURL,
InvitationsURL: defInvitationsURL,
JournalURL: defJournalURL,
HostURL: defURL,
TLSVerification: defTLSVerification,
@@ -199,10 +196,6 @@ func ParseConfig(sdkConf smqsdk.Config) (smqsdk.Config, error) {
sdkConf.CertsURL = config.Remotes.CertsURL
}
if sdkConf.InvitationsURL == "" && config.Remotes.InvitationsURL != "" {
sdkConf.InvitationsURL = config.Remotes.InvitationsURL
}
if sdkConf.JournalURL == "" && config.Remotes.JournalURL != "" {
sdkConf.JournalURL = config.Remotes.JournalURL
}
+5 -5
View File
@@ -10,20 +10,20 @@ import (
var cmdInvitations = []cobra.Command{
{
Use: "send <user_id> <domain_id> <relation> <user_auth_token>",
Use: "send <user_id> <domain_id> <role_id> <user_auth_token>",
Short: "Send invitation",
Long: "Send invitation to user\n" +
"For example:\n" +
"\tsupermq-cli invitations send 39f97daf-d6b6-40f4-b229-2697be8006ef 4ef09eff-d500-4d56-b04f-d23a512d6f2a administrator $USER_AUTH_TOKEN\n",
"\tsupermq-cli invitations send 39f97daf-d6b6-40f4-b229-2697be8006ef 4ef09eff-d500-4d56-b04f-d23a512d6f2a ba4c904c-e6d4-4978-9417-1694aac6793e $USER_AUTH_TOKEN\n",
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 4 {
logUsageCmd(*cmd, cmd.Use)
return
}
inv := smqsdk.Invitation{
UserID: args[0],
DomainID: args[1],
Relation: args[2],
InviteeUserID: args[0],
DomainID: args[1],
RoleID: args[2],
}
if err := sdk.SendInvitation(inv, args[3]); err != nil {
logErrorCmd(*cmd, err)
+3 -3
View File
@@ -21,9 +21,9 @@ import (
)
var invitation = mgsdk.Invitation{
InvitedBy: testsutil.GenerateUUID(&testing.T{}),
UserID: user.ID,
DomainID: domain.ID,
InvitedBy: testsutil.GenerateUUID(&testing.T{}),
InviteeUserID: user.ID,
DomainID: domain.ID,
}
func TestSendUserInvitationCmd(t *testing.T) {
+1 -1
View File
@@ -233,7 +233,7 @@ func main() {
}
ddatabase := pg.NewDatabase(db, dbConfig, tracer)
drepo := dpostgres.New(ddatabase)
drepo := dpostgres.NewRepository(ddatabase)
if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
logger.Error(fmt.Sprintf("failed to create domains event store : %s", err))
-8
View File
@@ -100,14 +100,6 @@ func main() {
"HTTP adapter URL",
)
rootCmd.PersistentFlags().StringVarP(
&sdkConf.InvitationsURL,
"invitations-url",
"v",
sdkConf.InvitationsURL,
"Inivitations URL",
)
rootCmd.PersistentFlags().StringVarP(
&sdkConf.JournalURL,
"journal-url",
+1 -1
View File
@@ -250,7 +250,7 @@ func main() {
}
ddatabase := pg.NewDatabase(db, dbConfig, tracer)
drepo := dpostgres.New(ddatabase)
drepo := dpostgres.NewRepository(ddatabase)
if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
logger.Error(fmt.Sprintf("failed to create domains event store : %s", err))
+1 -1
View File
@@ -159,7 +159,7 @@ func main() {
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
database := postgres.NewDatabase(db, dbConfig, tracer)
domainsRepo := dpostgres.New(database)
domainsRepo := dpostgres.NewRepository(database)
cacheclient, err := redisclient.Connect(cfg.CacheURL)
if err != nil {
+1 -1
View File
@@ -230,7 +230,7 @@ func main() {
}
ddatabase := pg.NewDatabase(db, dbConfig, tracer)
drepo := dpostgres.New(ddatabase)
drepo := dpostgres.NewRepository(ddatabase)
if err := dconsumer.DomainsEventsSubscribe(ctx, drepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil {
logger.Error(fmt.Sprintf("failed to create domains event store : %s", err))
-213
View File
@@ -1,213 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package main contains invitations main function to start the invitations service.
package main
import (
"context"
"fmt"
"log"
"log/slog"
"net/url"
"os"
chclient "github.com/absmach/callhome/pkg/client"
"github.com/absmach/supermq"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
"github.com/absmach/supermq/invitations"
httpapi "github.com/absmach/supermq/invitations/api"
"github.com/absmach/supermq/invitations/middleware"
invitationspg "github.com/absmach/supermq/invitations/postgres"
smqlog "github.com/absmach/supermq/logger"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient"
"github.com/absmach/supermq/pkg/grpcclient"
"github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/postgres"
clientspg "github.com/absmach/supermq/pkg/postgres"
"github.com/absmach/supermq/pkg/prometheus"
mgsdk "github.com/absmach/supermq/pkg/sdk"
"github.com/absmach/supermq/pkg/server"
"github.com/absmach/supermq/pkg/server/http"
"github.com/absmach/supermq/pkg/uuid"
"github.com/caarlos0/env/v11"
"github.com/jmoiron/sqlx"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
)
const (
svcName = "invitations"
envPrefixDB = "SMQ_INVITATIONS_DB_"
envPrefixHTTP = "SMQ_INVITATIONS_HTTP_"
envPrefixAuth = "SMQ_AUTH_GRPC_"
envPrefixDomains = "SMQ_DOMAINS_GRPC_"
defDB = "invitations"
defSvcHTTPPort = "9020"
)
type config struct {
LogLevel string `env:"SMQ_INVITATIONS_LOG_LEVEL" envDefault:"info"`
UsersURL string `env:"SMQ_USERS_URL" envDefault:"http://localhost:9002"`
DomainsURL string `env:"SMQ_DOMAINS_URL" envDefault:"http://localhost:9003"`
InstanceID string `env:"SMQ_INVITATIONS_INSTANCE_ID" envDefault:""`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
g, ctx := errgroup.WithContext(ctx)
cfg := config{}
if err := env.Parse(&cfg); err != nil {
log.Fatalf("failed to load %s configuration : %s", svcName, err)
}
logger, err := smqlog.New(os.Stdout, cfg.LogLevel)
if err != nil {
log.Fatalf("failed to init logger: %s", err.Error())
}
var exitCode int
defer smqlog.ExitWithError(&exitCode)
if cfg.InstanceID == "" {
if cfg.InstanceID, err = uuid.New().ID(); err != nil {
logger.Error(fmt.Sprintf("failed to generate instanceID: %s", err))
exitCode = 1
return
}
}
dbConfig := clientspg.Config{Name: defDB}
if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s database configuration : %s", svcName, err))
exitCode = 1
return
}
db, err := clientspg.Setup(dbConfig, *invitationspg.Migration())
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer db.Close()
authClientCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&authClientCfg, env.Options{Prefix: envPrefixAuth}); err != nil {
logger.Error(fmt.Sprintf("failed to load auth gRPC client configuration : %s", err.Error()))
exitCode = 1
return
}
tokenClient, tokenHandler, err := grpcclient.SetupTokenClient(ctx, authClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer tokenHandler.Close()
logger.Info("Token service client successfully connected to auth gRPC server " + tokenHandler.Secure())
authn, authnHandler, err := authsvcAuthn.NewAuthentication(ctx, authClientCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authnHandler.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
exitCode = 1
return
}
domAuthz, _, domainsHandler, err := domainsAuthz.NewAuthorization(ctx, domsGrpcCfg)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer domainsHandler.Close()
authz, authzHandler, err := authsvcAuthz.NewAuthorization(ctx, authClientCfg, domAuthz)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer authzHandler.Close()
logger.Info("Authz successfully connected to auth gRPC server " + authzHandler.Secure())
tp, err := jaeger.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
logger.Error(fmt.Sprintf("failed to init Jaeger: %s", err))
exitCode = 1
return
}
defer func() {
if err := tp.Shutdown(ctx); err != nil {
logger.Error(fmt.Sprintf("error shutting down tracer provider: %v", err))
}
}()
tracer := tp.Tracer(svcName)
svc, err := newService(db, dbConfig, authz, tokenClient, tracer, cfg, logger)
if err != nil {
logger.Error(fmt.Sprintf("failed to create %s service: %s", svcName, err))
exitCode = 1
return
}
httpServerConfig := server.Config{Port: defSvcHTTPPort}
if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil {
logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err))
exitCode = 1
return
}
httpSvr := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, logger, authn, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
go chc.CallHome(ctx)
}
g.Go(func() error {
return httpSvr.Start()
})
g.Go(func() error {
return server.StopSignalHandler(ctx, cancel, logger, svcName, httpSvr)
})
if err := g.Wait(); err != nil {
logger.Error(fmt.Sprintf("%s service terminated: %s", svcName, err))
}
}
func newService(db *sqlx.DB, dbConfig clientspg.Config, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, tracer trace.Tracer, conf config, logger *slog.Logger) (invitations.Service, error) {
database := postgres.NewDatabase(db, dbConfig, tracer)
repo := invitationspg.NewRepository(database)
config := mgsdk.Config{
UsersURL: conf.UsersURL,
DomainsURL: conf.DomainsURL,
}
sdk := mgsdk.NewSDK(config)
svc := invitations.NewService(token, repo, sdk)
svc = middleware.AuthorizationMiddleware(authz, svc)
svc = middleware.Tracing(svc, tracer)
svc = middleware.Logging(logger, svc)
counter, latency := prometheus.MakeMetrics(svcName, "api")
svc = middleware.Metrics(counter, latency, svc)
return svc, nil
}
-1
View File
@@ -17,7 +17,6 @@ user_token = ""
groups_url = "http://localhost:9004"
host_url = "http://localhost"
http_adapter_url = "http://localhost:8008"
invitations_url = "http://localhost:9020"
journal_url = "http://localhost:9021"
tls_verification = false
users_url = "http://localhost:9002"
-17
View File
@@ -143,23 +143,6 @@ SMQ_SPICEDB_HOST=supermq-spicedb
SMQ_SPICEDB_PORT=50051
SMQ_SPICEDB_DATASTORE_ENGINE=postgres
### Invitations
SMQ_INVITATIONS_LOG_LEVEL=info
SMQ_INVITATIONS_HTTP_HOST=invitations
SMQ_INVITATIONS_HTTP_PORT=9020
SMQ_INVITATIONS_HTTP_SERVER_CERT=
SMQ_INVITATIONS_HTTP_SERVER_KEY=
SMQ_INVITATIONS_DB_HOST=invitations-db
SMQ_INVITATIONS_DB_PORT=5432
SMQ_INVITATIONS_DB_USER=supermq
SMQ_INVITATIONS_DB_PASS=supermq
SMQ_INVITATIONS_DB_NAME=invitations
SMQ_INVITATIONS_DB_SSL_MODE=disable
SMQ_INVITATIONS_DB_SSL_CERT=
SMQ_INVITATIONS_DB_SSL_KEY=
SMQ_INVITATIONS_DB_SSL_ROOT_CERT=
SMQ_INVITATIONS_INSTANCE_ID=
### UI
SMQ_UI_LOG_LEVEL=debug
SMQ_UI_PORT=9095
-78
View File
@@ -19,7 +19,6 @@ volumes:
supermq-pat-db-volume:
supermq-domains-db-volume:
supermq-domains-redis-volume:
supermq-invitations-db-volume:
supermq-ui-db-volume:
services:
@@ -293,83 +292,6 @@ services:
bind:
create_host_path: true
invitations-db:
image: postgres:16.2-alpine
container_name: supermq-invitations-db
restart: on-failure
command: postgres -c "max_connections=${SMQ_POSTGRES_MAX_CONNECTIONS}"
environment:
POSTGRES_USER: ${SMQ_INVITATIONS_DB_USER}
POSTGRES_PASSWORD: ${SMQ_INVITATIONS_DB_PASS}
POSTGRES_DB: ${SMQ_INVITATIONS_DB_NAME}
SMQ_POSTGRES_MAX_CONNECTIONS: ${SMQ_POSTGRES_MAX_CONNECTIONS}
ports:
- 6021:5432
networks:
- supermq-base-net
volumes:
- supermq-invitations-db-volume:/var/lib/postgresql/data
invitations:
image: supermq/invitations:${SMQ_RELEASE_TAG}
container_name: supermq-invitations
restart: on-failure
depends_on:
- auth
- invitations-db
environment:
SMQ_INVITATIONS_LOG_LEVEL: ${SMQ_INVITATIONS_LOG_LEVEL}
SMQ_USERS_URL: ${SMQ_USERS_URL}
SMQ_DOMAINS_URL: ${SMQ_DOMAINS_URL}
SMQ_INVITATIONS_HTTP_HOST: ${SMQ_INVITATIONS_HTTP_HOST}
SMQ_INVITATIONS_HTTP_PORT: ${SMQ_INVITATIONS_HTTP_PORT}
SMQ_INVITATIONS_HTTP_SERVER_CERT: ${SMQ_INVITATIONS_HTTP_SERVER_CERT}
SMQ_INVITATIONS_HTTP_SERVER_KEY: ${SMQ_INVITATIONS_HTTP_SERVER_KEY}
SMQ_INVITATIONS_DB_HOST: ${SMQ_INVITATIONS_DB_HOST}
SMQ_INVITATIONS_DB_USER: ${SMQ_INVITATIONS_DB_USER}
SMQ_INVITATIONS_DB_PASS: ${SMQ_INVITATIONS_DB_PASS}
SMQ_INVITATIONS_DB_PORT: ${SMQ_INVITATIONS_DB_PORT}
SMQ_INVITATIONS_DB_NAME: ${SMQ_INVITATIONS_DB_NAME}
SMQ_INVITATIONS_DB_SSL_MODE: ${SMQ_INVITATIONS_DB_SSL_MODE}
SMQ_INVITATIONS_DB_SSL_CERT: ${SMQ_INVITATIONS_DB_SSL_CERT}
SMQ_INVITATIONS_DB_SSL_KEY: ${SMQ_INVITATIONS_DB_SSL_KEY}
SMQ_INVITATIONS_DB_SSL_ROOT_CERT: ${SMQ_INVITATIONS_DB_SSL_ROOT_CERT}
SMQ_AUTH_GRPC_URL: ${SMQ_AUTH_GRPC_URL}
SMQ_AUTH_GRPC_TIMEOUT: ${SMQ_AUTH_GRPC_TIMEOUT}
SMQ_AUTH_GRPC_CLIENT_CERT: ${SMQ_AUTH_GRPC_CLIENT_CERT:+/auth-grpc-client.crt}
SMQ_AUTH_GRPC_CLIENT_KEY: ${SMQ_AUTH_GRPC_CLIENT_KEY:+/auth-grpc-client.key}
SMQ_AUTH_GRPC_SERVER_CA_CERTS: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+/auth-grpc-server-ca.crt}
SMQ_DOMAINS_GRPC_URL: ${SMQ_DOMAINS_GRPC_URL}
SMQ_DOMAINS_GRPC_TIMEOUT: ${SMQ_DOMAINS_GRPC_TIMEOUT}
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key}
SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt}
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY}
SMQ_INVITATIONS_INSTANCE_ID: ${SMQ_INVITATIONS_INSTANCE_ID}
ports:
- ${SMQ_INVITATIONS_HTTP_PORT}:${SMQ_INVITATIONS_HTTP_PORT}
networks:
- supermq-base-net
volumes:
# Auth gRPC client certificates
- type: bind
source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_CERT:+.crt}
bind:
create_host_path: true
- type: bind
source: ${SMQ_AUTH_GRPC_CLIENT_KEY:-ssl/certs/dummy/client_key}
target: /auth-grpc-client${SMQ_AUTH_GRPC_CLIENT_KEY:+.key}
bind:
create_host_path: true
- type: bind
source: ${SMQ_AUTH_GRPC_SERVER_CA_CERTS:-ssl/certs/dummy/server_ca}
target: /auth-grpc-server-ca${SMQ_AUTH_GRPC_SERVER_CA_CERTS:+.crt}
bind:
create_host_path: true
nginx:
image: nginx:1.25.4-alpine
container_name: supermq-nginx
-1
View File
@@ -22,7 +22,6 @@ envsubst '
${SMQ_HTTP_ADAPTER_PORT}
${SMQ_NGINX_MQTT_PORT}
${SMQ_NGINX_MQTTS_PORT}
${SMQ_INVITATIONS_HTTP_PORT}
${SMQ_WS_ADAPTER_HTTP_PORT}' < /etc/nginx/nginx.conf.template > /etc/nginx/nginx.conf
exec nginx -g "daemon off;"
+1 -8
View File
@@ -85,20 +85,13 @@ http {
proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT};
}
# Proxy pass to domains service
# Proxy pass to channels service
location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT};
}
# Proxy pass to invitations service
location ~ ^/(invitations) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://invitations:${SMQ_INVITATIONS_HTTP_PORT};
}
location /health {
include snippets/proxy-headers.conf;
proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT};
+1 -8
View File
@@ -94,20 +94,13 @@ http {
proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT};
}
# Proxy pass to domains service
# Proxy pass to channels service
location ~ "^/([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(channels)" {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://channels:${SMQ_CHANNELS_HTTP_PORT};
}
# Proxy pass to invitations service
location ~ ^/(invitations) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://invitations:${SMQ_INVITATIONS_HTTP_PORT};
}
location /health {
include snippets/proxy-headers.conf;
proxy_pass http://clients:${SMQ_CLIENTS_HTTP_PORT};
+92
View File
@@ -16,6 +16,15 @@ import (
"github.com/go-chi/chi/v5"
)
const (
inviteeUserIDKey = "invitee_user_id"
domainIDKey = "domain_id"
invitedByKey = "invited_by"
roleIDKey = "role_id"
roleNameKey = "role_name"
stateKey = "state"
)
func decodeCreateDomainRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
@@ -157,3 +166,86 @@ func decodePageRequest(_ context.Context, r *http.Request) (page, error) {
status: st,
}, nil
}
func decodeSendInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
var req sendInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, error) {
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
inviteeUserID, err := apiutil.ReadStringQuery(r, inviteeUserIDKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
invitedBy, err := apiutil.ReadStringQuery(r, invitedByKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
roleID, err := apiutil.ReadStringQuery(r, roleIDKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
domainID, err := apiutil.ReadStringQuery(r, domainIDKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
st, err := apiutil.ReadStringQuery(r, stateKey, domains.AllState.String())
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
state, err := domains.ToState(st)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
req := listInvitationsReq{
InvitationPageMeta: domains.InvitationPageMeta{
Offset: offset,
Limit: limit,
InvitedBy: invitedBy,
InviteeUserID: inviteeUserID,
RoleID: roleID,
DomainID: domainID,
State: state,
},
}
return req, nil
}
func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
var req acceptInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
func decodeInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
req := invitationReq{
userID: chi.URLParam(r, "userID"),
domainID: chi.URLParam(r, "domainID"),
}
return req, nil
}
+161
View File
@@ -15,6 +15,9 @@ import (
"github.com/go-kit/kit/endpoint"
)
// InvitationSent is the message returned when an invitation is sent.
const InvitationSent = "invitation sent"
func createDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(createDomainReq)
@@ -182,3 +185,161 @@ func freezeDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return freezeDomainRes{}, nil
}
}
func sendInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(sendInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
invitation := domains.Invitation{
InviteeUserID: req.InviteeUserID,
DomainID: session.DomainID,
RoleID: req.RoleID,
}
if err := svc.SendInvitation(ctx, session, invitation); err != nil {
return nil, err
}
return sendInvitationRes{
Message: InvitationSent,
}, nil
}
}
func viewInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(invitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.domainID
invitation, err := svc.ViewInvitation(ctx, session, req.userID, req.domainID)
if err != nil {
return nil, err
}
return viewInvitationRes{
Invitation: invitation,
}, nil
}
}
func listDomainInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listInvitationsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
req.InvitationPageMeta.DomainID = session.DomainID
page, err := svc.ListInvitations(ctx, session, req.InvitationPageMeta)
if err != nil {
return nil, err
}
return listInvitationsRes{
page,
}, nil
}
}
func listUserInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listInvitationsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.DomainID
page, err := svc.ListInvitations(ctx, session, req.InvitationPageMeta)
if err != nil {
return nil, err
}
return listInvitationsRes{
page,
}, nil
}
}
func acceptInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(acceptInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.AcceptInvitation(ctx, session, req.DomainID); err != nil {
return nil, err
}
return acceptInvitationRes{}, nil
}
}
func rejectInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(acceptInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.RejectInvitation(ctx, session, req.DomainID); err != nil {
return nil, err
}
return rejectInvitationRes{}, nil
}
}
func deleteInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(invitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.domainID
if err := svc.DeleteInvitation(ctx, session, req.userID, req.domainID); err != nil {
return nil, err
}
return deleteInvitationRes{}, nil
}
}
+583
View File
@@ -45,6 +45,8 @@ var (
inValidToken = "invalid"
invalid = "invalid"
userID = testsutil.GenerateUUID(&testing.T{})
validID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
)
const (
@@ -1072,6 +1074,587 @@ func TestFreezeDomain(t *testing.T) {
}
}
func TestSendInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
domainID string
data string
session authn.Session
contentType string
status int
authnErr error
svcErr error
}{
{
desc: "send invitation with valid request",
token: validToken,
domainID: domainID,
data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID),
status: http.StatusCreated,
contentType: contentType,
svcErr: nil,
},
{
desc: "send invitation with invalid token",
token: "",
domainID: domainID,
data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID),
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "send invitation with empty domain_id",
token: validToken,
domainID: "",
data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID),
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "send invitation with invalid content type",
token: validToken,
domainID: domainID,
data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID),
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "send invitation with invalid data",
token: validToken,
domainID: domainID,
data: `data`,
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "send invitation with service error",
token: validToken,
domainID: domainID,
data: fmt.Sprintf(`{"invitee_user_id": "%s","role_id": "%s"}`, validID, validID),
status: http.StatusForbidden,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("SendInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/domains/%s/invitations", is.URL, tc.domainID),
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestListInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
session authn.Session
query string
contentType string
status int
svcErr error
authnErr error
}{
{
desc: "list invitations with valid request",
token: validToken,
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invalid token",
token: "",
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with offset",
token: validToken,
query: "offset=1",
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invalid offset",
token: validToken,
query: "offset=invalid",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with limit",
token: validToken,
query: "limit=1",
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invalid limit",
token: validToken,
query: "limit=invalid",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invitee_user_id",
token: validToken,
query: fmt.Sprintf("invitee_user_id=%s", validID),
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with duplicate invitee_user_id",
token: validToken,
query: "invitee_user_id=1&invitee_user_id=2",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invited_by",
token: validToken,
query: fmt.Sprintf("invited_by=%s", validID),
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with duplicate invited_by",
token: validToken,
query: "invited_by=1&invited_by=2",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with state",
token: validToken,
query: "state=pending",
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with invalid state",
token: validToken,
query: "state=invalid",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with duplicate state",
token: validToken,
query: "state=all&state=all",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "list invitations with service error",
token: validToken,
status: http.StatusForbidden,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("ListInvitations", mock.Anything, tc.session, mock.Anything).Return(domains.InvitationPage{}, tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodGet,
url: is.URL + "/invitations?" + tc.query,
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestViewInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
session authn.Session
domainID string
userID string
contentType string
status int
svcErr error
authnErr error
}{
{
desc: "view invitation with valid request",
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusOK,
contentType: contentType,
svcErr: nil,
},
{
desc: "view invitation with invalid token",
token: "",
userID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "view invitation with service error",
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusBadRequest,
contentType: contentType,
svcErr: svcerr.ErrViewEntity,
},
{
desc: "view invitation with empty domain",
token: validToken,
userID: validID,
domainID: "",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "view invitation with empty invitee_user_id and domain_id",
token: validToken,
userID: "",
domainID: "",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(domains.Invitation{}, tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodGet,
url: fmt.Sprintf("%s/domains/%s/invitations/%s", is.URL, tc.domainID, tc.userID),
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestDeleteInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
session authn.Session
domainID string
userID string
contentType string
status int
svcErr error
authnErr error
}{
{
desc: "delete invitation with valid request",
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusNoContent,
contentType: contentType,
svcErr: nil,
},
{
desc: "delete invitation with invalid token",
token: "",
userID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "delete invitation with service error",
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusForbidden,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "delete invitation with empty invitee_user_id",
token: validToken,
userID: "",
domainID: domainID,
status: http.StatusMethodNotAllowed,
contentType: contentType,
svcErr: nil,
},
{
desc: "delete invitation with empty domain_id",
token: validToken,
userID: validID,
domainID: "",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
{
desc: "delete invitation with empty invitee_user_id and domain_id",
token: validToken,
userID: "",
domainID: "",
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID, DomainUserID: domainID + "_" + userID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodDelete,
url: fmt.Sprintf("%s/domains/%s/invitations/%s", is.URL, tc.domainID, tc.userID),
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestAcceptInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
session authn.Session
data string
contentType string
status int
svcErr error
authnErr error
}{
{
desc: "accept invitation with valid request",
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
token: validToken,
status: http.StatusNoContent,
contentType: contentType,
svcErr: nil,
},
{
desc: "accept invitation with invalid token",
token: "",
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "accept invitation with service error",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusForbidden,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "accept invitation with invalid content type",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "accept invitation with invalid data",
token: validToken,
data: `data`,
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("AcceptInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: is.URL + "/invitations/accept",
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestRejectInvitation(t *testing.T) {
is, svc, auth := newDomainsServer()
cases := []struct {
desc string
token string
session authn.Session
data string
contentType string
status int
svcErr error
authnErr error
}{
{
desc: "reject invitation with valid request",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusNoContent,
contentType: contentType,
svcErr: nil,
},
{
desc: "reject invitation with invalid token",
token: "",
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnauthorized,
contentType: contentType,
svcErr: nil,
},
{
desc: "reject invitation with unauthorized error",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, "invalid"),
status: http.StatusForbidden,
contentType: contentType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "reject invitation with invalid content type",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "reject invitation with invalid data",
token: validToken,
data: `data`,
status: http.StatusBadRequest,
contentType: contentType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == validToken {
tc.session = authn.Session{UserID: userID, DomainID: domainID}
}
authnCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authnErr)
repoCall := svc.On("RejectInvitation", mock.Anything, tc.session, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: is.URL + "/invitations/reject",
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
type respBody struct {
Err string `json:"error"`
Message string `json:"message"`
+55
View File
@@ -8,6 +8,8 @@ import (
"github.com/absmach/supermq/domains"
)
const maxLimitSize = 100
type page struct {
offset uint64
limit uint64
@@ -112,3 +114,56 @@ func (req freezeDomainReq) validate() error {
return nil
}
type sendInvitationReq struct {
InviteeUserID string `json:"invitee_user_id,omitempty"`
RoleID string `json:"role_id,omitempty"`
}
func (req *sendInvitationReq) validate() error {
if req.InviteeUserID == "" || req.RoleID == "" {
return apiutil.ErrMissingID
}
return nil
}
type listInvitationsReq struct {
domains.InvitationPageMeta
}
func (req *listInvitationsReq) validate() error {
if req.InvitationPageMeta.Limit > maxLimitSize || req.InvitationPageMeta.Limit < 1 {
return apiutil.ErrLimitSize
}
return nil
}
type acceptInvitationReq struct {
DomainID string `json:"domain_id,omitempty"`
}
func (req *acceptInvitationReq) validate() error {
if req.DomainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
type invitationReq struct {
userID string
domainID string
}
func (req *invitationReq) validate() error {
if req.userID == "" {
return apiutil.ErrMissingID
}
if req.domainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
+99
View File
@@ -14,6 +14,15 @@ var (
_ supermq.Response = (*createDomainRes)(nil)
_ supermq.Response = (*retrieveDomainRes)(nil)
_ supermq.Response = (*listDomainsRes)(nil)
_ supermq.Response = (*enableDomainRes)(nil)
_ supermq.Response = (*disableDomainRes)(nil)
_ supermq.Response = (*freezeDomainRes)(nil)
_ supermq.Response = (*sendInvitationRes)(nil)
_ supermq.Response = (*viewInvitationRes)(nil)
_ supermq.Response = (*listInvitationsRes)(nil)
_ supermq.Response = (*acceptInvitationRes)(nil)
_ supermq.Response = (*rejectInvitationRes)(nil)
_ supermq.Response = (*deleteInvitationRes)(nil)
)
type createDomainRes struct {
@@ -121,3 +130,93 @@ func (res freezeDomainRes) Headers() map[string]string {
func (res freezeDomainRes) Empty() bool {
return true
}
type sendInvitationRes struct {
Message string `json:"message"`
}
func (res sendInvitationRes) Code() int {
return http.StatusCreated
}
func (res sendInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res sendInvitationRes) Empty() bool {
return true
}
type viewInvitationRes struct {
domains.Invitation `json:",inline"`
}
func (res viewInvitationRes) Code() int {
return http.StatusOK
}
func (res viewInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res viewInvitationRes) Empty() bool {
return false
}
type listInvitationsRes struct {
domains.InvitationPage `json:",inline"`
}
func (res listInvitationsRes) Code() int {
return http.StatusOK
}
func (res listInvitationsRes) Headers() map[string]string {
return map[string]string{}
}
func (res listInvitationsRes) Empty() bool {
return false
}
type acceptInvitationRes struct{}
func (res acceptInvitationRes) Code() int {
return http.StatusNoContent
}
func (res acceptInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res acceptInvitationRes) Empty() bool {
return true
}
type deleteInvitationRes struct{}
func (res deleteInvitationRes) Code() int {
return http.StatusNoContent
}
func (res deleteInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res deleteInvitationRes) Empty() bool {
return true
}
type rejectInvitationRes struct{}
func (res rejectInvitationRes) Code() int {
return http.StatusNoContent
}
func (res rejectInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res rejectInvitationRes) Empty() bool {
return true
}
+56 -2
View File
@@ -5,6 +5,7 @@ package http
import (
"log/slog"
"net/http"
"github.com/absmach/supermq"
api "github.com/absmach/supermq/api/http"
@@ -18,7 +19,8 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) *chi.Mux {
// MakeHandler returns a HTTP handler for Domains and Invitations API endpoints.
func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -82,9 +84,61 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
), "freeze_domain").ServeHTTP)
roleManagerHttp.EntityRoleMangerRouter(svc, d, r, opts)
})
r.Route("/{domainID}/invitations", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
sendInvitationEndpoint(svc),
decodeSendInvitationReq,
api.EncodeResponse,
opts...,
), "send_invitation").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listDomainInvitationsEndpoint(svc),
decodeListInvitationsReq,
api.EncodeResponse,
opts...,
), "list_domain_invitations").ServeHTTP)
r.Route("/{userID}", func(r chi.Router) {
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
viewInvitationEndpoint(svc),
decodeInvitationReq,
api.EncodeResponse,
opts...,
), "view_invitation").ServeHTTP)
r.Delete("/", otelhttp.NewHandler(kithttp.NewServer(
deleteInvitationEndpoint(svc),
decodeInvitationReq,
api.EncodeResponse,
opts...,
), "delete_invitation").ServeHTTP)
})
})
})
mux.Get("/health", supermq.Health("auth", instanceID))
mux.Route("/invitations", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listUserInvitationsEndpoint(svc),
decodeListInvitationsReq,
api.EncodeResponse,
opts...,
), "list_user_invitations").ServeHTTP)
r.Post("/accept", otelhttp.NewHandler(kithttp.NewServer(
acceptInvitationEndpoint(svc),
decodeAcceptInvitationReq,
api.EncodeResponse,
opts...,
), "accept_invitation").ServeHTTP)
r.Post("/reject", otelhttp.NewHandler(kithttp.NewServer(
rejectInvitationEndpoint(svc),
decodeAcceptInvitationReq,
api.EncodeResponse,
opts...,
), "reject_invitation").ServeHTTP)
})
mux.Get("/health", supermq.Health("domains", instanceID))
mux.Handle("/metrics", promhttp.Handler())
return mux
+83 -12
View File
@@ -105,6 +105,7 @@ type DomainReq struct {
UpdatedBy *string `json:"updated_by,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
type Domain struct {
ID string `json:"id"`
Name string `json:"name"`
@@ -137,7 +138,7 @@ type Page struct {
ID string `json:"id,omitempty"`
IDs []string `json:"-"`
Identity string `json:"identity,omitempty"`
UserID string `json:"-"`
UserID string `json:"user_id,omitempty"`
}
type DomainsPage struct {
@@ -164,13 +165,64 @@ func (page DomainsPage) MarshalJSON() ([]byte, error) {
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// CreateDomain creates a new domain.
CreateDomain(ctx context.Context, sesssion authn.Session, d Domain) (Domain, []roles.RoleProvision, error)
// RetrieveDomain retrieves a domain specified by the provided ID.
RetrieveDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error)
// UpdateDomain updates the domain specified by the provided ID.
UpdateDomain(ctx context.Context, sesssion authn.Session, id string, d DomainReq) (Domain, error)
// EnableDomain enables the domain specified by the provided ID.
EnableDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error)
// DisableDomain disables the domain specified by the provided ID.
// Only platform administrators and domain admins can disable domains.
DisableDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error)
// FreezeDomain freezes the domain specified by the provided ID.
// Only platform administrators can freeze domains.
FreezeDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error)
// ListDomains returns a list of domains.
ListDomains(ctx context.Context, sesssion authn.Session, page Page) (DomainsPage, error)
// SendInvitation sends an invitation to the given user.
// Only domain administrators and platform administrators can send invitations.
SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) (err error)
// ViewInvitation returns an invitation.
// People who can view invitations are:
// - the invited user: they can view their own invitations
// - the user who sent the invitation
// - domain administrators
// - platform administrators
ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation Invitation, err error)
// ListInvitations returns a list of invitations.
// People who can list invitations are:
// - platform administrators can list all invitations
// - domain administrators can list invitations for their domain
// By default, it will list invitations the current user has sent or received.
ListInvitations(ctx context.Context, session authn.Session, page InvitationPageMeta) (invitations InvitationPage, err error)
// AcceptInvitation accepts an invitation by adding the user to the domain.
AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error)
// DeleteInvitation deletes an invitation.
// People who can delete invitations are:
// - the invited user: they can delete their own invitations
// - the user who sent the invitation
// - domain administrators
// - platform administrators
DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error)
// RejectInvitation rejects an invitation.
// People who can reject invitations are:
// - the invited user: they can reject their own invitations
RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error)
roles.RoleManager
}
@@ -178,26 +230,45 @@ type Service interface {
//
//go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines"
type Repository interface {
// Save creates db insert transaction for the given domain.
Save(ctx context.Context, d Domain) (Domain, error)
// SaveDomain creates db insert transaction for the given domain.
SaveDomain(ctx context.Context, d Domain) (Domain, error)
// RetrieveByID retrieves Domain by its unique ID.
RetrieveByID(ctx context.Context, id string) (Domain, error)
// RetrieveDomainByID retrieves a domain by its unique ID.
RetrieveDomainByID(ctx context.Context, id string) (Domain, error)
RetrieveByUserAndID(ctx context.Context, userID, id string) (Domain, error)
// RetrieveDomainByUserAndID retrieves a domain by its unique ID and user ID.
RetrieveDomainByUserAndID(ctx context.Context, userID, id string) (Domain, error)
// RetrieveAllByIDs retrieves for given Domain IDs.
RetrieveAllByIDs(ctx context.Context, pm Page) (DomainsPage, error)
// RetrieveAllDomainsByIDs retrieves for given Domain IDs.
RetrieveAllDomainsByIDs(ctx context.Context, pm Page) (DomainsPage, error)
// Update updates the client name and metadata.
Update(ctx context.Context, id string, d DomainReq) (Domain, error)
// UpdateDomain updates the domain name and metadata.
UpdateDomain(ctx context.Context, id string, d DomainReq) (Domain, error)
// Delete
Delete(ctx context.Context, id string) error
// DeleteDomain deletes the domain.
DeleteDomain(ctx context.Context, id string) error
// ListDomains list all the domains
ListDomains(ctx context.Context, pm Page) (DomainsPage, error)
// CreateInvitation creates an invitation.
SaveInvitation(ctx context.Context, invitation Invitation) (err error)
// RetrieveInvitation retrieves an invitation.
RetrieveInvitation(ctx context.Context, userID, domainID string) (Invitation, error)
// RetrieveAllInvitations retrieves all invitations.
RetrieveAllInvitations(ctx context.Context, page InvitationPageMeta) (invitations InvitationPage, err error)
// UpdateConfirmation updates an invitation by setting the confirmation time.
UpdateConfirmation(ctx context.Context, invitation Invitation) (err error)
// UpdateRejection updates an invitation by setting the rejection time.
UpdateRejection(ctx context.Context, invitation Invitation) (err error)
// Delete deletes an invitation.
DeleteInvitation(ctx context.Context, userID, domainID string) (err error)
roles.Repository
}
+140
View File
@@ -23,6 +23,13 @@ const (
domainFreeze = domainPrefix + "freeze"
domainList = domainPrefix + "list"
domainUserDelete = domainPrefix + "user_delete"
invitationPrefix = "invitation."
invitationSend = invitationPrefix + "send"
invitationAccept = invitationPrefix + "accept"
invitationReject = invitationPrefix + "reject"
invitationList = invitationPrefix + "list"
invitationRetrieve = invitationPrefix + "retrieve"
invitationDelete = invitationPrefix + "delete"
)
var (
@@ -34,6 +41,12 @@ var (
_ events.Event = (*disableDomainEvent)(nil)
_ events.Event = (*freezeDomainEvent)(nil)
_ events.Event = (*listDomainsEvent)(nil)
_ events.Event = (*sendInvitationEvent)(nil)
_ events.Event = (*viewInvitationEvent)(nil)
_ events.Event = (*listInvitationsEvent)(nil)
_ events.Event = (*acceptInvitationEvent)(nil)
_ events.Event = (*rejectInvitationEvent)(nil)
_ events.Event = (*deleteInvitationEvent)(nil)
)
type createDomainEvent struct {
@@ -275,3 +288,130 @@ func (lde listDomainsEvent) Encode() (map[string]interface{}, error) {
return val, nil
}
type sendInvitationEvent struct {
invitation domains.Invitation
session authn.Session
}
func (sie sendInvitationEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationSend,
"invitee_user_id": sie.invitation.InviteeUserID,
"domain_id": sie.invitation.DomainID,
"invited_by": sie.session.UserID,
"role_id": sie.invitation.RoleID,
"token_type": sie.session.Type.String(),
"super_admin": sie.session.SuperAdmin,
}
return val, nil
}
type viewInvitationEvent struct {
inviteeUserID string
domainID string
roleID string
roleName string
session authn.Session
}
func (vie viewInvitationEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationRetrieve,
"invitee_user_id": vie.inviteeUserID,
"domain_id": vie.domainID,
"role_id": vie.roleID,
"role_name": vie.roleName,
"token_type": vie.session.Type.String(),
"super_admin": vie.session.SuperAdmin,
}
return val, nil
}
type listInvitationsEvent struct {
domains.InvitationPageMeta
session authn.Session
}
func (lie listInvitationsEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationList,
"offset": lie.Offset,
"limit": lie.Limit,
"user_id": lie.session.UserID,
"token_type": lie.session.Type.String(),
"super_admin": lie.session.SuperAdmin,
}
if lie.InvitedBy != "" {
val["invited_by"] = lie.InvitedBy
}
if lie.InviteeUserID != "" {
val["invitee_user_id"] = lie.InviteeUserID
}
if lie.DomainID != "" {
val["domain_id"] = lie.DomainID
}
if lie.RoleID != "" {
val["role_id"] = lie.RoleID
}
if lie.State.String() != "" {
val["state"] = lie.State.String()
}
return val, nil
}
type acceptInvitationEvent struct {
domainID string
session authn.Session
}
func (aie acceptInvitationEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationAccept,
"domain_id": aie.domainID,
"invitee_user_id": aie.session.UserID,
"token_type": aie.session.Type.String(),
"super_admin": aie.session.SuperAdmin,
}
return val, nil
}
type rejectInvitationEvent struct {
domainID string
session authn.Session
}
func (rie rejectInvitationEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationReject,
"domain_id": rie.domainID,
"invitee_user_id": rie.session.UserID,
"token_type": rie.session.Type.String(),
"super_admin": rie.session.SuperAdmin,
}
return val, nil
}
type deleteInvitationEvent struct {
inviteeUserID string
domainID string
session authn.Session
}
func (die deleteInvitationEvent) Encode() (map[string]interface{}, error) {
val := map[string]interface{}{
"operation": invitationDelete,
"invitee_user_id": die.inviteeUserID,
"domain_id": die.domainID,
"token_type": die.session.Type.String(),
"super_admin": die.session.SuperAdmin,
}
return val, nil
}
+92
View File
@@ -176,3 +176,95 @@ func (es *eventStore) ListDomains(ctx context.Context, session authn.Session, p
return dp, nil
}
func (es *eventStore) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) error {
if err := es.svc.SendInvitation(ctx, session, invitation); err != nil {
return err
}
event := sendInvitationEvent{
invitation: invitation,
session: session,
}
return es.Publish(ctx, event)
}
func (es *eventStore) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (domains.Invitation, error) {
invitation, err := es.svc.ViewInvitation(ctx, session, userID, domainID)
if err != nil {
return invitation, err
}
event := viewInvitationEvent{
inviteeUserID: userID,
domainID: domainID,
roleID: invitation.RoleID,
roleName: invitation.RoleName,
session: session,
}
if err := es.Publish(ctx, event); err != nil {
return invitation, err
}
return invitation, nil
}
func (es *eventStore) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (domains.InvitationPage, error) {
ip, err := es.svc.ListInvitations(ctx, session, pm)
if err != nil {
return ip, err
}
event := listInvitationsEvent{
InvitationPageMeta: pm,
session: session,
}
if err := es.Publish(ctx, event); err != nil {
return ip, err
}
return ip, nil
}
func (es *eventStore) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error {
if err := es.svc.AcceptInvitation(ctx, session, domainID); err != nil {
return err
}
event := acceptInvitationEvent{
domainID: domainID,
session: session,
}
return es.Publish(ctx, event)
}
func (es *eventStore) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error {
if err := es.svc.RejectInvitation(ctx, session, domainID); err != nil {
return err
}
event := rejectInvitationEvent{
domainID: domainID,
session: session,
}
return es.Publish(ctx, event)
}
func (es *eventStore) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) error {
if err := es.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID); err != nil {
return err
}
event := deleteInvitationEvent{
inviteeUserID: inviteeUserID,
domainID: domainID,
session: session,
}
return es.Publish(ctx, event)
}
+57
View File
@@ -0,0 +1,57 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package domains
import (
"encoding/json"
"time"
)
// Invitation is an invitation to join a domain.
type Invitation struct {
InvitedBy string `json:"invited_by"`
InviteeUserID string `json:"invitee_user_id"`
DomainID string `json:"domain_id"`
RoleID string `json:"role_id,omitempty"`
RoleName string `json:"role_name,omitempty"`
Actions []string `json:"actions,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
ConfirmedAt time.Time `json:"confirmed_at,omitempty"`
RejectedAt time.Time `json:"rejected_at,omitempty"`
}
// InvitationPage is a page of invitations.
type InvitationPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Invitations []Invitation `json:"invitations"`
}
func (page InvitationPage) MarshalJSON() ([]byte, error) {
type Alias InvitationPage
a := struct {
Alias
}{
Alias: Alias(page),
}
if a.Invitations == nil {
a.Invitations = make([]Invitation, 0)
}
return json.Marshal(a)
}
type InvitationPageMeta struct {
Offset uint64 `json:"offset" db:"offset"`
Limit uint64 `json:"limit" db:"limit"`
InvitedBy string `json:"invited_by,omitempty" db:"invited_by,omitempty"`
InviteeUserID string `json:"invitee_user_id,omitempty" db:"invitee_user_id,omitempty"`
DomainID string `json:"domain_id,omitempty" db:"domain_id,omitempty"`
RoleID string `json:"role_id,omitempty" db:"role_id,omitempty"`
InvitedByOrUserID string `db:"invited_by_or_user_id,omitempty"`
State State `json:"state,omitempty"`
}
+50
View File
@@ -0,0 +1,50 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package domains_test
import (
"fmt"
"testing"
"github.com/absmach/supermq/domains"
"github.com/stretchr/testify/assert"
)
func TestInvitation_MarshalJSON(t *testing.T) {
cases := []struct {
desc string
page domains.InvitationPage
res string
}{
{
desc: "empty page",
page: domains.InvitationPage{
Invitations: []domains.Invitation(nil),
},
res: `{"total":0,"offset":0,"limit":0,"invitations":[]}`,
},
{
desc: "page with invitations",
page: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 0,
Invitations: []domains.Invitation{
{
InvitedBy: "John",
InviteeUserID: "123",
DomainID: "123",
},
},
},
res: `{"total":1,"offset":0,"limit":0,"invitations":[{"invited_by":"John","invitee_user_id":"123","domain_id":"123","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","confirmed_at":"0001-01-01T00:00:00Z","rejected_at":"0001-01-01T00:00:00Z"}]}`,
},
}
for _, tc := range cases {
data, err := tc.page.MarshalJSON()
assert.NoError(t, err, "Unexpected error: %v", err)
assert.Equal(t, tc.res, string(data), fmt.Sprintf("%s: expected %s, got %s", tc.desc, tc.res, string(data)))
}
}
+115
View File
@@ -6,10 +6,13 @@ package middleware
import (
"context"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/absmach/supermq/pkg/roles"
rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware"
@@ -18,6 +21,9 @@ import (
var _ domains.Service = (*authorizationMiddleware)(nil)
// ErrMemberExist indicates that the user is already a member of the domain.
var ErrMemberExist = errors.New("user is already a member of the domain")
type authorizationMiddleware struct {
svc domains.Service
authz smqauthz.Authorization
@@ -134,6 +140,69 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth
return am.svc.ListDomains(ctx, session, page)
}
func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) {
domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.InviteeUserID)
if err := am.extAuthorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil {
// return error if the user is already a member of the domain
return errors.Wrap(svcerr.ErrConflict, ErrMemberExist)
}
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
return am.svc.SendInvitation(ctx, session, invitation)
}
func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domain string) (invitation domains.Invitation, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if session.UserID != inviteeUserID {
if err := am.checkAdmin(ctx, session); err != nil {
return domains.Invitation{}, err
}
}
return am.svc.ViewInvitation(ctx, session, inviteeUserID, domain)
}
func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.extAuthorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil {
session.SuperAdmin = true
page.DomainID = ""
}
if !session.SuperAdmin {
switch {
case page.DomainID != "":
if err := am.extAuthorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil {
return domains.InvitationPage{}, err
}
default:
page.InvitedByOrUserID = session.UserID
}
}
return am.svc.ListInvitations(ctx, session, page)
}
func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
return am.svc.AcceptInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
return am.svc.RejectInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
return am.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID)
}
func (am *authorizationMiddleware) authorize(ctx context.Context, op svcutil.Operation, authReq authz.PolicyReq) error {
perm, err := am.opp.GetPermission(op)
if err != nil {
@@ -147,3 +216,49 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, op svcutil.Ope
return nil
}
// checkAdmin checks if the given user is a domain or platform administrator.
func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error {
req := smqauthz.PolicyReq{
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.DomainUserID,
Permission: policies.AdminPermission,
ObjectType: policies.DomainType,
Object: session.DomainID,
}
if err := am.authz.Authorize(ctx, req); err == nil {
return nil
}
req = smqauthz.PolicyReq{
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.UserID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}
if err := am.authz.Authorize(ctx, req); err == nil {
return nil
}
return svcerr.ErrAuthorization
}
func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm, objType, obj string) error {
req := authz.PolicyReq{
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: subj,
Permission: perm,
ObjectType: objType,
Object: obj,
}
if err := am.authz.Authorize(ctx, req); err != nil {
return err
}
return nil
}
+103
View File
@@ -164,3 +164,106 @@ func (lm *loggingMiddleware) ListDomains(ctx context.Context, session authn.Sess
}(time.Now())
return lm.svc.ListDomains(ctx, session, page)
}
func (lm *loggingMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("invitee_user_id", invitation.InviteeUserID),
slog.String("domain_id", invitation.DomainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Send invitation failed", args...)
return
}
lm.logger.Info("Send invitation completed successfully", args...)
}(time.Now())
return lm.svc.SendInvitation(ctx, session, invitation)
}
func (lm *loggingMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation domains.Invitation, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("invitee_user_id", inviteeUserID),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("View invitation failed", args...)
return
}
lm.logger.Info("View invitation completed successfully", args...)
}(time.Now())
return lm.svc.ViewInvitation(ctx, session, inviteeUserID, domainID)
}
func (lm *loggingMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Group("page",
slog.Uint64("offset", pm.Offset),
slog.Uint64("limit", pm.Limit),
slog.Uint64("total", invs.Total),
),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List invitations failed", args...)
return
}
lm.logger.Info("List invitations completed successfully", args...)
}(time.Now())
return lm.svc.ListInvitations(ctx, session, pm)
}
func (lm *loggingMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Accept invitation failed", args...)
return
}
lm.logger.Info("Accept invitation completed successfully", args...)
}(time.Now())
return lm.svc.AcceptInvitation(ctx, session, domainID)
}
func (lm *loggingMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Reject invitation failed", args...)
return
}
lm.logger.Info("Reject invitation completed successfully", args...)
}(time.Now())
return lm.svc.RejectInvitation(ctx, session, domainID)
}
func (lm *loggingMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("invitee_user_id", inviteeUserID),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Delete invitation failed", args...)
return
}
lm.logger.Info("Delete invitation completed successfully", args...)
}(time.Now())
return lm.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID)
}
+48
View File
@@ -92,3 +92,51 @@ func (ms *metricsMiddleware) ListDomains(ctx context.Context, session authn.Sess
}(time.Now())
return ms.svc.ListDomains(ctx, session, page)
}
func (mm *metricsMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "send_invitation").Add(1)
mm.latency.With("method", "send_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.SendInvitation(ctx, session, invitation)
}
func (mm *metricsMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation domains.Invitation, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "view_invitation").Add(1)
mm.latency.With("method", "view_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ViewInvitation(ctx, session, userID, domainID)
}
func (mm *metricsMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_invitations").Add(1)
mm.latency.With("method", "list_invitations").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ListInvitations(ctx, session, pm)
}
func (mm *metricsMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "accept_invitation").Add(1)
mm.latency.With("method", "accept_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.AcceptInvitation(ctx, session, domainID)
}
func (mm *metricsMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "reject_invitation").Add(1)
mm.latency.With("method", "reject_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.RejectInvitation(ctx, session, domainID)
}
func (mm *metricsMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "delete_invitation").Add(1)
mm.latency.With("method", "delete_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.DeleteInvitation(ctx, session, userID, domainID)
}
+146 -18
View File
@@ -48,12 +48,12 @@ func (_m *Repository) AddRoles(ctx context.Context, rps []roles.RoleProvision) (
return r0, r1
}
// Delete provides a mock function with given fields: ctx, id
func (_m *Repository) Delete(ctx context.Context, id string) error {
// DeleteDomain provides a mock function with given fields: ctx, id
func (_m *Repository) DeleteDomain(ctx context.Context, id string) error {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for Delete")
panic("no return value specified for DeleteDomain")
}
var r0 error
@@ -66,6 +66,24 @@ func (_m *Repository) Delete(ctx context.Context, id string) error {
return r0
}
// DeleteInvitation provides a mock function with given fields: ctx, userID, domainID
func (_m *Repository) DeleteInvitation(ctx context.Context, userID string, domainID string) error {
ret := _m.Called(ctx, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for DeleteInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// ListDomains provides a mock function with given fields: ctx, pm
func (_m *Repository) ListDomains(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
ret := _m.Called(ctx, pm)
@@ -176,12 +194,12 @@ func (_m *Repository) RemoveRoles(ctx context.Context, roleIDs []string) error {
return r0
}
// RetrieveAllByIDs provides a mock function with given fields: ctx, pm
func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
// RetrieveAllDomainsByIDs provides a mock function with given fields: ctx, pm
func (_m *Repository) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
ret := _m.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveAllByIDs")
panic("no return value specified for RetrieveAllDomainsByIDs")
}
var r0 domains.DomainsPage
@@ -204,6 +222,34 @@ func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (do
return r0, r1
}
// RetrieveAllInvitations provides a mock function with given fields: ctx, page
func (_m *Repository) RetrieveAllInvitations(ctx context.Context, page domains.InvitationPageMeta) (domains.InvitationPage, error) {
ret := _m.Called(ctx, page)
if len(ret) == 0 {
panic("no return value specified for RetrieveAllInvitations")
}
var r0 domains.InvitationPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, domains.InvitationPageMeta) (domains.InvitationPage, error)); ok {
return rf(ctx, page)
}
if rf, ok := ret.Get(0).(func(context.Context, domains.InvitationPageMeta) domains.InvitationPage); ok {
r0 = rf(ctx, page)
} else {
r0 = ret.Get(0).(domains.InvitationPage)
}
if rf, ok := ret.Get(1).(func(context.Context, domains.InvitationPageMeta) error); ok {
r1 = rf(ctx, page)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveAllRoles provides a mock function with given fields: ctx, entityID, limit, offset
func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) {
ret := _m.Called(ctx, entityID, limit, offset)
@@ -232,12 +278,12 @@ func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, lim
return r0, r1
}
// RetrieveByID provides a mock function with given fields: ctx, id
func (_m *Repository) RetrieveByID(ctx context.Context, id string) (domains.Domain, error) {
// RetrieveDomainByID provides a mock function with given fields: ctx, id
func (_m *Repository) RetrieveDomainByID(ctx context.Context, id string) (domains.Domain, error) {
ret := _m.Called(ctx, id)
if len(ret) == 0 {
panic("no return value specified for RetrieveByID")
panic("no return value specified for RetrieveDomainByID")
}
var r0 domains.Domain
@@ -260,12 +306,12 @@ func (_m *Repository) RetrieveByID(ctx context.Context, id string) (domains.Doma
return r0, r1
}
// RetrieveByUserAndID provides a mock function with given fields: ctx, userID, id
func (_m *Repository) RetrieveByUserAndID(ctx context.Context, userID string, id string) (domains.Domain, error) {
// RetrieveDomainByUserAndID provides a mock function with given fields: ctx, userID, id
func (_m *Repository) RetrieveDomainByUserAndID(ctx context.Context, userID string, id string) (domains.Domain, error) {
ret := _m.Called(ctx, userID, id)
if len(ret) == 0 {
panic("no return value specified for RetrieveByUserAndID")
panic("no return value specified for RetrieveDomainByUserAndID")
}
var r0 domains.Domain
@@ -355,6 +401,34 @@ func (_m *Repository) RetrieveEntityRole(ctx context.Context, entityID string, r
return r0, r1
}
// RetrieveInvitation provides a mock function with given fields: ctx, userID, domainID
func (_m *Repository) RetrieveInvitation(ctx context.Context, userID string, domainID string) (domains.Invitation, error) {
ret := _m.Called(ctx, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for RetrieveInvitation")
}
var r0 domains.Invitation
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (domains.Invitation, error)); ok {
return rf(ctx, userID, domainID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) domains.Invitation); ok {
r0 = rf(ctx, userID, domainID)
} else {
r0 = ret.Get(0).(domains.Invitation)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, userID, domainID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveRole provides a mock function with given fields: ctx, roleID
func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Role, error) {
ret := _m.Called(ctx, roleID)
@@ -629,12 +703,12 @@ func (_m *Repository) RoleRemoveMembers(ctx context.Context, role roles.Role, me
return r0
}
// Save provides a mock function with given fields: ctx, d
func (_m *Repository) Save(ctx context.Context, d domains.Domain) (domains.Domain, error) {
// SaveDomain provides a mock function with given fields: ctx, d
func (_m *Repository) SaveDomain(ctx context.Context, d domains.Domain) (domains.Domain, error) {
ret := _m.Called(ctx, d)
if len(ret) == 0 {
panic("no return value specified for Save")
panic("no return value specified for SaveDomain")
}
var r0 domains.Domain
@@ -657,12 +731,48 @@ func (_m *Repository) Save(ctx context.Context, d domains.Domain) (domains.Domai
return r0, r1
}
// Update provides a mock function with given fields: ctx, id, d
func (_m *Repository) Update(ctx context.Context, id string, d domains.DomainReq) (domains.Domain, error) {
// SaveInvitation provides a mock function with given fields: ctx, invitation
func (_m *Repository) SaveInvitation(ctx context.Context, invitation domains.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for SaveInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateConfirmation provides a mock function with given fields: ctx, invitation
func (_m *Repository) UpdateConfirmation(ctx context.Context, invitation domains.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for UpdateConfirmation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateDomain provides a mock function with given fields: ctx, id, d
func (_m *Repository) UpdateDomain(ctx context.Context, id string, d domains.DomainReq) (domains.Domain, error) {
ret := _m.Called(ctx, id, d)
if len(ret) == 0 {
panic("no return value specified for Update")
panic("no return value specified for UpdateDomain")
}
var r0 domains.Domain
@@ -685,6 +795,24 @@ func (_m *Repository) Update(ctx context.Context, id string, d domains.DomainReq
return r0, r1
}
// UpdateRejection provides a mock function with given fields: ctx, invitation
func (_m *Repository) UpdateRejection(ctx context.Context, invitation domains.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for UpdateRejection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, domains.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateRole provides a mock function with given fields: ctx, ro
func (_m *Repository) UpdateRole(ctx context.Context, ro roles.Role) (roles.Role, error) {
ret := _m.Called(ctx, ro)
+128
View File
@@ -21,6 +21,24 @@ type Service struct {
mock.Mock
}
// AcceptInvitation provides a mock function with given fields: ctx, session, domainID
func (_m *Service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error {
ret := _m.Called(ctx, session, domainID)
if len(ret) == 0 {
panic("no return value specified for AcceptInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = rf(ctx, session, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// AddRole provides a mock function with given fields: ctx, session, entityID, roleName, optionalActions, optionalMembers
func (_m *Service) AddRole(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) {
ret := _m.Called(ctx, session, entityID, roleName, optionalActions, optionalMembers)
@@ -86,6 +104,24 @@ func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d d
return r0, r1, r2
}
// DeleteInvitation provides a mock function with given fields: ctx, session, inviteeUserID, domainID
func (_m *Service) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID string, domainID string) error {
ret := _m.Called(ctx, session, inviteeUserID, domainID)
if len(ret) == 0 {
panic("no return value specified for DeleteInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok {
r0 = rf(ctx, session, inviteeUserID, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// DisableDomain provides a mock function with given fields: ctx, sesssion, id
func (_m *Service) DisableDomain(ctx context.Context, sesssion authn.Session, id string) (domains.Domain, error) {
ret := _m.Called(ctx, sesssion, id)
@@ -256,6 +292,52 @@ func (_m *Service) ListEntityMembers(ctx context.Context, session authn.Session,
return r0, r1
}
// ListInvitations provides a mock function with given fields: ctx, session, page
func (_m *Service) ListInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (domains.InvitationPage, error) {
ret := _m.Called(ctx, session, page)
if len(ret) == 0 {
panic("no return value specified for ListInvitations")
}
var r0 domains.InvitationPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.InvitationPageMeta) (domains.InvitationPage, error)); ok {
return rf(ctx, session, page)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.InvitationPageMeta) domains.InvitationPage); ok {
r0 = rf(ctx, session, page)
} else {
r0 = ret.Get(0).(domains.InvitationPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, domains.InvitationPageMeta) error); ok {
r1 = rf(ctx, session, page)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RejectInvitation provides a mock function with given fields: ctx, session, domainID
func (_m *Service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error {
ret := _m.Called(ctx, session, domainID)
if len(ret) == 0 {
panic("no return value specified for RejectInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = rf(ctx, session, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveEntityMembers provides a mock function with given fields: ctx, session, entityID, members
func (_m *Service) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error {
ret := _m.Called(ctx, session, entityID, members)
@@ -640,6 +722,24 @@ func (_m *Service) RoleRemoveMembers(ctx context.Context, session authn.Session,
return r0
}
// SendInvitation provides a mock function with given fields: ctx, session, invitation
func (_m *Service) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) error {
ret := _m.Called(ctx, session, invitation)
if len(ret) == 0 {
panic("no return value specified for SendInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.Invitation) error); ok {
r0 = rf(ctx, session, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateDomain provides a mock function with given fields: ctx, sesssion, id, d
func (_m *Service) UpdateDomain(ctx context.Context, sesssion authn.Session, id string, d domains.DomainReq) (domains.Domain, error) {
ret := _m.Called(ctx, sesssion, id, d)
@@ -696,6 +796,34 @@ func (_m *Service) UpdateRoleName(ctx context.Context, session authn.Session, en
return r0, r1
}
// ViewInvitation provides a mock function with given fields: ctx, session, inviteeUserID, domainID
func (_m *Service) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID string, domainID string) (domains.Invitation, error) {
ret := _m.Called(ctx, session, inviteeUserID, domainID)
if len(ret) == 0 {
panic("no return value specified for ViewInvitation")
}
var r0 domains.Invitation
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (domains.Invitation, error)); ok {
return rf(ctx, session, inviteeUserID, domainID)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) domains.Invitation); ok {
r0 = rf(ctx, session, inviteeUserID, domainID)
} else {
r0 = ret.Get(0).(domains.Invitation)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok {
r1 = rf(ctx, session, inviteeUserID, domainID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
+13 -13
View File
@@ -36,9 +36,9 @@ type domainRepo struct {
rolesPostgres.Repository
}
// New instantiates a PostgreSQL
// NewRepository instantiates a PostgreSQL
// implementation of Domain repository.
func New(db postgres.Database) domains.Repository {
func NewRepository(db postgres.Database) domains.Repository {
rmsvcRepo := rolesPostgres.NewRepository(db, policies.DomainType, rolesTableNamePrefix, entityTableName, entityIDColumnName)
return &domainRepo{
db: db,
@@ -46,7 +46,7 @@ func New(db postgres.Database) domains.Repository {
}
}
func (repo domainRepo) Save(ctx context.Context, d domains.Domain) (dd domains.Domain, err error) {
func (repo domainRepo) SaveDomain(ctx context.Context, d domains.Domain) (dd domains.Domain, err error) {
q := `INSERT INTO domains (id, name, tags, alias, metadata, created_at, updated_at, updated_by, created_by, status)
VALUES (:id, :name, :tags, :alias, :metadata, :created_at, :updated_at, :updated_by, :created_by, :status)
RETURNING id, name, tags, alias, metadata, created_at, updated_at, updated_by, created_by, status;`
@@ -76,8 +76,8 @@ func (repo domainRepo) Save(ctx context.Context, d domains.Domain) (dd domains.D
return domain, nil
}
// RetrieveByID retrieves Domain by its unique ID.
func (repo domainRepo) RetrieveByID(ctx context.Context, id string) (domains.Domain, error) {
// RetrieveDomainByID retrieves Domain by its unique ID.
func (repo domainRepo) RetrieveDomainByID(ctx context.Context, id string) (domains.Domain, error) {
q := `SELECT d.id as id, d.name as name, d.tags as tags, d.alias as alias, d.metadata as metadata, d.created_at as created_at, d.updated_at as updated_at, d.updated_by as updated_by, d.created_by as created_by, d.status as status
FROM domains d WHERE d.id = :id`
@@ -107,7 +107,7 @@ func (repo domainRepo) RetrieveByID(ctx context.Context, id string) (domains.Dom
return domains.Domain{}, repoerr.ErrNotFound
}
func (repo domainRepo) RetrieveByUserAndID(ctx context.Context, userID, id string) (domains.Domain, error) {
func (repo domainRepo) RetrieveDomainByUserAndID(ctx context.Context, userID, id string) (domains.Domain, error) {
q := repo.userDomainsBaseQuery() +
`SELECT
d.id as id,
@@ -156,7 +156,7 @@ func (repo domainRepo) RetrieveByUserAndID(ctx context.Context, userID, id strin
}
// RetrieveAllByIDs retrieves for given Domain IDs .
func (repo domainRepo) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
func (repo domainRepo) RetrieveAllDomainsByIDs(ctx context.Context, pm domains.Page) (domains.DomainsPage, error) {
var q string
if len(pm.IDs) == 0 {
return domains.DomainsPage{}, nil
@@ -170,7 +170,7 @@ func (repo domainRepo) RetrieveAllByIDs(ctx context.Context, pm domains.Page) (d
FROM domains d`
q = fmt.Sprintf("%s %s LIMIT %d OFFSET %d;", q, query, pm.Limit, pm.Offset)
dbPage, err := toDBClientsPage(pm)
dbPage, err := toDBDomainsPage(pm)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
@@ -254,7 +254,7 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
q = fmt.Sprintf(q, squery)
dbPage, err := toDBClientsPage(pm)
dbPage, err := toDBDomainsPage(pm)
if err != nil {
return domains.DomainsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err)
}
@@ -294,8 +294,8 @@ func (repo domainRepo) ListDomains(ctx context.Context, pm domains.Page) (domain
}, nil
}
// Update updates the client name and metadata.
func (repo domainRepo) Update(ctx context.Context, id string, dr domains.DomainReq) (domains.Domain, error) {
// UpdateDomain updates the client name and metadata.
func (repo domainRepo) UpdateDomain(ctx context.Context, id string, dr domains.DomainReq) (domains.Domain, error) {
var query []string
var upq string
d := domains.Domain{ID: id}
@@ -361,7 +361,7 @@ func (repo domainRepo) Update(ctx context.Context, id string, dr domains.DomainR
}
// Delete delete domain from database.
func (repo domainRepo) Delete(ctx context.Context, id string) error {
func (repo domainRepo) DeleteDomain(ctx context.Context, id string) error {
q := "DELETE FROM domains WHERE id = $1;"
res, err := repo.db.ExecContext(ctx, q, id)
@@ -553,7 +553,7 @@ type dbDomainsPage struct {
UserID string `db:"member_id"`
}
func toDBClientsPage(pm domains.Page) (dbDomainsPage, error) {
func toDBDomainsPage(pm domains.Page) (dbDomainsPage, error) {
_, data, err := postgres.CreateMetadataQuery("", pm.Metadata)
if err != nil {
return dbDomainsPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
+17 -17
View File
@@ -27,13 +27,13 @@ var (
userID = testsutil.GenerateUUID(&testing.T{})
)
func TestSave(t *testing.T) {
func TestSaveDomain(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.New(database)
repo := postgres.NewRepository(database)
cases := []struct {
desc string
@@ -134,7 +134,7 @@ func TestSave(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
domain, err := repo.Save(context.Background(), tc.domain)
domain, err := repo.SaveDomain(context.Background(), tc.domain)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, tc.domain, domain, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.domain, domain))
@@ -149,7 +149,7 @@ func TestRetrieveByID(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.New(database)
repo := postgres.NewRepository(database)
domain := domains.Domain{
ID: domainID,
@@ -166,7 +166,7 @@ func TestRetrieveByID(t *testing.T) {
Status: domains.EnabledStatus,
}
_, err := repo.Save(context.Background(), domain)
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID))
cases := []struct {
@@ -197,7 +197,7 @@ func TestRetrieveByID(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
d, err := repo.RetrieveByID(context.Background(), tc.domainID)
d, err := repo.RetrieveDomainByID(context.Background(), tc.domainID)
assert.Equal(t, tc.response, d, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, d))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
})
@@ -210,7 +210,7 @@ func TestRetrieveAllByIDs(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.New(database)
repo := postgres.NewRepository(database)
items := []domains.Domain{}
for i := 0; i < 10; i++ {
@@ -233,7 +233,7 @@ func TestRetrieveAllByIDs(t *testing.T) {
"test1": "test1",
}
}
_, err := repo.Save(context.Background(), domain)
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("save domain unexpected error: %s", err))
items = append(items, domain)
}
@@ -434,7 +434,7 @@ func TestRetrieveAllByIDs(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
dp, err := repo.RetrieveAllByIDs(context.Background(), tc.pm)
dp, err := repo.RetrieveAllDomainsByIDs(context.Background(), tc.pm)
assert.Equal(t, tc.response, dp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, dp))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
})
@@ -455,7 +455,7 @@ func TestUpdate(t *testing.T) {
updatedStatus := domains.DisabledStatus
updatedAlias := "test1"
repo := postgres.New(database)
repo := postgres.NewRepository(database)
domain := domains.Domain{
ID: domainID,
@@ -470,7 +470,7 @@ func TestUpdate(t *testing.T) {
Status: domains.EnabledStatus,
}
_, err := repo.Save(context.Background(), domain)
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID))
cases := []struct {
@@ -561,7 +561,7 @@ func TestUpdate(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
d, err := repo.Update(context.Background(), tc.domainID, tc.d)
d, err := repo.UpdateDomain(context.Background(), tc.domainID, tc.d)
d.UpdatedAt = tc.response.UpdatedAt
assert.Equal(t, tc.response, d, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, d))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
@@ -575,7 +575,7 @@ func TestDelete(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.New(database)
repo := postgres.NewRepository(database)
domain := domains.Domain{
ID: domainID,
@@ -590,7 +590,7 @@ func TestDelete(t *testing.T) {
Status: domains.EnabledStatus,
}
_, err := repo.Save(context.Background(), domain)
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("failed to save client %s", domain.ID))
cases := []struct {
@@ -617,7 +617,7 @@ func TestDelete(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.Delete(context.Background(), tc.domainID)
err := repo.DeleteDomain(context.Background(), tc.domainID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
@@ -629,7 +629,7 @@ func TestListDomains(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.New(database)
repo := postgres.NewRepository(database)
items := []domains.Domain{}
for i := 0; i < 10; i++ {
@@ -652,7 +652,7 @@ func TestListDomains(t *testing.T) {
"test1": "test1",
}
}
_, err := repo.Save(context.Background(), domain)
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("save domain unexpected error: %s", err))
items = append(items, domain)
}
+21
View File
@@ -40,6 +40,27 @@ func Migration() (*migrate.MemoryMigrationSource, error) {
`DROP TABLE IF EXISTS domains`,
},
},
{
Id: "domain_2",
Up: []string{
`CREATE TABLE IF NOT EXISTS invitations (
invited_by VARCHAR(36) NOT NULL,
invitee_user_id VARCHAR(36) NOT NULL,
domain_id VARCHAR(36) NOT NULL,
role_id VARCHAR(36) NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP,
confirmed_at TIMESTAMP,
rejected_at TIMESTAMP,
UNIQUE (invitee_user_id, domain_id),
PRIMARY KEY (invitee_user_id, domain_id),
FOREIGN KEY (domain_id) REFERENCES domains(id) ON DELETE CASCADE
);`,
},
Down: []string{
`DROP TABLE IF EXISTS invitations`,
},
},
},
}
+229
View File
@@ -0,0 +1,229 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/domains"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/postgres"
)
func (repo domainRepo) SaveInvitation(ctx context.Context, invitation domains.Invitation) (err error) {
q := `INSERT INTO invitations (invited_by, invitee_user_id, domain_id, role_id, created_at)
VALUES (:invited_by, :invitee_user_id, :domain_id, :role_id, :created_at)`
dbInv := toDBInvitation(invitation)
if _, err = repo.db.NamedExecContext(ctx, q, dbInv); err != nil {
return postgres.HandleError(repoerr.ErrCreateEntity, err)
}
return nil
}
func (repo domainRepo) RetrieveInvitation(ctx context.Context, inviteeUserID, domainID string) (domains.Invitation, error) {
q := `SELECT invited_by, invitee_user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id;`
dbinv := dbInvitation{
InviteeUserID: inviteeUserID,
DomainID: domainID,
}
rows, err := repo.db.NamedQueryContext(ctx, q, dbinv)
if err != nil {
return domains.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbinv = dbInvitation{}
if rows.Next() {
if err = rows.StructScan(&dbinv); err != nil {
return domains.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
return toInvitation(dbinv), nil
}
return domains.Invitation{}, repoerr.ErrNotFound
}
func (repo domainRepo) RetrieveAllInvitations(ctx context.Context, pm domains.InvitationPageMeta) (domains.InvitationPage, error) {
query := pageQuery(pm)
q := fmt.Sprintf("SELECT invited_by, invitee_user_id, domain_id, role_id, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query)
rows, err := repo.db.NamedQueryContext(ctx, q, pm)
if err != nil {
return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var items []domains.Invitation
for rows.Next() {
var dbinv dbInvitation
if err = rows.StructScan(&dbinv); err != nil {
return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
items = append(items, toInvitation(dbinv))
}
tq := fmt.Sprintf(`SELECT COUNT(*) FROM invitations %s`, query)
total, err := postgres.Total(ctx, repo.db, tq, pm)
if err != nil {
return domains.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
invPage := domains.InvitationPage{
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
Invitations: items,
}
return invPage, nil
}
func (repo domainRepo) UpdateConfirmation(ctx context.Context, invitation domains.Invitation) (err error) {
q := `UPDATE invitations SET confirmed_at = :confirmed_at, updated_at = :updated_at WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id`
dbinv := toDBInvitation(invitation)
result, err := repo.db.NamedExecContext(ctx, q, dbinv)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo domainRepo) UpdateRejection(ctx context.Context, invitation domains.Invitation) (err error) {
q := `UPDATE invitations SET rejected_at = :rejected_at, updated_at = :updated_at WHERE invitee_user_id = :invitee_user_id AND domain_id = :domain_id`
dbInv := toDBInvitation(invitation)
result, err := repo.db.NamedExecContext(ctx, q, dbInv)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo domainRepo) DeleteInvitation(ctx context.Context, inviteeUserID, domain string) (err error) {
q := `DELETE FROM invitations WHERE invitee_user_id = $1 AND domain_id = $2`
result, err := repo.db.ExecContext(ctx, q, inviteeUserID, domain)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func pageQuery(pm domains.InvitationPageMeta) string {
var query []string
var emq string
if pm.DomainID != "" {
query = append(query, "domain_id = :domain_id")
}
if pm.InviteeUserID != "" {
query = append(query, "invitee_user_id = :invitee_user_id")
}
if pm.InvitedBy != "" {
query = append(query, "invited_by = :invited_by")
}
if pm.RoleID != "" {
query = append(query, "role_id = :role_id")
}
if pm.InvitedByOrUserID != "" {
query = append(query, "(invited_by = :invited_by_or_user_id OR invitee_user_id = :invited_by_or_user_id)")
}
if pm.State == domains.Accepted {
query = append(query, "confirmed_at IS NOT NULL")
}
if pm.State == domains.Pending {
query = append(query, "confirmed_at IS NULL AND rejected_at IS NULL")
}
if pm.State == domains.Rejected {
query = append(query, "rejected_at IS NOT NULL")
}
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
}
return emq
}
type dbInvitation struct {
InvitedBy string `db:"invited_by"`
InviteeUserID string `db:"invitee_user_id"`
DomainID string `db:"domain_id"`
RoleID string `db:"role_id,omitempty"`
Relation string `db:"relation"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
ConfirmedAt sql.NullTime `db:"confirmed_at,omitempty"`
RejectedAt sql.NullTime `db:"rejected_at,omitempty"`
}
func toDBInvitation(inv domains.Invitation) dbInvitation {
var updatedAt, confirmedAt, rejectedAt sql.NullTime
if inv.UpdatedAt != (time.Time{}) {
updatedAt = sql.NullTime{Time: inv.UpdatedAt, Valid: true}
}
if inv.ConfirmedAt != (time.Time{}) {
confirmedAt = sql.NullTime{Time: inv.ConfirmedAt, Valid: true}
}
if inv.RejectedAt != (time.Time{}) {
rejectedAt = sql.NullTime{Time: inv.RejectedAt, Valid: true}
}
return dbInvitation{
InvitedBy: inv.InvitedBy,
InviteeUserID: inv.InviteeUserID,
DomainID: inv.DomainID,
RoleID: inv.RoleID,
CreatedAt: inv.CreatedAt,
UpdatedAt: updatedAt,
ConfirmedAt: confirmedAt,
RejectedAt: rejectedAt,
}
}
func toInvitation(dbinv dbInvitation) domains.Invitation {
var updatedAt, confirmedAt, rejectedAt time.Time
if dbinv.UpdatedAt.Valid {
updatedAt = dbinv.UpdatedAt.Time
}
if dbinv.ConfirmedAt.Valid {
confirmedAt = dbinv.ConfirmedAt.Time
}
if dbinv.RejectedAt.Valid {
rejectedAt = dbinv.RejectedAt.Time
}
return domains.Invitation{
InvitedBy: dbinv.InvitedBy,
InviteeUserID: dbinv.InviteeUserID,
DomainID: dbinv.DomainID,
RoleID: dbinv.RoleID,
CreatedAt: dbinv.CreatedAt,
UpdatedAt: updatedAt,
ConfirmedAt: confirmedAt,
RejectedAt: rejectedAt,
}
}
+833
View File
@@ -0,0 +1,833 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var invalidUUID = strings.Repeat("a", 37)
func TestSaveInvitation(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
userID := testsutil.GenerateUUID(t)
roleID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
invitation domains.Invitation
err error
}{
{
desc: "add new invitation successfully",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: userID,
DomainID: domainID,
RoleID: roleID,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add new invitation with an confirmed_at date",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
CreatedAt: time.Now(),
RoleID: roleID,
ConfirmedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with duplicate invitation",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: userID,
DomainID: domainID,
RoleID: roleID,
CreatedAt: time.Now(),
},
err: repoerr.ErrConflict,
},
{
desc: "add invitation with invalid invitation invited_by",
invitation: domains.Invitation{
InvitedBy: invalidUUID,
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: roleID,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with invalid invitation domain",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: invalidUUID,
RoleID: roleID,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with invalid invitation invitee user id",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: invalidUUID,
DomainID: testsutil.GenerateUUID(t),
RoleID: roleID,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with empty invitation domain",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
RoleID: roleID,
CreatedAt: time.Now(),
},
err: repoerr.ErrCreateEntity,
},
{
desc: "add invitation with empty invitation invitee user id",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: roleID,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with empty invitation invited_by",
invitation: domains.Invitation{
DomainID: domainID,
InviteeUserID: testsutil.GenerateUUID(t),
RoleID: roleID,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with empty invitation role id",
invitation: domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
CreatedAt: time.Now(),
},
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.SaveInvitation(context.Background(), tc.invitation)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.err, err))
})
}
}
func TestInvitationRetrieve(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
invitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: testsutil.GenerateUUID(t),
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
}
err := repo.SaveInvitation(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
userID string
domainID string
response domains.Invitation
err error
}{
{
desc: "retrieve invitations successfully",
userID: invitation.InviteeUserID,
domainID: invitation.DomainID,
response: invitation,
err: nil,
},
{
desc: "retrieve invitations with invalid invitee user id",
userID: testsutil.GenerateUUID(t),
domainID: invitation.DomainID,
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with invalid invitation domain_id",
userID: invitation.InviteeUserID,
domainID: testsutil.GenerateUUID(t),
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with invalid invitee user id and domain_id",
userID: testsutil.GenerateUUID(t),
domainID: testsutil.GenerateUUID(t),
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitee user id",
userID: "",
domainID: invitation.DomainID,
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitation domain_id",
userID: invitation.InviteeUserID,
domainID: "",
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitation user id and domain_id",
userID: "",
domainID: "",
response: domains.Invitation{},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
inv, err := repo.RetrieveInvitation(context.Background(), tc.userID, tc.domainID)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, inv, fmt.Sprintf("desc: %s\n", tc.desc))
})
}
}
func TestInvitationRetrieveAll(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
num := 200
var items []domains.Invitation
for i := 0; i < num; i++ {
invitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: testsutil.GenerateUUID(t),
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
}
err := repo.SaveInvitation(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
items = append(items, invitation)
}
items[100].ConfirmedAt = time.Now().UTC().Truncate(time.Microsecond)
err := repo.UpdateConfirmation(context.Background(), items[100])
require.Nil(t, err, fmt.Sprintf("update invitation unexpected error: %s", err))
swap := items[100]
items = append(items[:100], items[101:]...)
items = append(items, swap)
cases := []struct {
desc string
page domains.InvitationPageMeta
response domains.InvitationPage
err error
}{
{
desc: "retrieve invitations successfully",
page: domains.InvitationPageMeta{
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 10,
Invitations: items[:10],
},
err: nil,
},
{
desc: "retrieve invitations with offset",
page: domains.InvitationPageMeta{
Offset: 10,
Limit: 10,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 10,
Limit: 10,
Invitations: items[10:20],
},
},
{
desc: "retrieve invitations with limit",
page: domains.InvitationPageMeta{
Offset: 0,
Limit: 50,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 50,
Invitations: items[:50],
},
},
{
desc: "retrieve invitations with offset and limit",
page: domains.InvitationPageMeta{
Offset: 10,
Limit: 50,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 10,
Limit: 50,
Invitations: items[10:60],
},
},
{
desc: "retrieve invitations with offset out of range",
page: domains.InvitationPageMeta{
Offset: 1000,
Limit: 50,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 1000,
Limit: 50,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with offset and limit out of range",
page: domains.InvitationPageMeta{
Offset: 170,
Limit: 50,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 170,
Limit: 50,
Invitations: items[170:200],
},
},
{
desc: "retrieve invitations with limit out of range",
page: domains.InvitationPageMeta{
Offset: 0,
Limit: 1000,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 1000,
Invitations: items,
},
},
{
desc: "retrieve invitations with empty page",
page: domains.InvitationPageMeta{},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 0,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with domain",
page: domains.InvitationPageMeta{
DomainID: items[0].DomainID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 10,
Invitations: items[:10],
},
},
{
desc: "retrieve invitations with invitee user id",
page: domains.InvitationPageMeta{
InviteeUserID: items[0].InviteeUserID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invited_by",
page: domains.InvitationPageMeta{
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with role_id",
page: domains.InvitationPageMeta{
RoleID: items[3].RoleID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[3]},
},
},
{
desc: "retrieve invitations with invited_by_or_user_id",
page: domains.InvitationPageMeta{
InvitedByOrUserID: items[0].InviteeUserID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id and invitee user id",
page: domains.InvitationPageMeta{
DomainID: items[0].DomainID,
InviteeUserID: items[0].InviteeUserID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id and invited_by",
page: domains.InvitationPageMeta{
DomainID: items[0].DomainID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invitee user id and invited_by",
page: domains.InvitationPageMeta{
InviteeUserID: items[0].InviteeUserID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id, invitee user id and invited_by",
page: domains.InvitationPageMeta{
DomainID: items[0].DomainID,
InviteeUserID: items[0].InviteeUserID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id, invitee user id, invited_by and role_id",
page: domains.InvitationPageMeta{
DomainID: items[0].DomainID,
InviteeUserID: items[0].InviteeUserID,
InvitedBy: items[0].InvitedBy,
RoleID: items[0].RoleID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invalid domain",
page: domains.InvitationPageMeta{
DomainID: invalidUUID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid invitee user id",
page: domains.InvitationPageMeta{
InviteeUserID: testsutil.GenerateUUID(t),
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid invited_by",
page: domains.InvitationPageMeta{
InvitedBy: invalidUUID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid role_id",
page: domains.InvitationPageMeta{
RoleID: invalidUUID,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation(nil),
},
},
{
desc: "retrieve invitations with accepted state",
page: domains.InvitationPageMeta{
State: domains.Accepted,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{items[num-1]},
},
},
{
desc: "retrieve invitations with pending state",
page: domains.InvitationPageMeta{
State: domains.Pending,
Offset: 0,
Limit: 10,
},
response: domains.InvitationPage{
Total: uint64(num - 1),
Offset: 0,
Limit: 10,
Invitations: items[0:10],
},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
page, err := repo.RetrieveAllInvitations(context.Background(), tc.page)
assert.Equal(t, tc.response.Total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, page.Total))
assert.Equal(t, tc.response.Offset, page.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, page.Offset))
assert.Equal(t, tc.response.Limit, page.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, page.Limit))
assert.ElementsMatch(t, page.Invitations, tc.response.Invitations, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response.Invitations, page.Invitations))
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestInvitationUpdateConfirmation(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
invitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
}
err := repo.SaveInvitation(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation domains.Invitation
err error
}{
{
desc: "update invitation successfully",
invitation: domains.Invitation{
DomainID: invitation.DomainID,
InviteeUserID: invitation.InviteeUserID,
ConfirmedAt: time.Now(),
},
err: nil,
},
{
desc: "update invitation with invalid invitee user id",
invitation: domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: invitation.InviteeUserID,
ConfirmedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
{
desc: "update invitation with invalid domain",
invitation: domains.Invitation{
InviteeUserID: invitation.InviteeUserID,
DomainID: testsutil.GenerateUUID(t),
ConfirmedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.UpdateConfirmation(context.Background(), tc.invitation)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestInvitationUpdateRejection(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
invitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
}
err := repo.SaveInvitation(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation domains.Invitation
err error
}{
{
desc: "update invitation successfully",
invitation: domains.Invitation{
DomainID: invitation.DomainID,
InviteeUserID: invitation.InviteeUserID,
RejectedAt: time.Now(),
},
err: nil,
},
{
desc: "update invitation with invalid invitee user id",
invitation: domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: invitation.InviteeUserID,
RejectedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
{
desc: "update invitation with invalid domain",
invitation: domains.Invitation{
InviteeUserID: invitation.InviteeUserID,
DomainID: testsutil.GenerateUUID(t),
RejectedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.UpdateRejection(context.Background(), tc.invitation)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestInvitationDelete(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
_, err = db.Exec("DELETE FROM domains")
require.Nil(t, err, fmt.Sprintf("clean domains unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := saveDomain(t, repo)
invitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: domainID,
RoleID: testsutil.GenerateUUID(t),
CreatedAt: time.Now(),
}
err := repo.SaveInvitation(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation domains.Invitation
err error
}{
{
desc: "delete invitation successfully",
invitation: domains.Invitation{
InviteeUserID: invitation.InviteeUserID,
DomainID: invitation.DomainID,
},
err: nil,
},
{
desc: "delete invitation with invalid invitation id",
invitation: domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
},
err: repoerr.ErrNotFound,
},
{
desc: "delete invitation with empty invitation id",
invitation: domains.Invitation{},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := repo.DeleteInvitation(context.Background(), tc.invitation.InviteeUserID, tc.invitation.DomainID)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func saveDomain(t *testing.T, repo domains.Repository) string {
domain := domains.Domain{
ID: testsutil.GenerateUUID(t),
Name: "test",
Alias: "test",
Tags: []string{"test"},
Metadata: map[string]interface{}{
"test": "test",
},
CreatedBy: userID,
UpdatedBy: userID,
CreatedAt: time.Now().UTC().Truncate(time.Millisecond),
UpdatedAt: time.Now().UTC().Truncate(time.Millisecond),
Status: domains.EnabledStatus,
}
_, err := repo.SaveDomain(context.Background(), domain)
require.Nil(t, err, fmt.Sprintf("failed to save domain %s", domain.ID))
return domain.ID
}
+1 -1
View File
@@ -38,7 +38,7 @@ func (svc service) RetrieveEntity(ctx context.Context, id string) (domains.Domai
if err == nil {
return domains.Domain{ID: id, Status: status}, nil
}
dom, err := svc.repo.RetrieveByID(ctx, id)
dom, err := svc.repo.RetrieveDomainByID(ctx, id)
if err != nil {
return domains.Domain{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
+138 -8
View File
@@ -61,13 +61,13 @@ func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Do
d.CreatedAt = time.Now()
// Domain is created in repo first, because Roles table have foreign key relation with Domain ID
dom, err := svc.repo.Save(ctx, d)
dom, err := svc.repo.SaveDomain(ctx, d)
if err != nil {
return Domain{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
defer func() {
if retErr != nil {
if errRollBack := svc.repo.Delete(ctx, domainID); errRollBack != nil {
if errRollBack := svc.repo.DeleteDomain(ctx, domainID); errRollBack != nil {
retErr = errors.Wrap(retErr, errors.Wrap(errRollbackRepo, errRollBack))
}
}
@@ -100,9 +100,9 @@ func (svc service) RetrieveDomain(ctx context.Context, session authn.Session, id
var err error
switch session.SuperAdmin {
case true:
domain, err = svc.repo.RetrieveByID(ctx, id)
domain, err = svc.repo.RetrieveDomainByID(ctx, id)
default:
domain, err = svc.repo.RetrieveByUserAndID(ctx, session.UserID, id)
domain, err = svc.repo.RetrieveDomainByUserAndID(ctx, session.UserID, id)
}
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrViewEntity, err)
@@ -114,7 +114,7 @@ func (svc service) UpdateDomain(ctx context.Context, session authn.Session, id s
updatedAt := time.Now()
d.UpdatedAt = &updatedAt
d.UpdatedBy = &session.UserID
dom, err := svc.repo.Update(ctx, id, d)
dom, err := svc.repo.UpdateDomain(ctx, id, d)
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
@@ -124,7 +124,7 @@ func (svc service) UpdateDomain(ctx context.Context, session authn.Session, id s
func (svc service) EnableDomain(ctx context.Context, session authn.Session, id string) (Domain, error) {
status := EnabledStatus
updatedAt := time.Now()
dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
@@ -138,7 +138,7 @@ func (svc service) EnableDomain(ctx context.Context, session authn.Session, id s
func (svc service) DisableDomain(ctx context.Context, session authn.Session, id string) (Domain, error) {
status := DisabledStatus
updatedAt := time.Now()
dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
@@ -153,7 +153,7 @@ func (svc service) DisableDomain(ctx context.Context, session authn.Session, id
func (svc service) FreezeDomain(ctx context.Context, session authn.Session, id string) (Domain, error) {
status := FreezeStatus
updatedAt := time.Now()
dom, err := svc.repo.Update(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
dom, err := svc.repo.UpdateDomain(ctx, id, DomainReq{Status: &status, UpdatedBy: &session.UserID, UpdatedAt: &updatedAt})
if err != nil {
return Domain{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
@@ -176,3 +176,133 @@ func (svc service) ListDomains(ctx context.Context, session authn.Session, p Pag
}
return dp, nil
}
func (svc *service) SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) error {
if _, err := svc.repo.RetrieveRole(ctx, invitation.RoleID); err != nil {
return errors.Wrap(svcerr.ErrInvalidRole, err)
}
invitation.InvitedBy = session.UserID
invitation.CreatedAt = time.Now()
if err := svc.repo.SaveInvitation(ctx, invitation); err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
return nil
}
func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (invitation Invitation, err error) {
inv, err := svc.repo.RetrieveInvitation(ctx, inviteeUserID, domainID)
if err != nil {
return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
role, err := svc.repo.RetrieveRole(ctx, inv.RoleID)
if err != nil {
return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
actions, err := svc.repo.RoleListActions(ctx, inv.RoleID)
if err != nil {
return Invitation{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
inv.Actions = actions
inv.RoleName = role.Name
return inv, nil
}
func (svc *service) ListInvitations(ctx context.Context, session authn.Session, page InvitationPageMeta) (invitations InvitationPage, err error) {
ip, err := svc.repo.RetrieveAllInvitations(ctx, page)
if err != nil {
return InvitationPage{}, err
}
return ip, nil
}
func (svc *service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error {
inv, err := svc.repo.RetrieveInvitation(ctx, session.UserID, domainID)
if err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
if inv.InviteeUserID != session.UserID {
return svcerr.ErrAuthorization
}
if !inv.ConfirmedAt.IsZero() {
return svcerr.ErrInvitationAlreadyAccepted
}
if !inv.RejectedAt.IsZero() {
return svcerr.ErrInvitationAlreadyRejected
}
session.DomainID = domainID
if _, err := svc.RoleAddMembers(ctx, session, domainID, inv.RoleID, []string{session.UserID}); err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
inv.ConfirmedAt = time.Now()
inv.UpdatedAt = inv.ConfirmedAt
if err := svc.repo.UpdateConfirmation(ctx, inv); err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return nil
}
func (svc *service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error {
inv, err := svc.repo.RetrieveInvitation(ctx, session.UserID, domainID)
if err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
if inv.InviteeUserID != session.UserID {
return svcerr.ErrAuthorization
}
if !inv.ConfirmedAt.IsZero() {
return svcerr.ErrInvitationAlreadyAccepted
}
if !inv.RejectedAt.IsZero() {
return svcerr.ErrInvitationAlreadyRejected
}
inv.RejectedAt = time.Now()
inv.UpdatedAt = inv.RejectedAt
if err := svc.repo.UpdateRejection(ctx, inv); err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return nil
}
func (svc *service) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) error {
if session.UserID == inviteeUserID {
if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
inv, err := svc.repo.RetrieveInvitation(ctx, inviteeUserID, domainID)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
if inv.InvitedBy == session.UserID {
if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
if err := svc.repo.DeleteInvitation(ctx, inviteeUserID, domainID); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
+462 -10
View File
@@ -53,8 +53,13 @@ var (
CreatedBy: validID,
UpdatedBy: validID,
}
userID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: userID}
userID = testsutil.GenerateUUID(&testing.T{})
validSession = authn.Session{UserID: userID}
validInvitation = domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(&testing.T{}),
DomainID: testsutil.GenerateUUID(&testing.T{}),
RoleID: testsutil.GenerateUUID(&testing.T{}),
}
)
var (
@@ -166,8 +171,8 @@ func TestCreateDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("Save", mock.Anything, mock.Anything).Return(tc.d, tc.saveDomainErr)
repoCall1 := drepo.On("Delete", mock.Anything, mock.Anything).Return(tc.deleteDomainErr)
repoCall := drepo.On("SaveDomain", mock.Anything, mock.Anything).Return(tc.d, tc.saveDomainErr)
repoCall1 := drepo.On("DeleteDomain", mock.Anything, mock.Anything).Return(tc.deleteDomainErr)
repoCall2 := drepo.On("AddRoles", mock.Anything, mock.Anything).Return([]roles.RoleProvision{}, tc.addRolesErr)
policyCall := policy.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr)
policyCall1 := policy.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr)
@@ -228,8 +233,8 @@ func TestRetrieveDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveByID", context.Background(), tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr)
repoCall1 := drepo.On("RetrieveByUserAndID", context.Background(), tc.session.UserID, tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr)
repoCall := drepo.On("RetrieveDomainByID", context.Background(), tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr)
repoCall1 := drepo.On("RetrieveDomainByUserAndID", context.Background(), tc.session.UserID, tc.domainID).Return(tc.retrieveDomainRes, tc.retrieveDomainErr)
domain, err := svc.RetrieveDomain(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
assert.Equal(t, tc.retrieveDomainRes, domain)
@@ -292,7 +297,7 @@ func TestUpdateDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.updateRes, tc.updateErr)
repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.updateRes, tc.updateErr)
domain, err := svc.UpdateDomain(context.Background(), tc.session, tc.domainID, tc.updateReq)
assert.True(t, errors.Contains(err, tc.err))
assert.Equal(t, tc.updateRes, domain)
@@ -354,7 +359,7 @@ func TestEnableDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.enableRes, tc.enableErr)
repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.enableRes, tc.enableErr)
cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr)
domain, err := svc.EnableDomain(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
@@ -418,7 +423,7 @@ func TestDisableDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.disableRes, tc.disableErr)
repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.disableRes, tc.disableErr)
cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr)
domain, err := svc.DisableDomain(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
@@ -482,7 +487,7 @@ func TestFreezeDomain(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("Update", context.Background(), tc.domainID, mock.Anything).Return(tc.freezeRes, tc.freezeErr)
repoCall := drepo.On("UpdateDomain", context.Background(), tc.domainID, mock.Anything).Return(tc.freezeRes, tc.freezeErr)
cacheCall := dcache.On("Remove", context.Background(), tc.domainID).Return(tc.cacheErr)
domain, err := svc.FreezeDomain(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
@@ -565,3 +570,450 @@ func TestListDomains(t *testing.T) {
})
}
}
func TestSendInvitation(t *testing.T) {
svc := newService()
cases := []struct {
desc string
session authn.Session
req domains.Invitation
retrieveRoleErr error
createInvitationErr error
err error
}{
{
desc: "send invitation successful",
session: validSession,
req: validInvitation,
err: nil,
},
{
desc: "send invitation with invalid role id",
session: validSession,
req: domains.Invitation{
DomainID: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
RoleID: inValid,
},
retrieveRoleErr: repoerr.ErrNotFound,
err: svcerr.ErrInvalidRole,
},
{
desc: "send invitations with failed to save invitation",
session: validSession,
req: validInvitation,
createInvitationErr: repoerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveRole", context.Background(), tc.req.RoleID).Return(roles.Role{}, tc.retrieveRoleErr)
repoCall1 := drepo.On("SaveInvitation", context.Background(), mock.Anything).Return(tc.createInvitationErr)
err := svc.SendInvitation(context.Background(), tc.session, tc.req)
assert.True(t, errors.Contains(err, tc.err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestViewInvitation(t *testing.T) {
svc := newService()
validInvitation := domains.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
Actions: []string{"read", "delete"},
CreatedAt: time.Now().Add(-time.Hour),
UpdatedAt: time.Now().Add(-time.Hour),
ConfirmedAt: time.Now().Add(-time.Hour),
}
cases := []struct {
desc string
userID string
domainID string
session authn.Session
req domains.Invitation
resp domains.Invitation
retrieveInvitationErr error
listRolesErr error
retrieveRoleErr error
err error
}{
{
desc: "view invitation successful",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
session: validSession,
resp: validInvitation,
err: nil,
},
{
desc: "view invitation with error retrieving invitation",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
session: validSession,
retrieveInvitationErr: repoerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
{
desc: "view invitation with failed to retrieve role actions",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
session: validSession,
listRolesErr: repoerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
{
desc: "view invitation with failed to retrieve role",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
session: validSession,
retrieveRoleErr: repoerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.retrieveInvitationErr)
repoCall1 := drepo.On("RoleListActions", context.Background(), tc.resp.RoleID).Return(tc.resp.Actions, tc.listRolesErr)
repoCall2 := drepo.On("RetrieveRole", context.Background(), tc.resp.RoleID).Return(roles.Role{}, tc.retrieveRoleErr)
inv, err := svc.ViewInvitation(context.Background(), tc.session, tc.userID, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
assert.Equal(t, tc.resp, inv, tc.desc)
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
func TestListInvitations(t *testing.T) {
svc := newService()
validPageMeta := domains.InvitationPageMeta{
Offset: 0,
Limit: 10,
}
validResp := domains.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []domains.Invitation{
{
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
RoleName: "admin",
CreatedAt: time.Now().Add(-time.Hour),
UpdatedAt: time.Now().Add(-time.Hour),
ConfirmedAt: time.Now().Add(-time.Hour),
},
},
}
cases := []struct {
desc string
session authn.Session
page domains.InvitationPageMeta
resp domains.InvitationPage
err error
repoErr error
}{
{
desc: "list invitations successful",
session: validSession,
page: validPageMeta,
resp: validResp,
err: nil,
repoErr: nil,
},
{
desc: "list invitations unsuccessful",
session: validSession,
page: validPageMeta,
err: repoerr.ErrViewEntity,
resp: domains.InvitationPage{},
repoErr: repoerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveAllInvitations", context.Background(), mock.Anything).Return(tc.resp, tc.repoErr)
resp, err := svc.ListInvitations(context.Background(), tc.session, tc.page)
assert.Equal(t, tc.err, err, tc.desc)
assert.Equal(t, tc.resp, resp, tc.desc)
repoCall.Unset()
})
}
}
func TestAcceptInvitation(t *testing.T) {
svc := newService()
cases := []struct {
desc string
domainID string
session authn.Session
resp domains.Invitation
retrieveInvitationErr error
updateConfirmationErr error
addRoleMemberErr error
err error
}{
{
desc: "accept invitation successful",
domainID: validID,
session: validSession,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
},
err: nil,
},
{
desc: "accept invitation with failed to retrieve invitation",
session: validSession,
retrieveInvitationErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "accept invitation with of different user",
session: validSession,
resp: domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
},
err: svcerr.ErrAuthorization,
},
{
desc: "accept invitation with failed to add role member",
domainID: validID,
session: validSession,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
},
addRoleMemberErr: repoerr.ErrMalformedEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "accept invitation with failed update confirmation",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: validID,
RoleID: testsutil.GenerateUUID(t),
},
updateConfirmationErr: repoerr.ErrNotFound,
err: svcerr.ErrUpdateEntity,
},
{
desc: "accept invitation that is already confirmed",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
ConfirmedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyAccepted,
},
{
desc: "accept rejected invitation",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
RejectedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyRejected,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveInvitation", context.Background(), tc.session.UserID, tc.domainID).Return(tc.resp, tc.retrieveInvitationErr)
repoCall1 := drepo.On("RetrieveEntityRole", context.Background(), tc.domainID, tc.resp.RoleID).Return(roles.Role{}, tc.addRoleMemberErr)
policyCall := policy.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addRoleMemberErr)
repoCall2 := drepo.On("RoleAddMembers", context.Background(), mock.Anything, []string{tc.resp.InviteeUserID}).Return([]string{}, tc.addRoleMemberErr)
repoCall3 := drepo.On("UpdateConfirmation", context.Background(), mock.Anything).Return(tc.updateConfirmationErr)
err := svc.AcceptInvitation(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
repoCall.Unset()
repoCall1.Unset()
policyCall.Unset()
repoCall2.Unset()
repoCall3.Unset()
})
}
}
func TestRejectInvitation(t *testing.T) {
svc := newService()
cases := []struct {
desc string
domainID string
session authn.Session
resp domains.Invitation
retrieveInvitationErr error
updateConfirmationErr error
addRoleMemberErr error
err error
}{
{
desc: "reject invitation successful",
domainID: validID,
session: validSession,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
},
err: nil,
},
{
desc: "reject invitation with failed to retrieve invitation",
session: validSession,
retrieveInvitationErr: repoerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
{
desc: "reject invitation with of different user",
session: validSession,
resp: domains.Invitation{
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
},
err: svcerr.ErrAuthorization,
},
{
desc: "reject invitation with failed update confirmation",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: validID,
RoleID: testsutil.GenerateUUID(t),
},
updateConfirmationErr: repoerr.ErrNotFound,
err: svcerr.ErrUpdateEntity,
},
{
desc: "reject invitation that is already confirmed",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
ConfirmedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyAccepted,
},
{
desc: "reject rejected invitation",
session: validSession,
domainID: validID,
resp: domains.Invitation{
InviteeUserID: userID,
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
RejectedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyRejected,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveInvitation", context.Background(), tc.session.UserID, tc.domainID).Return(tc.resp, tc.retrieveInvitationErr)
repoCall1 := drepo.On("UpdateRejection", context.Background(), mock.Anything).Return(tc.updateConfirmationErr)
err := svc.RejectInvitation(context.Background(), tc.session, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
func TestDeleteInvitation(t *testing.T) {
svc := newService()
cases := []struct {
desc string
userID string
domainID string
resp domains.Invitation
retrieveInvitationErr error
deleteInvitationErr error
err error
}{
{
desc: "delete invitations successful",
userID: testsutil.GenerateUUID(t),
domainID: testsutil.GenerateUUID(t),
resp: validInvitation,
err: nil,
},
{
desc: "delete invitations for the same user",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
resp: validInvitation,
err: nil,
},
{
desc: "delete invitations for the invited user",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
resp: validInvitation,
err: nil,
},
{
desc: "delete invitation with error retrieving invitation",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
resp: domains.Invitation{},
retrieveInvitationErr: repoerr.ErrNotFound,
err: svcerr.ErrRemoveEntity,
},
{
desc: "delete invitation with error deleting invitation",
userID: validInvitation.InviteeUserID,
domainID: validInvitation.DomainID,
resp: domains.Invitation{},
deleteInvitationErr: repoerr.ErrNotFound,
err: svcerr.ErrRemoveEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := drepo.On("RetrieveInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.retrieveInvitationErr)
repoCall1 := drepo.On("DeleteInvitation", context.Background(), mock.Anything, mock.Anything).Return(tc.deleteInvitationErr)
err := svc.DeleteInvitation(context.Background(), authn.Session{}, tc.userID, tc.domainID)
assert.True(t, errors.Contains(err, tc.err))
repoCall.Unset()
repoCall1.Unset()
})
}
}
+4 -4
View File
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations
package domains
import (
"encoding/json"
@@ -14,7 +14,7 @@ import (
type State uint8
const (
All State = iota // All is used for querying purposes to list invitations irrespective of their state - both pending and accepted.
AllState State = iota // All is used for querying purposes to list invitations irrespective of their state - both pending and accepted.
Pending // Pending is the state of an invitation that has not been accepted yet.
Accepted // Accepted is the state of an invitation that has been accepted.
Rejected // Rejected is the state of an invitation that has been rejected.
@@ -32,7 +32,7 @@ const (
// String converts invitation state to string literal.
func (s State) String() string {
switch s {
case All:
case AllState:
return all
case Pending:
return pending
@@ -49,7 +49,7 @@ func (s State) String() string {
func ToState(status string) (State, error) {
switch status {
case all:
return All, nil
return AllState, nil
case pending:
return Pending, nil
case accepted:
@@ -1,27 +1,27 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations_test
package domains_test
import (
"testing"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/domains"
"github.com/stretchr/testify/assert"
)
func TestState_String(t *testing.T) {
tests := []struct {
name string
state invitations.State
state domains.State
expected string
}{
{"Pending", invitations.Pending, "pending"},
{"Accepted", invitations.Accepted, "accepted"},
{"Rejected", invitations.Rejected, "rejected"},
{"All", invitations.All, "all"},
{"Unknown", invitations.State(100), "unknown"},
{"Pending", domains.Pending, "pending"},
{"Accepted", domains.Accepted, "accepted"},
{"Rejected", domains.Rejected, "rejected"},
{"All", domains.AllState, "all"},
{"Unknown", domains.State(100), "unknown"},
}
for _, tt := range tests {
@@ -34,18 +34,18 @@ func TestToState(t *testing.T) {
tests := []struct {
name string
status string
state invitations.State
state domains.State
err error
}{
{"Pending", "pending", invitations.Pending, nil},
{"Accepted", "accepted", invitations.Accepted, nil},
{"Rejected", "rejected", invitations.Rejected, nil},
{"All", "all", invitations.All, nil},
{"Unknown", "unknown", invitations.State(0), apiutil.ErrInvitationState},
{"Pending", "pending", domains.Pending, nil},
{"Accepted", "accepted", domains.Accepted, nil},
{"Rejected", "rejected", domains.Rejected, nil},
{"All", "all", domains.AllState, nil},
{"Unknown", "unknown", domains.State(0), apiutil.ErrInvitationState},
}
for _, tt := range tests {
got, err := invitations.ToState(tt.status)
got, err := domains.ToState(tt.status)
assert.Equal(t, tt.err, err, "ToState() error = %v, expected %v", err, tt.err)
assert.Equal(t, tt.state, got, "ToState() = %v, expected %v", got, tt.state)
}
@@ -54,15 +54,15 @@ func TestToState(t *testing.T) {
func TestState_MarshalJSON(t *testing.T) {
tests := []struct {
name string
state invitations.State
state domains.State
expected []byte
err error
}{
{"Pending", invitations.Pending, []byte(`"pending"`), nil},
{"Accepted", invitations.Accepted, []byte(`"accepted"`), nil},
{"Rejected", invitations.Rejected, []byte(`"rejected"`), nil},
{"All", invitations.All, []byte(`"all"`), nil},
{"Unknown", invitations.State(100), []byte(`"unknown"`), nil},
{"Pending", domains.Pending, []byte(`"pending"`), nil},
{"Accepted", domains.Accepted, []byte(`"accepted"`), nil},
{"Rejected", domains.Rejected, []byte(`"rejected"`), nil},
{"All", domains.AllState, []byte(`"all"`), nil},
{"Unknown", domains.State(100), []byte(`"unknown"`), nil},
}
for _, tt := range tests {
@@ -76,18 +76,18 @@ func TestState_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
data []byte
state invitations.State
state domains.State
err error
}{
{"Pending", []byte(`"pending"`), invitations.Pending, nil},
{"Accepted", []byte(`"accepted"`), invitations.Accepted, nil},
{"Rejected", []byte(`"rejected"`), invitations.Rejected, nil},
{"All", []byte(`"all"`), invitations.All, nil},
{"Unknown", []byte(`"unknown"`), invitations.State(0), apiutil.ErrInvitationState},
{"Pending", []byte(`"pending"`), domains.Pending, nil},
{"Accepted", []byte(`"accepted"`), domains.Accepted, nil},
{"Rejected", []byte(`"rejected"`), domains.Rejected, nil},
{"All", []byte(`"all"`), domains.AllState, nil},
{"Unknown", []byte(`"unknown"`), domains.State(0), apiutil.ErrInvitationState},
}
for _, tt := range tests {
var state invitations.State
var state domains.State
err := state.UnmarshalJSON(tt.data)
assert.Equal(t, tt.err, err, "State.UnmarshalJSON() error = %v, expected %v", err, tt.err)
assert.Equal(t, tt.state, state, "State.UnmarshalJSON() = %v, expected %v", state, tt.state)
+61
View File
@@ -80,3 +80,64 @@ func (tm *tracingMiddleware) ListDomains(ctx context.Context, session authn.Sess
defer span.End()
return tm.svc.ListDomains(ctx, session, p)
}
func (tm *tracingMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (err error) {
ctx, span := tm.tracer.Start(ctx, "send_invitation", trace.WithAttributes(
attribute.String("domain_id", invitation.DomainID),
attribute.String("invitee_user_id", invitation.InviteeUserID),
))
defer span.End()
return tm.svc.SendInvitation(ctx, session, invitation)
}
func (tm *tracingMiddleware) ViewInvitation(ctx context.Context, session authn.Session, inviteeUserID, domain string) (invitation domains.Invitation, err error) {
ctx, span := tm.tracer.Start(ctx, "view_invitation", trace.WithAttributes(
attribute.String("invitee_user_id", inviteeUserID),
attribute.String("domain_id", domain),
))
defer span.End()
return tm.svc.ViewInvitation(ctx, session, inviteeUserID, domain)
}
func (tm *tracingMiddleware) ListInvitations(ctx context.Context, session authn.Session, pm domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
ctx, span := tm.tracer.Start(ctx, "list_invitations", trace.WithAttributes(
attribute.Int("limit", int(pm.Limit)),
attribute.Int("offset", int(pm.Offset)),
attribute.String("invitee_user_id", pm.InviteeUserID),
attribute.String("domain_id", pm.DomainID),
attribute.String("invited_by", pm.InvitedBy),
))
defer span.End()
return tm.svc.ListInvitations(ctx, session, pm)
}
func (tm *tracingMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "accept_invitation", trace.WithAttributes(
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.AcceptInvitation(ctx, session, domainID)
}
func (tm *tracingMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "reject_invitation", trace.WithAttributes(
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.RejectInvitation(ctx, session, domainID)
}
func (tm *tracingMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, inviteeUserID, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "delete_invitation", trace.WithAttributes(
attribute.String("invitee_user_id", inviteeUserID),
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.DeleteInvitation(ctx, session, inviteeUserID, domainID)
}
-80
View File
@@ -1,80 +0,0 @@
# Invitation Service
Invitation service is responsible for sending invitations to users to join a domain.
## Configuration
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
| Variable | Description | Default |
| ------------------------------- | ------------------------------------------------ | ----------------------- |
| SMQ_INVITATION_LOG_LEVEL | Log level for the Invitation service | debug |
| SMQ_USERS_URL | Users service URL | <http://localhost:9002> |
| SMQ_DOMAINS_URL | Domains service URL | <http://localhost:8189> |
| SMQ_INVITATIONS_HTTP_HOST | Invitation service HTTP listening host | localhost |
| SMQ_INVITATIONS_HTTP_PORT | Invitation service HTTP listening port | 9020 |
| SMQ_INVITATIONS_HTTP_SERVER_CERT | Invitation service server certificate | "" |
| SMQ_INVITATIONS_HTTP_SERVER_KEY | Invitation service server key | "" |
| SMQ_AUTH_GRPC_URL | Auth service gRPC URL | localhost:8181 |
| SMQ_AUTH_GRPC_TIMEOUT | Auth service gRPC request timeout in seconds | 1s |
| SMQ_AUTH_GRPC_CLIENT_CERT | Path to client certificate in PEM format | "" |
| SMQ_AUTH_GRPC_CLIENT_KEY | Path to client key in PEM format | "" |
| SMQ_AUTH_GRPC_CLIENT_CA_CERTS | Path to trusted CAs in PEM format | "" |
| SMQ_INVITATIONS_DB_HOST | Invitation service database host | localhost |
| SMQ_INVITATIONS_DB_USER | Invitation service database user | supermq |
| SMQ_INVITATIONS_DB_PASS | Invitation service database password | supermq |
| SMQ_INVITATIONS_DB_PORT | Invitation service database port | 5432 |
| SMQ_INVITATIONS_DB_NAME | Invitation service database name | invitations |
| SMQ_INVITATIONS_DB_SSL_MODE | Invitation service database SSL mode | disable |
| SMQ_INVITATIONS_DB_SSL_CERT | Invitation service database SSL certificate | "" |
| SMQ_INVITATIONS_DB_SSL_KEY | Invitation service database SSL key | "" |
| SMQ_INVITATIONS_DB_SSL_ROOT_CERT | Invitation service database SSL root certificate | "" |
| SMQ_INVITATIONS_INSTANCE_ID | Invitation service instance ID | |
## Deployment
The service itself is distributed as Docker container. Check the [`invitation`](https://github.com/absmach/amdm/blob/main/docker/docker-compose.yml) service section in docker-compose file to see how service is deployed.
To start the service outside of the container, execute the following shell script:
```bash
# download the latest version of the service
git clone https://github.com/absmach/supermq
cd supermq
# compile the http
make invitation
# copy binary to bin
make install
# set the environment variables and run the service
SMQ_INVITATION_LOG_LEVEL=info \
SMQ_INVITATIONS_ENDPOINT=/invitations \
SMQ_USERS_URL="http://localhost:9002" \
SMQ_DOMAINS_URL="http://localhost:8189" \
SMQ_INVITATIONS_HTTP_HOST=localhost \
SMQ_INVITATIONS_HTTP_PORT=9020 \
SMQ_INVITATIONS_HTTP_SERVER_CERT="" \
SMQ_INVITATIONS_HTTP_SERVER_KEY="" \
SMQ_AUTH_GRPC_URL=localhost:8181 \
SMQ_AUTH_GRPC_TIMEOUT=1s \
SMQ_AUTH_GRPC_CLIENT_CERT="" \
SMQ_AUTH_GRPC_CLIENT_KEY="" \
SMQ_AUTH_GRPC_CLIENT_CA_CERTS="" \
SMQ_INVITATIONS_DB_HOST=localhost \
SMQ_INVITATIONS_DB_USER=supermq \
SMQ_INVITATIONS_DB_PASS=supermq \
SMQ_INVITATIONS_DB_PORT=5432 \
SMQ_INVITATIONS_DB_NAME=invitations \
SMQ_INVITATIONS_DB_SSL_MODE=disable \
SMQ_INVITATIONS_DB_SSL_CERT="" \
SMQ_INVITATIONS_DB_SSL_KEY="" \
SMQ_INVITATIONS_DB_SSL_ROOT_CERT="" \
$GOBIN/supermq-invitation
```
## Usage
For more information about service capabilities and its usage, please check out the [API documentation](https://docs.api.supermq.abstractmachines.fr/?urls.primaryName=invitations.yml).
-4
View File
@@ -1,4 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
-154
View File
@@ -1,154 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/go-kit/kit/endpoint"
)
// InvitationSent is the message returned when an invitation is sent.
const InvitationSent = "invitation sent"
func sendInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(sendInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.DomainID
invitation := invitations.Invitation{
UserID: req.UserID,
DomainID: req.DomainID,
Relation: req.Relation,
Resend: req.Resend,
}
if err := svc.SendInvitation(ctx, session, invitation); err != nil {
return nil, err
}
return sendInvitationRes{
Message: InvitationSent,
}, nil
}
}
func viewInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(invitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.domainID
invitation, err := svc.ViewInvitation(ctx, session, req.userID, req.domainID)
if err != nil {
return nil, err
}
return viewInvitationRes{
Invitation: invitation,
}, nil
}
}
func listInvitationsEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listInvitationsReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.DomainID
page, err := svc.ListInvitations(ctx, session, req.Page)
if err != nil {
return nil, err
}
return listInvitationsRes{
page,
}, nil
}
}
func acceptInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(acceptInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.AcceptInvitation(ctx, session, req.DomainID); err != nil {
return nil, err
}
return acceptInvitationRes{}, nil
}
}
func rejectInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(acceptInvitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
if err := svc.RejectInvitation(ctx, session, req.DomainID); err != nil {
return nil, err
}
return rejectInvitationRes{}, nil
}
}
func deleteInvitationEndpoint(svc invitations.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(invitationReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
session.DomainID = req.domainID
if err := svc.DeleteInvitation(ctx, session, req.userID, req.domainID); err != nil {
return nil, err
}
return deleteInvitationRes{}, nil
}
}
-672
View File
@@ -1,672 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api_test
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/invitations/api"
"github.com/absmach/supermq/invitations/mocks"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
validToken = "valid"
validContenType = "application/json"
validID = testsutil.GenerateUUID(&testing.T{})
domainID = testsutil.GenerateUUID(&testing.T{})
)
type testRequest struct {
client *http.Client
method string
url string
token string
contentType string
body io.Reader
}
func (tr testRequest) make() (*http.Response, error) {
req, err := http.NewRequest(tr.method, tr.url, tr.body)
if err != nil {
return nil, err
}
if tr.token != "" {
req.Header.Set("Authorization", apiutil.BearerPrefix+tr.token)
}
if tr.contentType != "" {
req.Header.Set("Content-Type", tr.contentType)
}
return tr.client.Do(req)
}
func newIvitationsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) {
svc := new(mocks.Service)
logger := smqlog.NewMock()
authn := new(authnmocks.Authentication)
mux := api.MakeHandler(svc, logger, authn, "test")
return httptest.NewServer(mux), svc, authn
}
func TestSendInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
cases := []struct {
desc string
token string
data string
contentType string
status int
authnRes smqauthn.Session
authnErr error
svcErr error
}{
{
desc: "valid request",
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, domainID, "domain"),
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusCreated,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, validID, "domain"),
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "empty domain_id",
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, "", "domain"),
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid content type",
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s","domain_id": "%s", "relation": "%s"}`, validID, validID, "domain"),
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "invalid data",
token: validToken,
data: `data`,
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
token: validToken,
data: fmt.Sprintf(`{"user_id": "%s", "domain_id": "%s", "relation": "%s"}`, validID, domainID, "domain"),
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("SendInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: is.URL + "/invitations",
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestListInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
cases := []struct {
desc string
token string
query string
contentType string
status int
svcErr error
authnRes smqauthn.Session
authnErr error
}{
{
desc: "valid request",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with offset",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "offset=1",
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with invalid offset",
token: validToken,
query: "offset=invalid",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with limit",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "limit=1",
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with invalid limit",
token: validToken,
query: "limit=invalid",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with user_id",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("user_id=%s", validID),
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with duplicate user_id",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "user_id=1&user_id=2",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with invited_by",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("invited_by=%s", validID),
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with duplicate invited_by",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "invited_by=1&invited_by=2",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with relation",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: fmt.Sprintf("relation=%s", "relation"),
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with duplicate relation",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "relation=1&relation=2",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with state",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
query: "state=pending",
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with invalid state",
token: validToken,
query: "state=invalid",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with duplicate state",
token: validToken,
query: "state=all&state=all",
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: smqauthn.Session{UserID: validID, DomainUserID: domainID + "_" + validID},
token: validToken,
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("ListInvitations", mock.Anything, tc.authnRes, mock.Anything).Return(invitations.InvitationPage{}, tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodGet,
url: is.URL + "/invitations?" + tc.query,
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestViewInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
cases := []struct {
desc string
token string
domainID string
userID string
contentType string
status int
svcErr error
authnRes smqauthn.Session
authnErr error
}{
{
desc: "valid request",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusOK,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
userID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: svcerr.ErrViewEntity,
},
{
desc: "with empty user_id",
token: validToken,
userID: "",
domainID: domainID,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with empty domain",
token: validToken,
userID: validID,
domainID: "",
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with empty user_id and domain_id",
token: validToken,
userID: "",
domainID: "",
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("ViewInvitation", mock.Anything, tc.authnRes, tc.userID, tc.domainID).Return(invitations.Invitation{}, tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodGet,
url: is.URL + "/invitations/" + tc.userID + "/" + tc.domainID,
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestDeleteInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
_ = authn
cases := []struct {
desc string
token string
domainID string
userID string
contentType string
status int
svcErr error
authnRes smqauthn.Session
authnErr error
}{
{
desc: "valid request",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusNoContent,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
userID: validID,
domainID: domainID,
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
userID: validID,
domainID: domainID,
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "with empty user_id",
token: validToken,
userID: "",
domainID: domainID,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with empty domain_id",
token: validToken,
userID: validID,
domainID: "",
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with empty user_id and domain_id",
token: validToken,
userID: "",
domainID: "",
status: http.StatusNotFound,
contentType: validContenType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("DeleteInvitation", mock.Anything, tc.authnRes, tc.userID, tc.domainID).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodDelete,
url: is.URL + "/invitations/" + tc.userID + "/" + tc.domainID,
token: tc.token,
contentType: tc.contentType,
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestAcceptInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
_ = authn
cases := []struct {
desc string
token string
data string
contentType string
status int
svcErr error
authnRes smqauthn.Session
authnErr error
}{
{
desc: "valid request",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
token: validToken,
status: http.StatusNoContent,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "with service error",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "invalid content type",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "invalid data",
token: validToken,
data: `data`,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("AcceptInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: is.URL + "/invitations/accept",
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
func TestRejectInvitation(t *testing.T) {
is, svc, authn := newIvitationsServer()
_ = authn
cases := []struct {
desc string
token string
data string
contentType string
status int
svcErr error
authnRes smqauthn.Session
authnErr error
}{
{
desc: "valid request",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusNoContent,
contentType: validContenType,
svcErr: nil,
},
{
desc: "invalid token",
token: "",
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnauthorized,
contentType: validContenType,
svcErr: nil,
},
{
desc: "unauthorized error",
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID, DomainUserID: domainID + "_" + validID},
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, "invalid"),
status: http.StatusForbidden,
contentType: validContenType,
svcErr: svcerr.ErrAuthorization,
},
{
desc: "invalid content type",
token: validToken,
data: fmt.Sprintf(`{"domain_id": "%s"}`, validID),
status: http.StatusUnsupportedMediaType,
contentType: "text/plain",
svcErr: nil,
},
{
desc: "invalid data",
token: validToken,
data: `data`,
status: http.StatusBadRequest,
contentType: validContenType,
svcErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
repoCall := svc.On("RejectInvitation", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr)
req := testRequest{
client: is.Client(),
method: http.MethodPost,
url: is.URL + "/invitations/reject",
token: tc.token,
contentType: tc.contentType,
body: strings.NewReader(tc.data),
}
res, err := req.make()
assert.Nil(t, err, tc.desc)
assert.Equal(t, tc.status, res.StatusCode, tc.desc)
repoCall.Unset()
authnCall.Unset()
})
}
}
-72
View File
@@ -1,72 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
)
const maxLimitSize = 100
type sendInvitationReq struct {
UserID string `json:"user_id,omitempty"`
DomainID string `json:"domain_id,omitempty"`
Relation string `json:"relation,omitempty"`
Resend bool `json:"resend,omitempty"`
}
func (req *sendInvitationReq) validate() error {
if req.UserID == "" {
return apiutil.ErrMissingID
}
if req.DomainID == "" {
return apiutil.ErrMissingDomainID
}
if err := invitations.CheckRelation(req.Relation); err != nil {
return err
}
return nil
}
type listInvitationsReq struct {
invitations.Page
}
func (req *listInvitationsReq) validate() error {
if req.Page.Limit > maxLimitSize || req.Page.Limit < 1 {
return apiutil.ErrLimitSize
}
return nil
}
type acceptInvitationReq struct {
DomainID string `json:"domain_id,omitempty"`
}
func (req *acceptInvitationReq) validate() error {
if req.DomainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
type invitationReq struct {
userID string
domainID string
}
func (req *invitationReq) validate() error {
if req.userID == "" {
return apiutil.ErrMissingID
}
if req.domainID == "" {
return apiutil.ErrMissingDomainID
}
return nil
}
-182
View File
@@ -1,182 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"fmt"
"testing"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/policies"
"github.com/stretchr/testify/assert"
)
var valid = "valid"
func TestSendInvitationReqValidation(t *testing.T) {
cases := []struct {
desc string
req sendInvitationReq
err error
}{
{
desc: "valid request",
req: sendInvitationReq{
UserID: valid,
DomainID: valid,
Relation: policies.DomainRelation,
Resend: true,
},
err: nil,
},
{
desc: "empty user ID",
req: sendInvitationReq{
UserID: "",
DomainID: valid,
Relation: policies.DomainRelation,
Resend: true,
},
err: apiutil.ErrMissingID,
},
{
desc: "empty domain_id",
req: sendInvitationReq{
UserID: valid,
DomainID: "",
Relation: policies.DomainRelation,
Resend: true,
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "missing relation",
req: sendInvitationReq{
UserID: valid,
DomainID: valid,
Relation: "",
Resend: true,
},
err: apiutil.ErrMissingRelation,
},
{
desc: "invalid relation",
req: sendInvitationReq{
UserID: valid,
DomainID: valid,
Relation: "invalid",
Resend: true,
},
err: apiutil.ErrInvalidRelation,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestListInvitationsReq(t *testing.T) {
cases := []struct {
desc string
req listInvitationsReq
err error
}{
{
desc: "valid request",
req: listInvitationsReq{
Page: invitations.Page{Limit: 1},
},
err: nil,
},
{
desc: "invalid limit",
req: listInvitationsReq{
Page: invitations.Page{Limit: 1000},
},
err: apiutil.ErrLimitSize,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestAcceptInvitationReq(t *testing.T) {
cases := []struct {
desc string
req acceptInvitationReq
err error
}{
{
desc: "valid request",
req: acceptInvitationReq{
DomainID: valid,
},
err: nil,
},
{
desc: "empty domain_id",
req: acceptInvitationReq{
DomainID: "",
},
err: apiutil.ErrMissingDomainID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
func TestInvitationReqValidation(t *testing.T) {
cases := []struct {
desc string
req invitationReq
err error
}{
{
desc: "valid request",
req: invitationReq{
userID: valid,
domainID: valid,
},
err: nil,
},
{
desc: "empty user ID",
req: invitationReq{
userID: "",
domainID: valid,
},
err: apiutil.ErrMissingID,
},
{
desc: "empty domain",
req: invitationReq{
userID: valid,
domainID: "",
},
err: apiutil.ErrMissingDomainID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tc.req.validate()
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
})
}
}
-110
View File
@@ -1,110 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"net/http"
"github.com/absmach/supermq"
"github.com/absmach/supermq/invitations"
)
var (
_ supermq.Response = (*sendInvitationRes)(nil)
_ supermq.Response = (*viewInvitationRes)(nil)
_ supermq.Response = (*listInvitationsRes)(nil)
_ supermq.Response = (*acceptInvitationRes)(nil)
_ supermq.Response = (*rejectInvitationRes)(nil)
_ supermq.Response = (*deleteInvitationRes)(nil)
)
type sendInvitationRes struct {
Message string `json:"message"`
}
func (res sendInvitationRes) Code() int {
return http.StatusCreated
}
func (res sendInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res sendInvitationRes) Empty() bool {
return true
}
type viewInvitationRes struct {
invitations.Invitation `json:",inline"`
}
func (res viewInvitationRes) Code() int {
return http.StatusOK
}
func (res viewInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res viewInvitationRes) Empty() bool {
return false
}
type listInvitationsRes struct {
invitations.InvitationPage `json:",inline"`
}
func (res listInvitationsRes) Code() int {
return http.StatusOK
}
func (res listInvitationsRes) Headers() map[string]string {
return map[string]string{}
}
func (res listInvitationsRes) Empty() bool {
return false
}
type acceptInvitationRes struct{}
func (res acceptInvitationRes) Code() int {
return http.StatusNoContent
}
func (res acceptInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res acceptInvitationRes) Empty() bool {
return true
}
type deleteInvitationRes struct{}
func (res deleteInvitationRes) Code() int {
return http.StatusNoContent
}
func (res deleteInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res deleteInvitationRes) Empty() bool {
return true
}
type rejectInvitationRes struct{}
func (res rejectInvitationRes) Code() int {
return http.StatusNoContent
}
func (res rejectInvitationRes) Headers() map[string]string {
return map[string]string{}
}
func (res rejectInvitationRes) Empty() bool {
return true
}
-172
View File
@@ -1,172 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package api
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"strings"
"github.com/absmach/supermq"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
const (
userIDKey = "user_id"
domainIDKey = "domain_id"
invitedByKey = "invited_by"
relationKey = "relation"
stateKey = "state"
)
func MakeHandler(svc invitations.Service, logger *slog.Logger, authn smqauthn.Authentication, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
mux := chi.NewRouter()
mux.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Route("/invitations", func(r chi.Router) {
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
sendInvitationEndpoint(svc),
decodeSendInvitationReq,
api.EncodeResponse,
opts...,
), "send_invitation").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listInvitationsEndpoint(svc),
decodeListInvitationsReq,
api.EncodeResponse,
opts...,
), "list_invitations").ServeHTTP)
r.Route("/{user_id}/{domain_id}", func(r chi.Router) {
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
viewInvitationEndpoint(svc),
decodeInvitationReq,
api.EncodeResponse,
opts...,
), "view_invitations").ServeHTTP)
r.Delete("/", otelhttp.NewHandler(kithttp.NewServer(
deleteInvitationEndpoint(svc),
decodeInvitationReq,
api.EncodeResponse,
opts...,
), "delete_invitation").ServeHTTP)
})
r.Post("/accept", otelhttp.NewHandler(kithttp.NewServer(
acceptInvitationEndpoint(svc),
decodeAcceptInvitationReq,
api.EncodeResponse,
opts...,
), "accept_invitation").ServeHTTP)
r.Post("/reject", otelhttp.NewHandler(kithttp.NewServer(
rejectInvitationEndpoint(svc),
decodeAcceptInvitationReq,
api.EncodeResponse,
opts...,
), "reject_invitation").ServeHTTP)
})
})
mux.Get("/health", supermq.Health("invitations", instanceID))
mux.Handle("/metrics", promhttp.Handler())
return mux
}
func decodeSendInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
var req sendInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
func decodeListInvitationsReq(_ context.Context, r *http.Request) (interface{}, error) {
offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
userID, err := apiutil.ReadStringQuery(r, userIDKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
invitedBy, err := apiutil.ReadStringQuery(r, invitedByKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
relation, err := apiutil.ReadStringQuery(r, relationKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
domainID, err := apiutil.ReadStringQuery(r, domainIDKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
st, err := apiutil.ReadStringQuery(r, stateKey, invitations.All.String())
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
state, err := invitations.ToState(st)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
req := listInvitationsReq{
Page: invitations.Page{
Offset: offset,
Limit: limit,
InvitedBy: invitedBy,
UserID: userID,
Relation: relation,
DomainID: domainID,
State: state,
},
}
return req, nil
}
func decodeAcceptInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
var req acceptInvitationReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity))
}
return req, nil
}
func decodeInvitationReq(_ context.Context, r *http.Request) (interface{}, error) {
req := invitationReq{
userID: chi.URLParam(r, "user_id"),
domainID: chi.URLParam(r, "domain_id"),
}
return req, nil
}
-7
View File
@@ -1,7 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package invitations provides the API to manage invitations.
//
// An invitation is a request to join a domain.
package invitations
-149
View File
@@ -1,149 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations
import (
"context"
"encoding/json"
"time"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/policies"
)
// Invitation is an invitation to join a domain.
type Invitation struct {
InvitedBy string `json:"invited_by"`
UserID string `json:"user_id"`
DomainID string `json:"domain_id"`
Token string `json:"token,omitempty"`
Relation string `json:"relation,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
ConfirmedAt time.Time `json:"confirmed_at,omitempty"`
RejectedAt time.Time `json:"rejected_at,omitempty"`
Resend bool `json:"resend,omitempty"`
}
// Page is a page of invitations.
type Page struct {
Offset uint64 `json:"offset" db:"offset"`
Limit uint64 `json:"limit" db:"limit"`
InvitedBy string `json:"invited_by,omitempty" db:"invited_by,omitempty"`
UserID string `json:"user_id,omitempty" db:"user_id,omitempty"`
DomainID string `json:"domain_id,omitempty" db:"domain_id,omitempty"`
Relation string `json:"relation,omitempty" db:"relation,omitempty"`
InvitedByOrUserID string `db:"invited_by_or_user_id,omitempty"`
State State `json:"state,omitempty"`
}
// InvitationPage is a page of invitations.
type InvitationPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
Invitations []Invitation `json:"invitations"`
}
func (page InvitationPage) MarshalJSON() ([]byte, error) {
type Alias InvitationPage
a := struct {
Alias
}{
Alias: Alias(page),
}
if a.Invitations == nil {
a.Invitations = make([]Invitation, 0)
}
return json.Marshal(a)
}
// Service is an interface that defines methods for managing invitations.
//
//go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines"
type Service interface {
// SendInvitation sends an invitation to the given user.
// Only domain administrators and platform administrators can send invitations.
SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) (err error)
// ViewInvitation returns an invitation.
// People who can view invitations are:
// - the invited user: they can view their own invitations
// - the user who sent the invitation
// - domain administrators
// - platform administrators
ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error)
// ListInvitations returns a list of invitations.
// People who can list invitations are:
// - platform administrators can list all invitations
// - domain administrators can list invitations for their domain
// By default, it will list invitations the current user has sent or received.
ListInvitations(ctx context.Context, session authn.Session, page Page) (invitations InvitationPage, err error)
// AcceptInvitation accepts an invitation by adding the user to the domain.
AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error)
// DeleteInvitation deletes an invitation.
// People who can delete invitations are:
// - the invited user: they can delete their own invitations
// - the user who sent the invitation
// - domain administrators
// - platform administrators
DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error)
// RejectInvitation rejects an invitation.
// People who can reject invitations are:
// - the invited user: they can reject their own invitations
RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error)
}
//go:generate mockery --name Repository --output=./mocks --filename repository.go --quiet --note "Copyright (c) Abstract Machines"
type Repository interface {
// Create creates an invitation.
Create(ctx context.Context, invitation Invitation) (err error)
// Retrieve returns an invitation.
Retrieve(ctx context.Context, userID, domainID string) (Invitation, error)
// RetrieveAll returns a list of invitations based on the given page.
RetrieveAll(ctx context.Context, page Page) (invitations InvitationPage, err error)
// UpdateToken updates an invitation by setting the token.
UpdateToken(ctx context.Context, invitation Invitation) (err error)
// UpdateConfirmation updates an invitation by setting the confirmation time.
UpdateConfirmation(ctx context.Context, invitation Invitation) (err error)
// UpdateRejection updates an invitation by setting the rejection time.
UpdateRejection(ctx context.Context, invitation Invitation) (err error)
// Delete deletes an invitation.
Delete(ctx context.Context, userID, domainID string) (err error)
}
// CheckRelation checks if the given relation is valid.
// It returns an error if the relation is empty or invalid.
func CheckRelation(relation string) error {
if relation == "" {
return apiutil.ErrMissingRelation
}
if relation != policies.AdministratorRelation &&
relation != policies.EditorRelation &&
relation != policies.ContributorRelation &&
relation != policies.MemberRelation &&
relation != policies.GuestRelation &&
relation != policies.DomainRelation &&
relation != policies.ParentGroupRelation &&
relation != policies.RoleGroupRelation &&
relation != policies.GroupRelation &&
relation != policies.PlatformRelation {
return apiutil.ErrInvalidRelation
}
return nil
}
-75
View File
@@ -1,75 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations_test
import (
"fmt"
"testing"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/invitations"
"github.com/stretchr/testify/assert"
)
func TestInvitation_MarshalJSON(t *testing.T) {
cases := []struct {
desc string
page invitations.InvitationPage
res string
}{
{
desc: "empty page",
page: invitations.InvitationPage{
Invitations: []invitations.Invitation(nil),
},
res: `{"total":0,"offset":0,"limit":0,"invitations":[]}`,
},
{
desc: "page with invitations",
page: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 0,
Invitations: []invitations.Invitation{
{
InvitedBy: "John",
UserID: "123",
DomainID: "123",
},
},
},
res: `{"total":1,"offset":0,"limit":0,"invitations":[{"invited_by":"John","user_id":"123","domain_id":"123","created_at":"0001-01-01T00:00:00Z","updated_at":"0001-01-01T00:00:00Z","confirmed_at":"0001-01-01T00:00:00Z","rejected_at":"0001-01-01T00:00:00Z"}]}`,
},
}
for _, tc := range cases {
data, err := tc.page.MarshalJSON()
assert.NoError(t, err, "Unexpected error: %v", err)
assert.Equal(t, tc.res, string(data), fmt.Sprintf("%s: expected %s, got %s", tc.desc, tc.res, string(data)))
}
}
func TestCheckRelation(t *testing.T) {
cases := []struct {
relation string
err error
}{
{"", apiutil.ErrMissingRelation},
{"admin", apiutil.ErrInvalidRelation},
{"editor", nil},
{"contributor", nil},
{"member", nil},
{"guest", nil},
{"domain", nil},
{"parent_group", nil},
{"role_group", nil},
{"group", nil},
{"platform", nil},
}
for _, tc := range cases {
err := invitations.CheckRelation(tc.relation)
assert.Equal(t, tc.err, err, "CheckRelation(%q) expected %v, got %v", tc.relation, tc.err, err)
}
}
-122
View File
@@ -1,122 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
)
// ErrMemberExist indicates that the user is already a member of the domain.
var ErrMemberExist = errors.New("user is already a member of the domain")
var _ invitations.Service = (*tracing)(nil)
type authorizationMiddleware struct {
authz authz.Authorization
svc invitations.Service
}
func AuthorizationMiddleware(authz authz.Authorization, svc invitations.Service) invitations.Service {
return &authorizationMiddleware{authz, svc}
}
func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.UserID)
if err := am.authorize(ctx, domainUserId, policies.MembershipPermission, policies.DomainType, invitation.DomainID); err == nil {
// return error if the user is already a member of the domain
return errors.Wrap(svcerr.ErrConflict, ErrMemberExist)
}
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
return am.svc.SendInvitation(ctx, session, invitation)
}
func (am *authorizationMiddleware) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if session.UserID != userID {
if err := am.checkAdmin(ctx, session); err != nil {
return invitations.Invitation{}, err
}
}
return am.svc.ViewInvitation(ctx, session, userID, domain)
}
func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil {
session.SuperAdmin = true
}
if !session.SuperAdmin {
switch {
case page.DomainID != "":
if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, page.DomainID); err != nil {
return invitations.InvitationPage{}, err
}
default:
page.InvitedByOrUserID = session.UserID
}
}
return am.svc.ListInvitations(ctx, session, page)
}
func (am *authorizationMiddleware) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
return am.svc.AcceptInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
return am.svc.RejectInvitation(ctx, session, domainID)
}
func (am *authorizationMiddleware) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
session.DomainUserID = auth.EncodeDomainUserID(session.DomainID, session.UserID)
if err := am.checkAdmin(ctx, session); err != nil {
return err
}
return am.svc.DeleteInvitation(ctx, session, userID, domainID)
}
// checkAdmin checks if the given user is a domain or platform administrator.
func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error {
if err := am.authorize(ctx, session.DomainUserID, policies.AdminPermission, policies.DomainType, session.DomainID); err == nil {
return nil
}
if err := am.authorize(ctx, session.UserID, policies.AdminPermission, policies.PlatformType, policies.SuperMQObject); err == nil {
return nil
}
return svcerr.ErrAuthorization
}
func (am *authorizationMiddleware) authorize(ctx context.Context, subj, perm, objType, obj string) error {
req := authz.PolicyReq{
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: subj,
Permission: perm,
ObjectType: objType,
Object: obj,
}
if err := am.authz.Authorize(ctx, req); err != nil {
return err
}
return nil
}
-9
View File
@@ -1,9 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package middleware contains the middleware for the invitations service.
// It is responsible for the following:
// - Logging
// - Metrics
// - Tracing
package middleware
-127
View File
@@ -1,127 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"log/slog"
"time"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/authn"
)
var _ invitations.Service = (*logging)(nil)
type logging struct {
logger *slog.Logger
svc invitations.Service
}
func Logging(logger *slog.Logger, svc invitations.Service) invitations.Service {
return &logging{logger, svc}
}
func (lm *logging) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", invitation.UserID),
slog.String("domain_id", invitation.DomainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Send invitation failed", args...)
return
}
lm.logger.Info("Send invitation completed successfully", args...)
}(time.Now())
return lm.svc.SendInvitation(ctx, session, invitation)
}
func (lm *logging) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation invitations.Invitation, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("View invitation failed", args...)
return
}
lm.logger.Info("View invitation completed successfully", args...)
}(time.Now())
return lm.svc.ViewInvitation(ctx, session, userID, domainID)
}
func (lm *logging) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Group("page",
slog.Uint64("offset", page.Offset),
slog.Uint64("limit", page.Limit),
slog.Uint64("total", invs.Total),
),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List invitations failed", args...)
return
}
lm.logger.Info("List invitations completed successfully", args...)
}(time.Now())
return lm.svc.ListInvitations(ctx, session, page)
}
func (lm *logging) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Accept invitation failed", args...)
return
}
lm.logger.Info("Accept invitation completed successfully", args...)
}(time.Now())
return lm.svc.AcceptInvitation(ctx, session, domainID)
}
func (lm *logging) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Reject invitation failed", args...)
return
}
lm.logger.Info("Reject invitation completed successfully", args...)
}(time.Now())
return lm.svc.RejectInvitation(ctx, session, domainID)
}
func (lm *logging) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.String("domain_id", domainID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Delete invitation failed", args...)
return
}
lm.logger.Info("Delete invitation completed successfully", args...)
}(time.Now())
return lm.svc.DeleteInvitation(ctx, session, userID, domainID)
}
-77
View File
@@ -1,77 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"time"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/authn"
"github.com/go-kit/kit/metrics"
)
var _ invitations.Service = (*metricsmw)(nil)
type metricsmw struct {
counter metrics.Counter
latency metrics.Histogram
svc invitations.Service
}
func Metrics(counter metrics.Counter, latency metrics.Histogram, svc invitations.Service) invitations.Service {
return &metricsmw{
counter: counter,
latency: latency,
svc: svc,
}
}
func (mm *metricsmw) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "send_invitation").Add(1)
mm.latency.With("method", "send_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.SendInvitation(ctx, session, invitation)
}
func (mm *metricsmw) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation invitations.Invitation, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "view_invitation").Add(1)
mm.latency.With("method", "view_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ViewInvitation(ctx, session, userID, domainID)
}
func (mm *metricsmw) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_invitations").Add(1)
mm.latency.With("method", "list_invitations").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ListInvitations(ctx, session, page)
}
func (mm *metricsmw) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "accept_invitation").Add(1)
mm.latency.With("method", "accept_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.AcceptInvitation(ctx, session, domainID)
}
func (mm *metricsmw) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "reject_invitation").Add(1)
mm.latency.With("method", "reject_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.RejectInvitation(ctx, session, domainID)
}
func (mm *metricsmw) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
defer func(begin time.Time) {
mm.counter.With("method", "delete_invitation").Add(1)
mm.latency.With("method", "delete_invitation").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.DeleteInvitation(ctx, session, userID, domainID)
}
-85
View File
@@ -1,85 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package middleware
import (
"context"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/pkg/authn"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var _ invitations.Service = (*tracing)(nil)
type tracing struct {
tracer trace.Tracer
svc invitations.Service
}
func Tracing(svc invitations.Service, tracer trace.Tracer) invitations.Service {
return &tracing{tracer, svc}
}
func (tm *tracing) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) (err error) {
ctx, span := tm.tracer.Start(ctx, "send_invitation", trace.WithAttributes(
attribute.String("domain_id", invitation.DomainID),
attribute.String("user_id", invitation.UserID),
))
defer span.End()
return tm.svc.SendInvitation(ctx, session, invitation)
}
func (tm *tracing) ViewInvitation(ctx context.Context, session authn.Session, userID, domain string) (invitation invitations.Invitation, err error) {
ctx, span := tm.tracer.Start(ctx, "view_invitation", trace.WithAttributes(
attribute.String("user_id", userID),
attribute.String("domain_id", domain),
))
defer span.End()
return tm.svc.ViewInvitation(ctx, session, userID, domain)
}
func (tm *tracing) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invs invitations.InvitationPage, err error) {
ctx, span := tm.tracer.Start(ctx, "list_invitations", trace.WithAttributes(
attribute.Int("limit", int(page.Limit)),
attribute.Int("offset", int(page.Offset)),
attribute.String("user_id", page.UserID),
attribute.String("domain_id", page.DomainID),
attribute.String("invited_by", page.InvitedBy),
))
defer span.End()
return tm.svc.ListInvitations(ctx, session, page)
}
func (tm *tracing) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "accept_invitation", trace.WithAttributes(
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.AcceptInvitation(ctx, session, domainID)
}
func (tm *tracing) RejectInvitation(ctx context.Context, session authn.Session, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "reject_invitation", trace.WithAttributes(
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.RejectInvitation(ctx, session, domainID)
}
func (tm *tracing) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) (err error) {
ctx, span := tm.tracer.Start(ctx, "delete_invitation", trace.WithAttributes(
attribute.String("user_id", userID),
attribute.String("domain_id", domainID),
))
defer span.End()
return tm.svc.DeleteInvitation(ctx, session, userID, domainID)
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package mocks provides a mock implementation of the invitations repository.
package mocks
-177
View File
@@ -1,177 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
invitations "github.com/absmach/supermq/invitations"
mock "github.com/stretchr/testify/mock"
)
// Repository is an autogenerated mock type for the Repository type
type Repository struct {
mock.Mock
}
// Create provides a mock function with given fields: ctx, invitation
func (_m *Repository) Create(ctx context.Context, invitation invitations.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for Create")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// Delete provides a mock function with given fields: ctx, userID, domainID
func (_m *Repository) Delete(ctx context.Context, userID string, domainID string) error {
ret := _m.Called(ctx, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for Delete")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Retrieve provides a mock function with given fields: ctx, userID, domainID
func (_m *Repository) Retrieve(ctx context.Context, userID string, domainID string) (invitations.Invitation, error) {
ret := _m.Called(ctx, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for Retrieve")
}
var r0 invitations.Invitation
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) (invitations.Invitation, error)); ok {
return rf(ctx, userID, domainID)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string) invitations.Invitation); ok {
r0 = rf(ctx, userID, domainID)
} else {
r0 = ret.Get(0).(invitations.Invitation)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = rf(ctx, userID, domainID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveAll provides a mock function with given fields: ctx, page
func (_m *Repository) RetrieveAll(ctx context.Context, page invitations.Page) (invitations.InvitationPage, error) {
ret := _m.Called(ctx, page)
if len(ret) == 0 {
panic("no return value specified for RetrieveAll")
}
var r0 invitations.InvitationPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, invitations.Page) (invitations.InvitationPage, error)); ok {
return rf(ctx, page)
}
if rf, ok := ret.Get(0).(func(context.Context, invitations.Page) invitations.InvitationPage); ok {
r0 = rf(ctx, page)
} else {
r0 = ret.Get(0).(invitations.InvitationPage)
}
if rf, ok := ret.Get(1).(func(context.Context, invitations.Page) error); ok {
r1 = rf(ctx, page)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateConfirmation provides a mock function with given fields: ctx, invitation
func (_m *Repository) UpdateConfirmation(ctx context.Context, invitation invitations.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for UpdateConfirmation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateRejection provides a mock function with given fields: ctx, invitation
func (_m *Repository) UpdateRejection(ctx context.Context, invitation invitations.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for UpdateRejection")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateToken provides a mock function with given fields: ctx, invitation
func (_m *Repository) UpdateToken(ctx context.Context, invitation invitations.Invitation) error {
ret := _m.Called(ctx, invitation)
if len(ret) == 0 {
panic("no return value specified for UpdateToken")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, invitations.Invitation) error); ok {
r0 = rf(ctx, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewRepository(t interface {
mock.TestingT
Cleanup(func())
}) *Repository {
mock := &Repository{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-162
View File
@@ -1,162 +0,0 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
authn "github.com/absmach/supermq/pkg/authn"
invitations "github.com/absmach/supermq/invitations"
mock "github.com/stretchr/testify/mock"
)
// Service is an autogenerated mock type for the Service type
type Service struct {
mock.Mock
}
// AcceptInvitation provides a mock function with given fields: ctx, session, domainID
func (_m *Service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error {
ret := _m.Called(ctx, session, domainID)
if len(ret) == 0 {
panic("no return value specified for AcceptInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = rf(ctx, session, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// DeleteInvitation provides a mock function with given fields: ctx, session, userID, domainID
func (_m *Service) DeleteInvitation(ctx context.Context, session authn.Session, userID string, domainID string) error {
ret := _m.Called(ctx, session, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for DeleteInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) error); ok {
r0 = rf(ctx, session, userID, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// ListInvitations provides a mock function with given fields: ctx, session, page
func (_m *Service) ListInvitations(ctx context.Context, session authn.Session, page invitations.Page) (invitations.InvitationPage, error) {
ret := _m.Called(ctx, session, page)
if len(ret) == 0 {
panic("no return value specified for ListInvitations")
}
var r0 invitations.InvitationPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Page) (invitations.InvitationPage, error)); ok {
return rf(ctx, session, page)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Page) invitations.InvitationPage); ok {
r0 = rf(ctx, session, page)
} else {
r0 = ret.Get(0).(invitations.InvitationPage)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, invitations.Page) error); ok {
r1 = rf(ctx, session, page)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RejectInvitation provides a mock function with given fields: ctx, session, domainID
func (_m *Service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error {
ret := _m.Called(ctx, session, domainID)
if len(ret) == 0 {
panic("no return value specified for RejectInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = rf(ctx, session, domainID)
} else {
r0 = ret.Error(0)
}
return r0
}
// SendInvitation provides a mock function with given fields: ctx, session, invitation
func (_m *Service) SendInvitation(ctx context.Context, session authn.Session, invitation invitations.Invitation) error {
ret := _m.Called(ctx, session, invitation)
if len(ret) == 0 {
panic("no return value specified for SendInvitation")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, invitations.Invitation) error); ok {
r0 = rf(ctx, session, invitation)
} else {
r0 = ret.Error(0)
}
return r0
}
// ViewInvitation provides a mock function with given fields: ctx, session, userID, domainID
func (_m *Service) ViewInvitation(ctx context.Context, session authn.Session, userID string, domainID string) (invitations.Invitation, error) {
ret := _m.Called(ctx, session, userID, domainID)
if len(ret) == 0 {
panic("no return value specified for ViewInvitation")
}
var r0 invitations.Invitation
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) (invitations.Invitation, error)); ok {
return rf(ctx, session, userID, domainID)
}
if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, string) invitations.Invitation); ok {
r0 = rf(ctx, session, userID, domainID)
} else {
r0 = ret.Get(0).(invitations.Invitation)
}
if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, string) error); ok {
r1 = rf(ctx, session, userID, domainID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewService(t interface {
mock.TestingT
Cleanup(func())
}) *Service {
mock := &Service{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
-5
View File
@@ -1,5 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package postgres provides a postgres implementation of the invitations repository.
package postgres
-48
View File
@@ -1,48 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
_ "github.com/jackc/pgx/v5/stdlib" // required for SQL access
migrate "github.com/rubenv/sql-migrate"
)
func Migration() *migrate.MemoryMigrationSource {
return &migrate.MemoryMigrationSource{
Migrations: []*migrate.Migration{
{
Id: "invitations_01",
// VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters
Up: []string{
`CREATE TABLE IF NOT EXISTS invitations (
invited_by VARCHAR(36) NOT NULL,
user_id VARCHAR(36) NOT NULL,
domain_id VARCHAR(36) NOT NULL,
token TEXT NOT NULL,
relation VARCHAR(254) NOT NULL,
created_at TIMESTAMP NOT NULL,
updated_at TIMESTAMP,
confirmed_at TIMESTAMP,
UNIQUE (user_id, domain_id),
PRIMARY KEY (user_id, domain_id)
)`,
},
Down: []string{
`DROP TABLE IF EXISTS invitations`,
},
},
{
Id: "invitations_02_add_rejection",
Up: []string{
`ALTER TABLE invitations
ADD COLUMN rejected_at TIMESTAMP`,
},
Down: []string{
`ALTER TABLE invitations
DROP COLUMN rejected_at`,
},
},
},
}
}
-254
View File
@@ -1,254 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/invitations"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/postgres"
)
type repository struct {
db postgres.Database
}
func NewRepository(db postgres.Database) invitations.Repository {
return &repository{db: db}
}
func (repo *repository) Create(ctx context.Context, invitation invitations.Invitation) (err error) {
q := `INSERT INTO invitations (invited_by, user_id, domain_id, token, relation, created_at)
VALUES (:invited_by, :user_id, :domain_id, :token, :relation, :created_at)`
dbInv := toDBInvitation(invitation)
if _, err = repo.db.NamedExecContext(ctx, q, dbInv); err != nil {
return postgres.HandleError(repoerr.ErrCreateEntity, err)
}
return nil
}
func (repo *repository) Retrieve(ctx context.Context, userID, domainID string) (invitations.Invitation, error) {
q := `SELECT invited_by, user_id, domain_id, token, relation, created_at, updated_at, confirmed_at, rejected_at FROM invitations WHERE user_id = :user_id AND domain_id = :domain_id;`
dbinv := dbInvitation{
UserID: userID,
DomainID: domainID,
}
rows, err := repo.db.NamedQueryContext(ctx, q, dbinv)
if err != nil {
return invitations.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
dbinv = dbInvitation{}
if rows.Next() {
if err = rows.StructScan(&dbinv); err != nil {
return invitations.Invitation{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
return toInvitation(dbinv), nil
}
return invitations.Invitation{}, repoerr.ErrNotFound
}
func (repo *repository) RetrieveAll(ctx context.Context, page invitations.Page) (invitations.InvitationPage, error) {
query := pageQuery(page)
q := fmt.Sprintf("SELECT invited_by, user_id, domain_id, relation, created_at, updated_at, confirmed_at, rejected_at FROM invitations %s LIMIT :limit OFFSET :offset;", query)
rows, err := repo.db.NamedQueryContext(ctx, q, page)
if err != nil {
return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var items []invitations.Invitation
for rows.Next() {
var dbinv dbInvitation
if err = rows.StructScan(&dbinv); err != nil {
return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
items = append(items, toInvitation(dbinv))
}
tq := fmt.Sprintf(`SELECT COUNT(*) FROM invitations %s`, query)
total, err := postgres.Total(ctx, repo.db, tq, page)
if err != nil {
return invitations.InvitationPage{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
invPage := invitations.InvitationPage{
Total: total,
Offset: page.Offset,
Limit: page.Limit,
Invitations: items,
}
return invPage, nil
}
func (repo *repository) UpdateToken(ctx context.Context, invitation invitations.Invitation) (err error) {
q := `UPDATE invitations SET token = :token, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id`
dbinv := toDBInvitation(invitation)
result, err := repo.db.NamedExecContext(ctx, q, dbinv)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo *repository) UpdateConfirmation(ctx context.Context, invitation invitations.Invitation) (err error) {
q := `UPDATE invitations SET confirmed_at = :confirmed_at, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id`
dbinv := toDBInvitation(invitation)
result, err := repo.db.NamedExecContext(ctx, q, dbinv)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo *repository) UpdateRejection(ctx context.Context, invitation invitations.Invitation) (err error) {
q := `UPDATE invitations SET rejected_at = :rejected_at, updated_at = :updated_at WHERE user_id = :user_id AND domain_id = :domain_id`
dbInv := toDBInvitation(invitation)
result, err := repo.db.NamedExecContext(ctx, q, dbInv)
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func (repo *repository) Delete(ctx context.Context, userID, domain string) (err error) {
q := `DELETE FROM invitations WHERE user_id = $1 AND domain_id = $2`
result, err := repo.db.ExecContext(ctx, q, userID, domain)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
if rows, _ := result.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
func pageQuery(pm invitations.Page) string {
var query []string
var emq string
if pm.DomainID != "" {
query = append(query, "domain_id = :domain_id")
}
if pm.UserID != "" {
query = append(query, "user_id = :user_id")
}
if pm.InvitedBy != "" {
query = append(query, "invited_by = :invited_by")
}
if pm.Relation != "" {
query = append(query, "relation = :relation")
}
if pm.InvitedByOrUserID != "" {
query = append(query, "(invited_by = :invited_by_or_user_id OR user_id = :invited_by_or_user_id)")
}
if pm.State == invitations.Accepted {
query = append(query, "confirmed_at IS NOT NULL")
}
if pm.State == invitations.Pending {
query = append(query, "confirmed_at IS NULL AND rejected_at IS NULL")
}
if pm.State == invitations.Rejected {
query = append(query, "rejected_at IS NOT NULL")
}
if len(query) > 0 {
emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND "))
}
return emq
}
type dbInvitation struct {
InvitedBy string `db:"invited_by"`
UserID string `db:"user_id"`
DomainID string `db:"domain_id"`
Token string `db:"token,omitempty"`
Relation string `db:"relation"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
ConfirmedAt sql.NullTime `db:"confirmed_at,omitempty"`
RejectedAt sql.NullTime `db:"rejected_at,omitempty"`
}
func toDBInvitation(inv invitations.Invitation) dbInvitation {
var updatedAt, confirmedAt, rejectedAt sql.NullTime
if inv.UpdatedAt != (time.Time{}) {
updatedAt = sql.NullTime{Time: inv.UpdatedAt, Valid: true}
}
if inv.ConfirmedAt != (time.Time{}) {
confirmedAt = sql.NullTime{Time: inv.ConfirmedAt, Valid: true}
}
if inv.RejectedAt != (time.Time{}) {
rejectedAt = sql.NullTime{Time: inv.RejectedAt, Valid: true}
}
return dbInvitation{
InvitedBy: inv.InvitedBy,
UserID: inv.UserID,
DomainID: inv.DomainID,
Token: inv.Token,
Relation: inv.Relation,
CreatedAt: inv.CreatedAt,
UpdatedAt: updatedAt,
ConfirmedAt: confirmedAt,
RejectedAt: rejectedAt,
}
}
func toInvitation(dbinv dbInvitation) invitations.Invitation {
var updatedAt, confirmedAt, rejectedAt time.Time
if dbinv.UpdatedAt.Valid {
updatedAt = dbinv.UpdatedAt.Time
}
if dbinv.ConfirmedAt.Valid {
confirmedAt = dbinv.ConfirmedAt.Time
}
if dbinv.RejectedAt.Valid {
rejectedAt = dbinv.RejectedAt.Time
}
return invitations.Invitation{
InvitedBy: dbinv.InvitedBy,
UserID: dbinv.UserID,
DomainID: dbinv.DomainID,
Token: dbinv.Token,
Relation: dbinv.Relation,
CreatedAt: dbinv.CreatedAt,
UpdatedAt: updatedAt,
ConfirmedAt: confirmedAt,
RejectedAt: rejectedAt,
}
}
-811
View File
@@ -1,811 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"context"
"fmt"
"strings"
"testing"
"time"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/invitations/postgres"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
invalidUUID = strings.Repeat("a", 37)
validToken = strings.Repeat("a", 1024)
relation = "relation"
)
func TestInvitationCreate(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
domainID := testsutil.GenerateUUID(t)
userID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
invitation invitations.Invitation
err error
}{
{
desc: "add new invitation successfully",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: userID,
DomainID: domainID,
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add new invitation with an confirmed_at date",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
ConfirmedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with duplicate invitation",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: userID,
DomainID: domainID,
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: repoerr.ErrConflict,
},
{
desc: "add invitation with invalid invitation invited_by",
invitation: invitations.Invitation{
InvitedBy: invalidUUID,
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with invalid invitation relation",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: strings.Repeat("a", 255),
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with invalid invitation domain",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: invalidUUID,
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with invalid invitation user id",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: invalidUUID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: repoerr.ErrMalformedEntity,
},
{
desc: "add invitation with empty invitation domain",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with empty invitation user id",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with empty invitation invited_by",
invitation: invitations.Invitation{
DomainID: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now(),
},
err: nil,
},
{
desc: "add invitation with empty invitation token",
invitation: invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
Relation: relation,
CreatedAt: time.Now(),
},
err: nil,
},
}
for _, tc := range cases {
switch err := repo.Create(context.Background(), tc.invitation); {
case err == nil:
assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
default:
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
}
func TestInvitationRetrieve(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
invitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: relation,
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
}
err := repo.Create(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
userID string
domainID string
response invitations.Invitation
err error
}{
{
desc: "retrieve invitations successfully",
userID: invitation.UserID,
domainID: invitation.DomainID,
response: invitation,
err: nil,
},
{
desc: "retrieve invitations with invalid invitation user id",
userID: testsutil.GenerateUUID(t),
domainID: invitation.DomainID,
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with invalid invitation domain_id",
userID: invitation.UserID,
domainID: testsutil.GenerateUUID(t),
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with invalid invitation user id and domain_id",
userID: testsutil.GenerateUUID(t),
domainID: testsutil.GenerateUUID(t),
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitation user id",
userID: "",
domainID: invitation.DomainID,
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitation domain_id",
userID: invitation.UserID,
domainID: "",
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
{
desc: "retrieve invitations with empty invitation user id and domain_id",
userID: "",
domainID: "",
response: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
page, err := repo.Retrieve(context.Background(), tc.userID, tc.domainID)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.response, page, fmt.Sprintf("desc: %s\n", tc.desc))
}
}
func TestInvitationRetrieveAll(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
num := 200
var items []invitations.Invitation
for i := 0; i < num; i++ {
invitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: fmt.Sprintf("%s-%d", relation, i),
CreatedAt: time.Now().UTC().Truncate(time.Microsecond),
}
err := repo.Create(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
invitation.Token = ""
items = append(items, invitation)
}
items[100].ConfirmedAt = time.Now().UTC().Truncate(time.Microsecond)
err := repo.UpdateConfirmation(context.Background(), items[100])
require.Nil(t, err, fmt.Sprintf("update invitation unexpected error: %s", err))
swap := items[100]
items = append(items[:100], items[101:]...)
items = append(items, swap)
cases := []struct {
desc string
page invitations.Page
response invitations.InvitationPage
err error
}{
{
desc: "retrieve invitations successfully",
page: invitations.Page{
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 10,
Invitations: items[:10],
},
err: nil,
},
{
desc: "retrieve invitations with offset",
page: invitations.Page{
Offset: 10,
Limit: 10,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 10,
Limit: 10,
Invitations: items[10:20],
},
},
{
desc: "retrieve invitations with limit",
page: invitations.Page{
Offset: 0,
Limit: 50,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 50,
Invitations: items[:50],
},
},
{
desc: "retrieve invitations with offset and limit",
page: invitations.Page{
Offset: 10,
Limit: 50,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 10,
Limit: 50,
Invitations: items[10:60],
},
},
{
desc: "retrieve invitations with offset out of range",
page: invitations.Page{
Offset: 1000,
Limit: 50,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 1000,
Limit: 50,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with offset and limit out of range",
page: invitations.Page{
Offset: 170,
Limit: 50,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 170,
Limit: 50,
Invitations: items[170:200],
},
},
{
desc: "retrieve invitations with limit out of range",
page: invitations.Page{
Offset: 0,
Limit: 1000,
},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 1000,
Invitations: items,
},
},
{
desc: "retrieve invitations with empty page",
page: invitations.Page{},
response: invitations.InvitationPage{
Total: uint64(num),
Offset: 0,
Limit: 0,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with domain",
page: invitations.Page{
DomainID: items[0].DomainID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with user id",
page: invitations.Page{
UserID: items[0].UserID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invited_by",
page: invitations.Page{
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invited_by_or_user_id",
page: invitations.Page{
InvitedByOrUserID: items[0].UserID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with relation",
page: invitations.Page{
Relation: relation + "-0",
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id and user id",
page: invitations.Page{
DomainID: items[0].DomainID,
UserID: items[0].UserID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id and invited_by",
page: invitations.Page{
DomainID: items[0].DomainID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with user id and invited_by",
page: invitations.Page{
UserID: items[0].UserID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id, user id and invited_by",
page: invitations.Page{
DomainID: items[0].DomainID,
UserID: items[0].UserID,
InvitedBy: items[0].InvitedBy,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with domain_id, user id, invited_by and relation",
page: invitations.Page{
DomainID: items[0].DomainID,
UserID: items[0].UserID,
InvitedBy: items[0].InvitedBy,
Relation: relation + "-0",
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[0]},
},
},
{
desc: "retrieve invitations with invalid domain",
page: invitations.Page{
DomainID: invalidUUID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid user id",
page: invitations.Page{
UserID: testsutil.GenerateUUID(t),
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid invited_by",
page: invitations.Page{
InvitedBy: invalidUUID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with invalid relation",
page: invitations.Page{
Relation: invalidUUID,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 0,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation(nil),
},
},
{
desc: "retrieve invitations with accepted state",
page: invitations.Page{
State: invitations.Accepted,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{items[num-1]},
},
},
{
desc: "retrieve invitations with pending state",
page: invitations.Page{
State: invitations.Pending,
Offset: 0,
Limit: 10,
},
response: invitations.InvitationPage{
Total: uint64(num - 1),
Offset: 0,
Limit: 10,
Invitations: items[0:10],
},
},
}
for _, tc := range cases {
page, err := repo.RetrieveAll(context.Background(), tc.page)
assert.Equal(t, tc.response.Total, page.Total, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Total, page.Total))
assert.Equal(t, tc.response.Offset, page.Offset, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Offset, page.Offset))
assert.Equal(t, tc.response.Limit, page.Limit, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.response.Limit, page.Limit))
assert.ElementsMatch(t, page.Invitations, tc.response.Invitations, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response.Invitations, page.Invitations))
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestInvitationUpdateToken(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
invitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
CreatedAt: time.Now(),
}
err := repo.Create(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation invitations.Invitation
err error
}{
{
desc: "update invitation successfully",
invitation: invitations.Invitation{
DomainID: invitation.DomainID,
UserID: invitation.UserID,
Token: validToken,
UpdatedAt: time.Now(),
},
err: nil,
},
{
desc: "update invitation with invalid user id",
invitation: invitations.Invitation{
UserID: testsutil.GenerateUUID(t),
DomainID: invitation.DomainID,
Token: validToken,
UpdatedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
{
desc: "update invitation with invalid domain_id",
invitation: invitations.Invitation{
UserID: invitation.UserID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
UpdatedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
err := repo.UpdateToken(context.Background(), tc.invitation)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestInvitationUpdateConfirmation(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
invitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
CreatedAt: time.Now(),
}
err := repo.Create(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation invitations.Invitation
err error
}{
{
desc: "update invitation successfully",
invitation: invitations.Invitation{
DomainID: invitation.DomainID,
UserID: invitation.UserID,
ConfirmedAt: time.Now(),
},
err: nil,
},
{
desc: "update invitation with invalid user id",
invitation: invitations.Invitation{
UserID: testsutil.GenerateUUID(t),
DomainID: invitation.UserID,
ConfirmedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
{
desc: "update invitation with invalid domain",
invitation: invitations.Invitation{
UserID: invitation.UserID,
DomainID: testsutil.GenerateUUID(t),
ConfirmedAt: time.Now(),
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
err := repo.UpdateConfirmation(context.Background(), tc.invitation)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
func TestInvitationDelete(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM invitations")
require.Nil(t, err, fmt.Sprintf("clean invitations unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
invitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
CreatedAt: time.Now(),
}
err := repo.Create(context.Background(), invitation)
require.Nil(t, err, fmt.Sprintf("create invitation unexpected error: %s", err))
cases := []struct {
desc string
invitation invitations.Invitation
err error
}{
{
desc: "delete invitation successfully",
invitation: invitations.Invitation{
UserID: invitation.UserID,
DomainID: invitation.DomainID,
},
err: nil,
},
{
desc: "delete invitation with invalid invitation id",
invitation: invitations.Invitation{
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
},
err: repoerr.ErrNotFound,
},
{
desc: "delete invitation with empty invitation id",
invitation: invitations.Invitation{},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
err := repo.Delete(context.Background(), tc.invitation.UserID, tc.invitation.DomainID)
assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
}
-96
View File
@@ -1,96 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"database/sql"
"fmt"
"log"
"os"
"testing"
"time"
ipostgres "github.com/absmach/supermq/invitations/postgres"
"github.com/absmach/supermq/pkg/postgres"
"github.com/jmoiron/sqlx"
dockertest "github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"go.opentelemetry.io/otel"
)
var (
db *sqlx.DB
database postgres.Database
tracer = otel.Tracer("repo_tests")
)
func TestMain(m *testing.M) {
pool, err := dockertest.NewPool("")
if err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
container, err := pool.RunWithOptions(&dockertest.RunOptions{
Repository: "postgres",
Tag: "16.2-alpine",
Env: []string{
"POSTGRES_USER=test",
"POSTGRES_PASSWORD=test",
"POSTGRES_DB=test",
"listen_addresses = '*'",
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err != nil {
log.Fatalf("Could not start container: %s", err)
}
port := container.GetPort("5432/tcp")
// exponential backoff-retry, because the application in the container might not be ready to accept connections yet
pool.MaxWait = 120 * time.Second
if err := pool.Retry(func() error {
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
db, err := sql.Open("pgx", url)
if err != nil {
return err
}
return db.Ping()
}); err != nil {
log.Fatalf("Could not connect to docker: %s", err)
}
dbConfig := postgres.Config{
Host: "localhost",
Port: port,
User: "test",
Pass: "test",
Name: "test",
SSLMode: "disable",
SSLCert: "",
SSLKey: "",
SSLRootCert: "",
}
if db, err = postgres.Setup(dbConfig, *ipostgres.Migration()); err != nil {
log.Fatalf("Could not setup test DB connection: %s", err)
}
if db, err = postgres.Connect(dbConfig); err != nil {
log.Fatalf("Could not setup test DB connection: %s", err)
}
database = postgres.NewDatabase(db, dbConfig, tracer)
code := m.Run()
// Defers will not be run when using os.Exit
db.Close()
if err := pool.Purge(container); err != nil {
log.Fatalf("Could not purge container: %s", err)
}
os.Exit(code)
}
-141
View File
@@ -1,141 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations
import (
"context"
"time"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/authn"
svcerr "github.com/absmach/supermq/pkg/errors/service"
mgsdk "github.com/absmach/supermq/pkg/sdk"
)
type service struct {
token grpcTokenV1.TokenServiceClient
repo Repository
sdk mgsdk.SDK
}
func NewService(token grpcTokenV1.TokenServiceClient, repo Repository, sdk mgsdk.SDK) Service {
return &service{
token: token,
repo: repo,
sdk: sdk,
}
}
func (svc *service) SendInvitation(ctx context.Context, session authn.Session, invitation Invitation) error {
if err := CheckRelation(invitation.Relation); err != nil {
return err
}
invitation.InvitedBy = session.UserID
joinToken, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: session.UserID, Type: uint32(auth.InvitationKey)})
if err != nil {
return err
}
invitation.Token = joinToken.GetAccessToken()
if invitation.Resend {
invitation.UpdatedAt = time.Now()
return svc.repo.UpdateToken(ctx, invitation)
}
invitation.CreatedAt = time.Now()
if err := svc.repo.Create(ctx, invitation); err != nil {
return err
}
return nil
}
func (svc *service) ViewInvitation(ctx context.Context, session authn.Session, userID, domainID string) (invitation Invitation, err error) {
inv, err := svc.repo.Retrieve(ctx, userID, domainID)
if err != nil {
return Invitation{}, err
}
inv.Token = ""
return inv, nil
}
func (svc *service) ListInvitations(ctx context.Context, session authn.Session, page Page) (invitations InvitationPage, err error) {
ip, err := svc.repo.RetrieveAll(ctx, page)
if err != nil {
return InvitationPage{}, err
}
return ip, nil
}
func (svc *service) AcceptInvitation(ctx context.Context, session authn.Session, domainID string) error {
inv, err := svc.repo.Retrieve(ctx, session.UserID, domainID)
if err != nil {
return err
}
if inv.UserID != session.UserID {
return svcerr.ErrAuthorization
}
if !inv.ConfirmedAt.IsZero() {
return svcerr.ErrInvitationAlreadyAccepted
}
if !inv.RejectedAt.IsZero() {
return svcerr.ErrInvitationAlreadyRejected
}
if _, sdkerr := svc.sdk.AddDomainRoleMembers(inv.DomainID, "admin", []string{session.UserID}, inv.Token); sdkerr != nil {
return sdkerr
}
inv.ConfirmedAt = time.Now()
inv.UpdatedAt = inv.ConfirmedAt
return svc.repo.UpdateConfirmation(ctx, inv)
}
func (svc *service) RejectInvitation(ctx context.Context, session authn.Session, domainID string) error {
inv, err := svc.repo.Retrieve(ctx, session.UserID, domainID)
if err != nil {
return err
}
if inv.UserID != session.UserID {
return svcerr.ErrAuthorization
}
if !inv.ConfirmedAt.IsZero() {
return svcerr.ErrInvitationAlreadyAccepted
}
if !inv.RejectedAt.IsZero() {
return svcerr.ErrInvitationAlreadyRejected
}
inv.RejectedAt = time.Now()
inv.UpdatedAt = inv.RejectedAt
return svc.repo.UpdateRejection(ctx, inv)
}
func (svc *service) DeleteInvitation(ctx context.Context, session authn.Session, userID, domainID string) error {
if session.UserID == userID {
return svc.repo.Delete(ctx, userID, domainID)
}
inv, err := svc.repo.Retrieve(ctx, userID, domainID)
if err != nil {
return err
}
if inv.InvitedBy == session.UserID {
return svc.repo.Delete(ctx, userID, domainID)
}
return svc.repo.Delete(ctx, userID, domainID)
}
-513
View File
@@ -1,513 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package invitations_test
import (
"context"
"testing"
"time"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
apiutil "github.com/absmach/supermq/api/http/util"
authmocks "github.com/absmach/supermq/auth/mocks"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/invitations/mocks"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
sdkmocks "github.com/absmach/supermq/pkg/sdk/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
var (
validInvitation = invitations.Invitation{
UserID: testsutil.GenerateUUID(&testing.T{}),
DomainID: testsutil.GenerateUUID(&testing.T{}),
Relation: policies.ContributorRelation,
}
validDomainUserID = "domain_user_id"
validUserID = "user_id"
validDomainID = "domain_id"
validToken = "valid_token"
invalidToken = "invalid"
)
func TestSendInvitation(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
svc := invitations.NewService(token, repo, nil)
cases := []struct {
desc string
token string
session authn.Session
tokenUserID string
req invitations.Invitation
err error
issueErr error
repoErr error
}{
{
desc: "send invitation successful",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
tokenUserID: testsutil.GenerateUUID(t),
req: validInvitation,
err: nil,
issueErr: nil,
repoErr: nil,
},
{
desc: "failed to issue token",
token: invalidToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
tokenUserID: testsutil.GenerateUUID(t),
req: validInvitation,
err: svcerr.ErrCreateEntity,
issueErr: svcerr.ErrCreateEntity,
repoErr: nil,
},
{
desc: "invalid relation",
token: validToken,
tokenUserID: testsutil.GenerateUUID(t),
req: invitations.Invitation{Relation: "invalid"},
err: apiutil.ErrInvalidRelation,
issueErr: nil,
repoErr: nil,
},
{
desc: "resend invitation",
token: invalidToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
tokenUserID: testsutil.GenerateUUID(t),
req: invitations.Invitation{
UserID: validInvitation.UserID,
DomainID: validInvitation.DomainID,
Relation: validInvitation.Relation,
Resend: true,
},
err: nil,
issueErr: nil,
repoErr: nil,
},
{
desc: "error during token issuance",
token: validToken,
tokenUserID: testsutil.GenerateUUID(t),
req: validInvitation,
err: svcerr.ErrAuthentication,
issueErr: svcerr.ErrAuthentication,
repoErr: nil,
},
}
for _, tc := range cases {
repocall1 := token.On("Issue", context.Background(), mock.Anything).Return(&grpcTokenV1.Token{AccessToken: tc.req.Token}, tc.issueErr)
repocall2 := repo.On("Create", context.Background(), mock.Anything).Return(tc.repoErr)
if tc.req.Resend {
repocall2 = repo.On("UpdateToken", context.Background(), mock.Anything).Return(tc.repoErr)
}
err := svc.SendInvitation(context.Background(), tc.session, tc.req)
assert.Equal(t, tc.err, err, tc.desc)
repocall1.Unset()
repocall2.Unset()
}
}
func TestViewInvitation(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
svc := invitations.NewService(token, repo, nil)
validInvitation := invitations.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Relation: policies.ContributorRelation,
CreatedAt: time.Now().Add(-time.Hour),
UpdatedAt: time.Now().Add(-time.Hour),
ConfirmedAt: time.Now().Add(-time.Hour),
}
cases := []struct {
desc string
token string
userID string
domainID string
session authn.Session
tokenUserID string
req invitations.Invitation
resp invitations.Invitation
err error
issueErr error
repoErr error
}{
{
desc: "view invitation successful",
token: validToken,
tokenUserID: testsutil.GenerateUUID(t),
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
resp: validInvitation,
err: nil,
repoErr: nil,
},
{
desc: "error retrieving invitation",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
tokenUserID: testsutil.GenerateUUID(t),
err: svcerr.ErrNotFound,
repoErr: svcerr.ErrNotFound,
},
{
desc: "valid invitation for the same user",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
resp: validInvitation,
tokenUserID: validInvitation.UserID,
err: nil,
repoErr: nil,
},
{
desc: "valid invitation for the invited user",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
tokenUserID: validInvitation.InvitedBy,
resp: validInvitation,
err: nil,
repoErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr)
inv, err := svc.ViewInvitation(context.Background(), tc.session, tc.userID, tc.domainID)
assert.Equal(t, tc.err, err, tc.desc)
assert.Equal(t, tc.resp, inv, tc.desc)
repocall1.Unset()
})
}
}
func TestListInvitations(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
svc := invitations.NewService(token, repo, nil)
validPage := invitations.Page{
Offset: 0,
Limit: 10,
}
validResp := invitations.InvitationPage{
Total: 1,
Offset: 0,
Limit: 10,
Invitations: []invitations.Invitation{
{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Relation: policies.ContributorRelation,
CreatedAt: time.Now().Add(-time.Hour),
UpdatedAt: time.Now().Add(-time.Hour),
ConfirmedAt: time.Now().Add(-time.Hour),
},
},
}
cases := []struct {
desc string
session authn.Session
page invitations.Page
resp invitations.InvitationPage
err error
repoErr error
}{
{
desc: "list invitations successful",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
page: validPage,
resp: validResp,
err: nil,
repoErr: nil,
},
{
desc: "list invitations unsuccessful",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: validUserID},
page: validPage,
err: repoerr.ErrViewEntity,
resp: invitations.InvitationPage{},
repoErr: repoerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall1 := repo.On("RetrieveAll", context.Background(), mock.Anything).Return(tc.resp, tc.repoErr)
resp, err := svc.ListInvitations(context.Background(), tc.session, tc.page)
assert.Equal(t, tc.err, err, tc.desc)
assert.Equal(t, tc.resp, resp, tc.desc)
repocall1.Unset()
})
}
}
func TestAcceptInvitation(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
sdksvc := new(sdkmocks.SDK)
svc := invitations.NewService(token, repo, sdksvc)
userID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
token string
domainID string
session authn.Session
resp invitations.Invitation
err error
repoErr error
sdkErr errors.SDKError
repoErr1 error
}{
{
desc: "accept invitation successful",
token: validToken,
domainID: "",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
resp: invitations.Invitation{
UserID: userID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.ContributorRelation,
},
err: nil,
repoErr: nil,
},
{
desc: "accept invitation with failed to retrieve all",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
err: svcerr.ErrNotFound,
repoErr: svcerr.ErrNotFound,
},
{
desc: "accept invitation with sdk err",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: "",
resp: invitations.Invitation{
UserID: userID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.ContributorRelation,
},
err: errors.NewSDKError(svcerr.ErrConflict),
repoErr: nil,
sdkErr: errors.NewSDKError(svcerr.ErrConflict),
},
{
desc: "accept invitation with failed update confirmation",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: "",
resp: invitations.Invitation{
UserID: userID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.ContributorRelation,
},
err: svcerr.ErrUpdateEntity,
repoErr: nil,
repoErr1: svcerr.ErrUpdateEntity,
},
{
desc: "accept invitation that is already confirmed",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: "",
resp: invitations.Invitation{
UserID: userID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.ContributorRelation,
ConfirmedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyAccepted,
repoErr: nil,
},
{
desc: "accept rejected invitation",
token: validToken,
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: "",
resp: invitations.Invitation{
UserID: userID,
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.ContributorRelation,
RejectedAt: time.Now(),
},
err: svcerr.ErrInvitationAlreadyRejected,
repoErr: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, tc.domainID).Return(tc.resp, tc.repoErr)
sdkcall := sdksvc.On("AddDomainRoleMembers", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, tc.sdkErr)
repocall2 := repo.On("UpdateConfirmation", context.Background(), mock.Anything).Return(tc.repoErr1)
err := svc.AcceptInvitation(context.Background(), tc.session, tc.domainID)
assert.Equal(t, tc.err, err, tc.desc)
repocall1.Unset()
sdkcall.Unset()
repocall2.Unset()
})
}
}
func TestDeleteInvitation(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
svc := invitations.NewService(token, repo, nil)
cases := []struct {
desc string
token string
userID string
domainID string
resp invitations.Invitation
err error
repoErr error
}{
{
desc: "delete invitations successful",
userID: testsutil.GenerateUUID(t),
domainID: testsutil.GenerateUUID(t),
resp: validInvitation,
err: nil,
repoErr: nil,
},
{
desc: "delete invitations for the same user",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
resp: validInvitation,
err: nil,
repoErr: nil,
},
{
desc: "delete invitations for the invited user",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
resp: validInvitation,
err: nil,
repoErr: nil,
},
{
desc: "error retrieving invitation",
token: validToken,
userID: validInvitation.UserID,
domainID: validInvitation.DomainID,
resp: invitations.Invitation{},
err: svcerr.ErrNotFound,
repoErr: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr)
repocall2 := repo.On("Delete", context.Background(), mock.Anything, mock.Anything).Return(tc.repoErr)
err := svc.DeleteInvitation(context.Background(), authn.Session{}, tc.userID, tc.domainID)
assert.Equal(t, tc.err, err, tc.desc)
repocall1.Unset()
repocall2.Unset()
})
}
}
func TestRejectInvitation(t *testing.T) {
repo := new(mocks.Repository)
token := new(authmocks.TokenServiceClient)
svc := invitations.NewService(token, repo, nil)
userID := validInvitation.UserID
cases := []struct {
desc string
session authn.Session
domainID string
resp invitations.Invitation
err error
repoErr error
repoErr1 error
}{
{
desc: "reject invitations for the same user",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: validInvitation.DomainID,
resp: validInvitation,
err: nil,
repoErr: nil,
repoErr1: nil,
},
{
desc: "reject invitations for the invited user",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: validInvitation.DomainID,
resp: invitations.Invitation{},
err: svcerr.ErrAuthorization,
repoErr: nil,
repoErr1: nil,
},
{
desc: "error retrieving invitation",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: validInvitation.DomainID,
resp: invitations.Invitation{},
err: repoerr.ErrNotFound,
repoErr: repoerr.ErrNotFound,
repoErr1: nil,
},
{
desc: "error updating rejection",
session: authn.Session{DomainUserID: validDomainUserID, DomainID: validDomainID, UserID: userID},
domainID: validInvitation.DomainID,
resp: validInvitation,
err: repoerr.ErrUpdateEntity,
repoErr: nil,
repoErr1: repoerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repocall1 := repo.On("Retrieve", context.Background(), mock.Anything, mock.Anything).Return(tc.resp, tc.repoErr)
repocall3 := repo.On("UpdateRejection", context.Background(), mock.Anything).Return(tc.repoErr1)
err := svc.RejectInvitation(context.Background(), tc.session, tc.domainID)
assert.Equal(t, tc.err, err, tc.desc)
repocall1.Unset()
repocall3.Unset()
})
}
}
+6 -6
View File
@@ -106,7 +106,7 @@ func (es *eventHandler) createDomainHandler(ctx context.Context, data map[string
return errors.Wrap(errCreateDomainEvent, err)
}
if _, err := es.repo.Save(ctx, d); err != nil {
if _, err := es.repo.SaveDomain(ctx, d); err != nil {
return errors.Wrap(errCreateDomainEvent, err)
}
if _, err := es.repo.AddRoles(ctx, rps); err != nil {
@@ -122,7 +122,7 @@ func (es *eventHandler) updateDomainHandler(ctx context.Context, data map[string
return errors.Wrap(errUpdateDomainEvent, err)
}
if _, err := es.repo.Update(
if _, err := es.repo.UpdateDomain(
ctx,
d.ID,
domains.DomainReq{
@@ -147,7 +147,7 @@ func (es *eventHandler) enableDomainHandler(ctx context.Context, data map[string
}
enabled := domains.EnabledStatus
if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &enabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &enabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
return errors.Wrap(errEnableDomainGroupEvent, err)
}
@@ -161,7 +161,7 @@ func (es *eventHandler) disableDomainHandler(ctx context.Context, data map[strin
}
disabled := domains.DisabledStatus
if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &disabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &disabled, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
return errors.Wrap(errDisableDomainGroupEvent, err)
}
@@ -175,7 +175,7 @@ func (es *eventHandler) freezeDomainHandler(ctx context.Context, data map[string
}
freeze := domains.FreezeStatus
if _, err := es.repo.Update(ctx, d.ID, domains.DomainReq{Status: &freeze, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
if _, err := es.repo.UpdateDomain(ctx, d.ID, domains.DomainReq{Status: &freeze, UpdatedBy: &d.UpdatedBy, UpdatedAt: &d.UpdatedAt}); err != nil {
return errors.Wrap(errFreezeDomainGroupEvent, err)
}
@@ -197,7 +197,7 @@ func (es *eventHandler) deleteDomainHandler(ctx context.Context, data map[string
return errors.Wrap(errDeleteDomainEvent, err)
}
if err := es.repo.Delete(ctx, d.ID); err != nil {
if err := es.repo.DeleteDomain(ctx, d.ID); err != nil {
return errors.Wrap(errDeleteDomainEvent, err)
}
+2 -2
View File
@@ -51,8 +51,8 @@ func setupDomains() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
mux := chi.NewRouter()
authn := new(authnmocks.Authentication)
mux = domainapi.MakeHandler(svc, authn, mux, logger, "")
return httptest.NewServer(mux), svc, authn
handler := domainapi.MakeHandler(svc, authn, mux, logger, "")
return httptest.NewServer(handler), svc, authn
}
func TestCreateDomain(t *testing.T) {
+16 -16
View File
@@ -18,16 +18,16 @@ const (
)
type Invitation struct {
InvitedBy string `json:"invited_by"`
UserID string `json:"user_id"`
DomainID string `json:"domain_id"`
Token string `json:"token,omitempty"`
Relation string `json:"relation,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
ConfirmedAt time.Time `json:"confirmed_at,omitempty"`
RejectedAt time.Time `json:"rejected_at,omitempty"`
Resend bool `json:"resend,omitempty"`
InvitedBy string `json:"invited_by"`
InviteeUserID string `json:"invitee_user_id"`
DomainID string `json:"domain_id"`
RoleID string `json:"role_id,omitempty"`
RoleName string `json:"role_name,omitempty"`
Actions []string `json:"actions,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
ConfirmedAt time.Time `json:"confirmed_at,omitempty"`
RejectedAt time.Time `json:"rejected_at,omitempty"`
}
type InvitationPage struct {
@@ -43,7 +43,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error)
return errors.NewSDKError(err)
}
url := sdk.invitationsURL + "/" + invitationsEndpoint
url := sdk.domainsURL + "/" + domainsEndpoint + "/" + invitation.DomainID + "/" + invitationsEndpoint
_, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusCreated)
@@ -51,7 +51,7 @@ func (sdk mgSDK) SendInvitation(invitation Invitation, token string) (err error)
}
func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitation, err error) {
url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID
url := sdk.domainsURL + "/" + domainsEndpoint + "/" + domainID + "/" + invitationsEndpoint + "/" + userID
_, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK)
if sdkerr != nil {
@@ -66,7 +66,7 @@ func (sdk mgSDK) Invitation(userID, domainID, token string) (invitation Invitati
}
func (sdk mgSDK) Invitations(pm PageMetadata, token string) (invitations InvitationPage, err error) {
url, err := sdk.withQueryParams(sdk.invitationsURL, invitationsEndpoint, pm)
url, err := sdk.withQueryParams(sdk.domainsURL, invitationsEndpoint, pm)
if err != nil {
return InvitationPage{}, errors.NewSDKError(err)
}
@@ -95,7 +95,7 @@ func (sdk mgSDK) AcceptInvitation(domainID, token string) (err error) {
return errors.NewSDKError(err)
}
url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + acceptEndpoint
url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + acceptEndpoint
_, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusNoContent)
@@ -113,7 +113,7 @@ func (sdk mgSDK) RejectInvitation(domainID, token string) (err error) {
return errors.NewSDKError(err)
}
url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + rejectEndpoint
url := sdk.domainsURL + "/" + invitationsEndpoint + "/" + rejectEndpoint
_, _, sdkerr := sdk.processRequest(http.MethodPost, url, token, data, nil, http.StatusNoContent)
@@ -121,7 +121,7 @@ func (sdk mgSDK) RejectInvitation(domainID, token string) (err error) {
}
func (sdk mgSDK) DeleteInvitation(userID, domainID, token string) (err error) {
url := sdk.invitationsURL + "/" + invitationsEndpoint + "/" + userID + "/" + domainID
url := sdk.domainsURL + "/" + domainsEndpoint + "/" + domainID + "/" + invitationsEndpoint + "/" + userID
_, _, sdkerr := sdk.processRequest(http.MethodDelete, url, token, nil, nil, http.StatusNoContent)
+102 -118
View File
@@ -6,21 +6,15 @@ package sdk_test
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/invitations/api"
"github.com/absmach/supermq/invitations/mocks"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmocks "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
policies "github.com/absmach/supermq/pkg/policies"
sdk "github.com/absmach/supermq/pkg/sdk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
@@ -31,29 +25,19 @@ var (
invitation = convertInvitation(sdkInvitation)
)
func setupInvitations() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) {
svc := new(mocks.Service)
logger := smqlog.NewMock()
authn := new(authnmocks.Authentication)
mux := api.MakeHandler(svc, logger, authn, "test")
return httptest.NewServer(mux), svc, authn
}
func TestSendInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
sendInvitationReq := sdk.Invitation{
UserID: invitation.UserID,
DomainID: invitation.DomainID,
Relation: invitation.Relation,
Resend: invitation.Resend,
InviteeUserID: invitation.InviteeUserID,
DomainID: invitation.DomainID,
RoleID: invitation.RoleID,
}
cases := []struct {
@@ -61,7 +45,7 @@ func TestSendInvitation(t *testing.T) {
token string
session smqauthn.Session
sendInvitationReq sdk.Invitation
svcReq invitations.Invitation
svcReq domains.Invitation
authenticateErr error
svcErr error
err error
@@ -86,7 +70,7 @@ func TestSendInvitation(t *testing.T) {
desc: "send invitation with empty token",
token: "",
sendInvitationReq: sendInvitationReq,
svcReq: invitations.Invitation{},
svcReq: domains.Invitation{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
@@ -94,42 +78,38 @@ func TestSendInvitation(t *testing.T) {
desc: "send invitation with empty userID",
token: validToken,
sendInvitationReq: sdk.Invitation{
UserID: "",
DomainID: invitation.DomainID,
Relation: invitation.Relation,
Resend: invitation.Resend,
InviteeUserID: "",
DomainID: invitation.DomainID,
RoleID: invitation.RoleID,
},
svcReq: invitations.Invitation{},
svcReq: domains.Invitation{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
},
{
desc: "send invitation with invalid relation",
desc: "send invitation with empty role ID",
token: validToken,
sendInvitationReq: sdk.Invitation{
UserID: invitation.UserID,
DomainID: invitation.DomainID,
Relation: "invalid",
Resend: invitation.Resend,
InviteeUserID: invitation.InviteeUserID,
DomainID: invitation.DomainID,
RoleID: "",
},
svcReq: invitations.Invitation{},
svcReq: domains.Invitation{},
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrInvalidRelation), http.StatusInternalServerError),
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
},
{
desc: "send inviation with invalid domainID",
token: validToken,
sendInvitationReq: sdk.Invitation{
UserID: invitation.UserID,
DomainID: wrongID,
Relation: invitation.Relation,
Resend: invitation.Resend,
InviteeUserID: invitation.InviteeUserID,
DomainID: wrongID,
RoleID: invitation.RoleID,
},
svcReq: invitations.Invitation{
UserID: invitation.UserID,
DomainID: wrongID,
Relation: invitation.Relation,
Resend: invitation.Resend,
svcReq: domains.Invitation{
InviteeUserID: invitation.InviteeUserID,
DomainID: wrongID,
RoleID: invitation.RoleID,
},
svcErr: svcerr.ErrCreateEntity,
err: errors.NewSDKErrorWithStatus(svcerr.ErrCreateEntity, http.StatusUnprocessableEntity),
@@ -138,7 +118,11 @@ func TestSendInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{UserID: tc.sendInvitationReq.UserID, DomainID: tc.sendInvitationReq.DomainID}
tc.session = smqauthn.Session{
UserID: tc.sendInvitationReq.InviteeUserID,
DomainID: tc.sendInvitationReq.DomainID,
DomainUserID: tc.sendInvitationReq.DomainID + "_" + tc.sendInvitationReq.InviteeUserID,
}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("SendInvitation", mock.Anything, tc.session, tc.svcReq).Return(tc.svcErr)
@@ -155,11 +139,11 @@ func TestSendInvitation(t *testing.T) {
}
func TestViewInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
@@ -169,7 +153,7 @@ func TestViewInvitation(t *testing.T) {
session smqauthn.Session
userID string
domainID string
svcRes invitations.Invitation
svcRes domains.Invitation
svcErr error
authenticateErr error
response sdk.Invitation
@@ -178,7 +162,7 @@ func TestViewInvitation(t *testing.T) {
{
desc: "view invitation successfully",
token: validToken,
userID: invitation.UserID,
userID: invitation.InviteeUserID,
domainID: invitation.DomainID,
svcRes: invitation,
svcErr: nil,
@@ -188,9 +172,9 @@ func TestViewInvitation(t *testing.T) {
{
desc: "view invitation with invalid token",
token: invalidToken,
userID: invitation.UserID,
userID: invitation.InviteeUserID,
domainID: invitation.DomainID,
svcRes: invitations.Invitation{},
svcRes: domains.Invitation{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
@@ -198,29 +182,29 @@ func TestViewInvitation(t *testing.T) {
{
desc: "view invitation with empty token",
token: "",
userID: invitation.UserID,
userID: invitation.InviteeUserID,
domainID: invitation.DomainID,
svcRes: invitations.Invitation{},
svcRes: domains.Invitation{},
svcErr: nil,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "view invitation with empty userID",
desc: "view invitation with empty domainID",
token: validToken,
userID: "",
domainID: invitation.DomainID,
svcRes: invitations.Invitation{},
userID: invitation.InviteeUserID,
domainID: "",
svcRes: domains.Invitation{},
svcErr: nil,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest),
},
{
desc: "view invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
userID: invitation.InviteeUserID,
domainID: wrongID,
svcRes: invitations.Invitation{},
svcRes: domains.Invitation{},
svcErr: svcerr.ErrNotFound,
response: sdk.Invitation{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
@@ -229,7 +213,7 @@ func TestViewInvitation(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID}
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.userID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("ViewInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcRes, tc.svcErr)
@@ -247,11 +231,11 @@ func TestViewInvitation(t *testing.T) {
}
func TestListInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
@@ -260,8 +244,8 @@ func TestListInvitation(t *testing.T) {
token string
session smqauthn.Session
pageMeta sdk.PageMetadata
svcReq invitations.Page
svcRes invitations.InvitationPage
svcReq domains.InvitationPageMeta
svcRes domains.InvitationPage
svcErr error
authenticateErr error
response sdk.InvitationPage
@@ -274,13 +258,13 @@ func TestListInvitation(t *testing.T) {
Offset: 0,
Limit: 10,
},
svcReq: invitations.Page{
svcReq: domains.InvitationPageMeta{
Offset: 0,
Limit: 10,
},
svcRes: invitations.InvitationPage{
svcRes: domains.InvitationPage{
Total: 1,
Invitations: []invitations.Invitation{invitation},
Invitations: []domains.Invitation{invitation},
},
svcErr: nil,
response: sdk.InvitationPage{
@@ -296,11 +280,11 @@ func TestListInvitation(t *testing.T) {
Offset: 0,
Limit: 10,
},
svcReq: invitations.Page{
svcReq: domains.InvitationPageMeta{
Offset: 0,
Limit: 10,
},
svcRes: invitations.InvitationPage{},
svcRes: domains.InvitationPage{},
authenticateErr: svcerr.ErrAuthentication,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
@@ -309,7 +293,7 @@ func TestListInvitation(t *testing.T) {
desc: "list invitations with empty token",
token: "",
pageMeta: sdk.PageMetadata{},
svcRes: invitations.InvitationPage{},
svcRes: domains.InvitationPage{},
svcErr: nil,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
@@ -321,8 +305,8 @@ func TestListInvitation(t *testing.T) {
Offset: 0,
Limit: 101,
},
svcReq: invitations.Page{},
svcRes: invitations.InvitationPage{},
svcReq: domains.InvitationPageMeta{},
svcRes: domains.InvitationPage{},
svcErr: nil,
response: sdk.InvitationPage{},
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest),
@@ -349,11 +333,11 @@ func TestListInvitation(t *testing.T) {
}
func TestAcceptInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
@@ -415,11 +399,11 @@ func TestAcceptInvitation(t *testing.T) {
}
func TestRejectInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
@@ -481,11 +465,11 @@ func TestRejectInvitation(t *testing.T) {
}
func TestDeleteInvitation(t *testing.T) {
is, svc, auth := setupInvitations()
is, svc, auth := setupDomains()
defer is.Close()
conf := sdk.Config{
InvitationsURL: is.URL,
DomainsURL: is.URL,
}
mgsdk := sdk.NewSDK(conf)
@@ -493,64 +477,64 @@ func TestDeleteInvitation(t *testing.T) {
desc string
token string
session smqauthn.Session
userID string
inviteeUserID string
domainID string
authenticateErr error
svcErr error
err error
}{
{
desc: "delete invitation successfully",
token: validToken,
userID: invitation.UserID,
domainID: invitation.DomainID,
svcErr: nil,
err: nil,
desc: "delete invitation successfully",
token: validToken,
inviteeUserID: invitation.InviteeUserID,
domainID: invitation.DomainID,
svcErr: nil,
err: nil,
},
{
desc: "delete invitation with invalid token",
token: invalidToken,
userID: invitation.UserID,
inviteeUserID: invitation.InviteeUserID,
domainID: invitation.DomainID,
authenticateErr: svcerr.ErrAuthentication,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized),
},
{
desc: "delete invitation with empty token",
token: "",
userID: invitation.UserID,
domainID: invitation.DomainID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
desc: "delete invitation with empty token",
token: "",
inviteeUserID: invitation.InviteeUserID,
domainID: invitation.DomainID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized),
},
{
desc: "delete invitation with empty userID",
token: validToken,
userID: "",
domainID: invitation.DomainID,
svcErr: nil,
err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest),
desc: "delete invitation with empty domainID",
token: validToken,
inviteeUserID: invitation.InviteeUserID,
domainID: "",
svcErr: nil,
err: errors.NewSDKErrorWithStatus(apiutil.ErrMissingDomainID, http.StatusBadRequest),
},
{
desc: "delete invitation with invalid domainID",
token: validToken,
userID: invitation.UserID,
domainID: wrongID,
svcErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
desc: "delete invitation with invalid domainID",
token: validToken,
inviteeUserID: invitation.InviteeUserID,
domainID: wrongID,
svcErr: svcerr.ErrNotFound,
err: errors.NewSDKErrorWithStatus(svcerr.ErrNotFound, http.StatusNotFound),
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
if tc.token == valid {
tc.session = smqauthn.Session{UserID: tc.userID, DomainID: tc.domainID}
tc.session = smqauthn.Session{UserID: tc.inviteeUserID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + tc.inviteeUserID}
}
authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr)
svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID).Return(tc.svcErr)
err := mgsdk.DeleteInvitation(tc.userID, tc.domainID, tc.token)
svcCall := svc.On("DeleteInvitation", mock.Anything, tc.session, tc.inviteeUserID, tc.domainID).Return(tc.svcErr)
err := mgsdk.DeleteInvitation(tc.inviteeUserID, tc.domainID, tc.token)
assert.Equal(t, tc.err, err)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "DeleteInvitation", mock.Anything, tc.session, tc.userID, tc.domainID)
ok := svcCall.Parent.AssertCalled(t, "DeleteInvitation", mock.Anything, tc.session, tc.inviteeUserID, tc.domainID)
assert.True(t, ok)
}
svcCall.Unset()
@@ -563,13 +547,13 @@ func generateTestInvitation(t *testing.T) sdk.Invitation {
createdAt, err := time.Parse(time.RFC3339, "2024-01-01T00:00:00Z")
assert.Nil(t, err, fmt.Sprintf("Unexpected error parsing time: %v", err))
return sdk.Invitation{
InvitedBy: testsutil.GenerateUUID(t),
UserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
Token: validToken,
Relation: policies.MemberRelation,
CreatedAt: createdAt,
UpdatedAt: createdAt,
Resend: false,
InvitedBy: testsutil.GenerateUUID(t),
InviteeUserID: testsutil.GenerateUUID(t),
DomainID: testsutil.GenerateUUID(t),
RoleID: testsutil.GenerateUUID(t),
RoleName: "admin",
Actions: []string{"read", "update"},
CreatedAt: createdAt,
UpdatedAt: createdAt,
}
}
-3
View File
@@ -1318,7 +1318,6 @@ type mgSDK struct {
groupsURL string
channelsURL string
domainsURL string
invitationsURL string
journalURL string
HostURL string
@@ -1336,7 +1335,6 @@ type Config struct {
GroupsURL string
ChannelsURL string
DomainsURL string
InvitationsURL string
JournalURL string
HostURL string
@@ -1355,7 +1353,6 @@ func NewSDK(conf Config) SDK {
groupsURL: conf.GroupsURL,
channelsURL: conf.ChannelsURL,
domainsURL: conf.DomainsURL,
invitationsURL: conf.InvitationsURL,
journalURL: conf.JournalURL,
HostURL: conf.HostURL,
+13 -12
View File
@@ -12,9 +12,9 @@ import (
mgchannels "github.com/absmach/supermq/channels"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
groups "github.com/absmach/supermq/groups"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/invitations"
"github.com/absmach/supermq/journal"
"github.com/absmach/supermq/pkg/roles"
sdk "github.com/absmach/supermq/pkg/sdk"
@@ -218,17 +218,18 @@ func convertChannel(g sdk.Channel) mgchannels.Channel {
}
}
func convertInvitation(i sdk.Invitation) invitations.Invitation {
return invitations.Invitation{
InvitedBy: i.InvitedBy,
UserID: i.UserID,
DomainID: i.DomainID,
Token: i.Token,
Relation: i.Relation,
CreatedAt: i.CreatedAt,
UpdatedAt: i.UpdatedAt,
ConfirmedAt: i.ConfirmedAt,
Resend: i.Resend,
func convertInvitation(i sdk.Invitation) domains.Invitation {
return domains.Invitation{
InvitedBy: i.InvitedBy,
InviteeUserID: i.InviteeUserID,
DomainID: i.DomainID,
RoleID: i.RoleID,
RoleName: i.RoleName,
Actions: i.Actions,
CreatedAt: i.CreatedAt,
UpdatedAt: i.UpdatedAt,
ConfirmedAt: i.ConfirmedAt,
RejectedAt: i.RejectedAt,
}
}