mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-3093 - User email verification (#3101)
Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
@@ -73,6 +73,7 @@ type AuthNRes struct {
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id
|
||||
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
|
||||
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role
|
||||
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` // verified user
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -128,6 +129,13 @@ func (x *AuthNRes) GetUserRole() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *AuthNRes) GetVerified() bool {
|
||||
if x != nil {
|
||||
return x.Verified
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type AuthZReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` // Domain
|
||||
@@ -378,11 +386,12 @@ const file_auth_v1_auth_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x12auth/v1/auth.proto\x12\aauth.v1\" \n" +
|
||||
"\bAuthNReq\x12\x14\n" +
|
||||
"\x05token\x18\x01 \x01(\tR\x05token\"P\n" +
|
||||
"\x05token\x18\x01 \x01(\tR\x05token\"l\n" +
|
||||
"\bAuthNRes\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" +
|
||||
"\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" +
|
||||
"\tuser_role\x18\x03 \x01(\rR\buserRole\"\xa2\x02\n" +
|
||||
"\tuser_role\x18\x03 \x01(\rR\buserRole\x12\x1a\n" +
|
||||
"\bverified\x18\x04 \x01(\bR\bverified\"\xa2\x02\n" +
|
||||
"\bAuthZReq\x12\x16\n" +
|
||||
"\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" +
|
||||
"\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" +
|
||||
|
||||
@@ -29,6 +29,7 @@ type IssueReq struct {
|
||||
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
|
||||
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -84,9 +85,17 @@ func (x *IssueReq) GetType() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *IssueReq) GetVerified() bool {
|
||||
if x != nil {
|
||||
return x.Verified
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type RefreshReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
|
||||
Verified bool `protobuf:"varint,2,opt,name=verified,proto3" json:"verified,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -128,6 +137,13 @@ func (x *RefreshReq) GetRefreshToken() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RefreshReq) GetVerified() bool {
|
||||
if x != nil {
|
||||
return x.Verified
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// If a token is not carrying any information itself, the type
|
||||
// field can be used to determine how to validate the token.
|
||||
// Also, different tokens can be encoded in different ways.
|
||||
@@ -195,14 +211,16 @@ var File_token_v1_token_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_token_v1_token_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"T\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"p\n" +
|
||||
"\bIssueReq\x12\x17\n" +
|
||||
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
|
||||
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
|
||||
"\x04type\x18\x03 \x01(\rR\x04type\"1\n" +
|
||||
"\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" +
|
||||
"\bverified\x18\x04 \x01(\bR\bverified\"M\n" +
|
||||
"\n" +
|
||||
"RefreshReq\x12#\n" +
|
||||
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\"\x87\x01\n" +
|
||||
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" +
|
||||
"\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" +
|
||||
"\x05Token\x12!\n" +
|
||||
"\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" +
|
||||
"\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" +
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/auth"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type sessionKeyType string
|
||||
|
||||
const SessionKey = sessionKeyType("session")
|
||||
|
||||
func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := apiutil.ExtractBearerToken(r)
|
||||
if token == "" {
|
||||
EncodeError(r.Context(), apiutil.ErrBearerToken, w)
|
||||
return
|
||||
}
|
||||
resp, err := authn.Authenticate(r.Context(), token)
|
||||
if err != nil {
|
||||
EncodeError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if domainCheck {
|
||||
domain := chi.URLParam(r, "domainID")
|
||||
if domain == "" {
|
||||
EncodeError(r.Context(), apiutil.ErrMissingDomainID, w)
|
||||
return
|
||||
}
|
||||
resp.DomainID = domain
|
||||
switch resp.Role {
|
||||
case smqauthn.AdminRole:
|
||||
resp.DomainUserID = resp.UserID
|
||||
case smqauthn.UserRole:
|
||||
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), SessionKey, resp)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -110,6 +111,13 @@ func ValidateUUID(extID string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateEmail(email string) (err error) {
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return apiutil.ErrInvalidEmail
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateName validates name format.
|
||||
func ValidateName(id string) error {
|
||||
if !nameRegExp.MatchString(id) {
|
||||
@@ -201,6 +209,7 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
|
||||
errors.Contains(err, apiutil.ErrMissingSecret),
|
||||
errors.Contains(err, errors.ErrMalformedEntity),
|
||||
errors.Contains(err, apiutil.ErrMissingID),
|
||||
errors.Contains(err, apiutil.ErrInvalidVerification),
|
||||
errors.Contains(err, apiutil.ErrMissingName),
|
||||
errors.Contains(err, apiutil.ErrMissingEmail),
|
||||
errors.Contains(err, apiutil.ErrInvalidEmail),
|
||||
|
||||
@@ -268,4 +268,10 @@ var (
|
||||
|
||||
// ErrMissingUsernameEmail indicates missing user name / email.
|
||||
ErrMissingUsernameEmail = errors.New("missing username / email")
|
||||
|
||||
// ErrInvalidVerification indicates invalid email verification.
|
||||
ErrInvalidVerification = errors.New("invalid verification")
|
||||
|
||||
// ErrEmailNotVerified indicates invalid email not verified.
|
||||
ErrEmailNotVerified = errors.New("email not verified")
|
||||
)
|
||||
|
||||
@@ -607,6 +607,54 @@ paths:
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/users/send-verification:
|
||||
post:
|
||||
operationId: sendVerification
|
||||
tags:
|
||||
- Users
|
||||
summary: Sends a verification email
|
||||
description: |
|
||||
Sends a verification email to the user.
|
||||
security:
|
||||
- bearerAuth: []
|
||||
responses:
|
||||
"200":
|
||||
description: Sent verification email if registered.
|
||||
"400":
|
||||
description: Failed due to malformed JSON.
|
||||
"401":
|
||||
description: Missing or invalid access token provided.
|
||||
"415":
|
||||
description: Missing or invalid content type.
|
||||
"422":
|
||||
description: Database can't process request.
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/verify-email:
|
||||
get:
|
||||
operationId: verifyEmail
|
||||
tags:
|
||||
- Users
|
||||
summary: Verify user's email
|
||||
description: |
|
||||
Verify user's email using the token from the verification link.
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/VerificationToken"
|
||||
responses:
|
||||
"200":
|
||||
description: Email verified successfully.
|
||||
"400":
|
||||
description: Failed due to malformed query parameters.
|
||||
"401":
|
||||
description: Missing or invalid access token provided.
|
||||
"404":
|
||||
description: A non-existent entity request.
|
||||
"422":
|
||||
description: Database can't process request.
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/health:
|
||||
get:
|
||||
operationId: health
|
||||
@@ -1240,6 +1288,14 @@ components:
|
||||
required: false
|
||||
example: "0"
|
||||
|
||||
VerificationToken:
|
||||
name: token
|
||||
description: Verification token.
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
required: true
|
||||
|
||||
requestBodies:
|
||||
UserCreateReq:
|
||||
description: JSON-formatted document describing the new user to be registered
|
||||
|
||||
@@ -75,7 +75,7 @@ func (client authGrpcClient) Authenticate(ctx context.Context, token *grpcAuthV1
|
||||
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
|
||||
}
|
||||
ir := res.(authenticateRes)
|
||||
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole)}, nil
|
||||
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole), Verified: ir.verified}, nil
|
||||
}
|
||||
|
||||
func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
@@ -85,7 +85,7 @@ func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
|
||||
func decodeIdentifyResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(*grpcAuthV1.AuthNRes)
|
||||
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole)}, nil
|
||||
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole), verified: res.GetVerified()}, nil
|
||||
}
|
||||
|
||||
func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
|
||||
|
||||
@@ -23,7 +23,7 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return authenticateRes{}, err
|
||||
}
|
||||
|
||||
return authenticateRes{userID: key.Subject, userRole: key.Role}, nil
|
||||
return authenticateRes{userID: key.Subject, userRole: key.Role, verified: key.Verified}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ type authenticateRes struct {
|
||||
id string
|
||||
userID string
|
||||
userRole smqauth.Role
|
||||
verified bool
|
||||
}
|
||||
|
||||
type authorizeRes struct {
|
||||
|
||||
@@ -82,7 +82,7 @@ func decodeAuthenticateRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
|
||||
func encodeAuthenticateResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(authenticateRes)
|
||||
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole)}, nil
|
||||
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole), Verified: res.verified}, nil
|
||||
}
|
||||
|
||||
func encodeAuthenticatePATResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
|
||||
@@ -56,6 +56,7 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.GetVerified(),
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
|
||||
@@ -69,6 +70,7 @@ func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
UserId: req.userID,
|
||||
UserRole: uint32(req.userRole),
|
||||
Type: uint32(req.keyType),
|
||||
Verified: req.verified,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -80,7 +82,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken()})
|
||||
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken(), verified: req.GetVerified()})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
|
||||
}
|
||||
@@ -89,7 +91,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr
|
||||
|
||||
func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(refreshReq)
|
||||
return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken}, nil
|
||||
return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken, Verified: req.verified}, nil
|
||||
}
|
||||
|
||||
func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
|
||||
@@ -18,9 +18,10 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
key := auth.Key{
|
||||
Type: req.keyType,
|
||||
Subject: req.userID,
|
||||
Role: req.userRole,
|
||||
Type: req.keyType,
|
||||
Subject: req.userID,
|
||||
Role: req.userRole,
|
||||
Verified: req.verified,
|
||||
}
|
||||
tkn, err := svc.Issue(ctx, "", key)
|
||||
if err != nil {
|
||||
@@ -42,7 +43,7 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return issueRes{}, err
|
||||
}
|
||||
|
||||
key := auth.Key{Type: auth.RefreshKey}
|
||||
key := auth.Key{Type: auth.RefreshKey, Verified: req.verified}
|
||||
tkn, err := svc.Issue(ctx, req.refreshToken, key)
|
||||
if err != nil {
|
||||
return issueRes{}, err
|
||||
|
||||
@@ -12,6 +12,7 @@ type issueReq struct {
|
||||
userID string
|
||||
userRole auth.Role
|
||||
keyType auth.KeyType
|
||||
verified bool
|
||||
}
|
||||
|
||||
func (req issueReq) validate() error {
|
||||
@@ -27,6 +28,7 @@ func (req issueReq) validate() error {
|
||||
|
||||
type refreshReq struct {
|
||||
refreshToken string
|
||||
verified bool
|
||||
}
|
||||
|
||||
func (req refreshReq) validate() error {
|
||||
|
||||
@@ -58,12 +58,13 @@ func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.Verified,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcTokenV1.RefreshReq)
|
||||
return refreshReq{refreshToken: req.GetRefreshToken()}, nil
|
||||
return refreshReq{refreshToken: req.GetRefreshToken(), verified: req.Verified}, nil
|
||||
}
|
||||
|
||||
func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
|
||||
@@ -21,6 +21,8 @@ var (
|
||||
errInvalidType = errors.New("invalid token type")
|
||||
// errInvalidRole is returned when the role is invalid.
|
||||
errInvalidRole = errors.New("invalid role")
|
||||
// errInvalidVerified is returned when the verified is invalid.
|
||||
errInvalidVerified = errors.New("invalid verified")
|
||||
// errJWTExpiryKey is used to check if the token is expired.
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
// ErrSignJWT indicates an error in signing jwt token.
|
||||
@@ -36,6 +38,7 @@ const (
|
||||
tokenType = "type"
|
||||
userField = "user"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
oauthProviderField = "oauth_provider"
|
||||
oauthAccessTokenField = "access_token"
|
||||
oauthRefreshTokenField = "refresh_token"
|
||||
@@ -62,6 +65,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
Claim(tokenType, key.Type).
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(RoleField, key.Role)
|
||||
builder.Claim(VerifiedField, key.Verified)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
@@ -150,6 +154,16 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
tVerified, ok := tkn.Get(VerifiedField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
kVerified, ok := tVerified.(bool)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidVerified
|
||||
}
|
||||
|
||||
kr := auth.Role(kRole)
|
||||
if !kr.Validate() {
|
||||
return auth.Key{}, errInvalidRole
|
||||
@@ -162,6 +176,7 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
|
||||
key.Subject = tkn.Subject()
|
||||
key.IssuedAt = tkn.IssuedAt()
|
||||
key.ExpiresAt = tkn.Expiration()
|
||||
key.Verified = kVerified
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
@@ -88,6 +88,7 @@ type Key struct {
|
||||
Role Role `json:"role,omitempty"`
|
||||
IssuedAt time.Time `json:"issued_at,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
Verified bool `json:"verified,omitempty"`
|
||||
}
|
||||
|
||||
func (key Key) String() string {
|
||||
|
||||
@@ -72,7 +72,8 @@ func newCertServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticati
|
||||
logger := smqlog.NewMock()
|
||||
idp := uuid.NewMock()
|
||||
authn := new(authnmocks.Authentication)
|
||||
mux := api.MakeHandler(svc, authn, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux := api.MakeHandler(svc, am, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ const (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for API endpoints.
|
||||
func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
func MakeHandler(svc certs.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
@@ -39,7 +39,7 @@ func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog.
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Route("/{domainID}", func(r chi.Router) {
|
||||
|
||||
@@ -58,7 +58,8 @@ func newChannelsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenti
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
logger := smqlog.NewMock()
|
||||
mux = MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux = MakeHandler(svc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package http
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/channels"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
@@ -22,7 +21,7 @@ func createChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -46,7 +45,7 @@ func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -77,7 +76,7 @@ func viewChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -98,7 +97,7 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -138,7 +137,7 @@ func updateChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -164,7 +163,7 @@ func updateChannelTagsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -189,7 +188,7 @@ func setChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -209,7 +208,7 @@ func removeChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -229,7 +228,7 @@ func enableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -250,7 +249,7 @@ func disableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -271,7 +270,7 @@ func connectChannelClientEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -291,7 +290,7 @@ func disconnectChannelClientsEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -311,7 +310,7 @@ func connectEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -331,7 +330,7 @@ func disconnectEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -351,7 +350,7 @@ func deleteChannelEndpoint(svc channels.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for Channels API endpoints.
|
||||
func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
|
||||
func MakeHandler(svc channels.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
@@ -27,7 +27,7 @@ func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.M
|
||||
d := roleManagerHttp.NewDecoder("channelID")
|
||||
|
||||
mux.Route("/{domainID}/channels", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
|
||||
+10
-8
@@ -12,14 +12,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
all = "all"
|
||||
create = "create"
|
||||
get = "get"
|
||||
update = "update"
|
||||
delete = "delete"
|
||||
enable = "enable"
|
||||
disable = "disable"
|
||||
users = "users"
|
||||
all = "all"
|
||||
create = "create"
|
||||
get = "get"
|
||||
update = "update"
|
||||
delete = "delete"
|
||||
enable = "enable"
|
||||
disable = "disable"
|
||||
users = "users"
|
||||
sendVerification = "send-verification"
|
||||
verifyEmail = "verify-email"
|
||||
|
||||
usageCreate = "cli channels <channel_id> create <JSON_channel> <domain_id> <user_auth_token>"
|
||||
usageGet = "cli channels <channel_id|all> get <domain_id> <user_auth_token>"
|
||||
|
||||
+34
-1
@@ -72,11 +72,16 @@ Examples:
|
||||
logUsageCmd(*cmd, cmd.Use)
|
||||
return
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case create:
|
||||
handleUserCreate(cmd, args[1:])
|
||||
return
|
||||
case sendVerification:
|
||||
handleSendVerification(cmd, args[1])
|
||||
return
|
||||
case verifyEmail:
|
||||
handleVerify(cmd, args[1])
|
||||
return
|
||||
case token:
|
||||
handleUserToken(cmd, args[1], args[2:])
|
||||
return
|
||||
@@ -152,6 +157,34 @@ func handleUserCreate(cmd *cobra.Command, args []string) {
|
||||
logJSONCmd(*cmd, user)
|
||||
}
|
||||
|
||||
func handleSendVerification(cmd *cobra.Command, token string) {
|
||||
if token == "" {
|
||||
logUsageCmd(*cmd, usageUserToken)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sdk.SendVerification(cmd.Context(), token); err != nil {
|
||||
logErrorCmd(*cmd, err)
|
||||
return
|
||||
}
|
||||
|
||||
logJSONCmd(*cmd, "sent verification successfully")
|
||||
}
|
||||
|
||||
func handleVerify(cmd *cobra.Command, token string) {
|
||||
if token == "" {
|
||||
logUsageCmd(*cmd, usageUserToken)
|
||||
return
|
||||
}
|
||||
|
||||
if err := sdk.VerifyEmail(cmd.Context(), token); err != nil {
|
||||
logErrorCmd(*cmd, err)
|
||||
return
|
||||
}
|
||||
|
||||
logJSONCmd(*cmd, "verified successfully")
|
||||
}
|
||||
|
||||
func handleUserGet(cmd *cobra.Command, userParams string, args []string) {
|
||||
if len(args) != 1 {
|
||||
logUsageCmd(*cmd, usageUserGet)
|
||||
|
||||
@@ -17,14 +17,14 @@ import (
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
)
|
||||
|
||||
func clientsHandler(svc clients.Service, authn smqauthn.Authentication, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux {
|
||||
func clientsHandler(svc clients.Service, authn smqauthn.AuthNMiddleware, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
d := roleManagerHttp.NewDecoder("clientID")
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Route("/{domainID}/clients", func(r chi.Router) {
|
||||
|
||||
@@ -6,7 +6,6 @@ package http
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/clients"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
@@ -22,7 +21,7 @@ func createClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -46,7 +45,7 @@ func createClientsEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -77,7 +76,7 @@ func viewClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -98,7 +97,7 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -138,7 +137,7 @@ func updateClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -164,7 +163,7 @@ func updateClientTagsEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -189,7 +188,7 @@ func updateClientSecretEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -210,7 +209,7 @@ func enableClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -231,7 +230,7 @@ func disableClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -252,7 +251,7 @@ func setClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -271,7 +270,7 @@ func removeClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -290,7 +289,7 @@ func deleteClientEndpoint(svc clients.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
@@ -96,7 +96,8 @@ func newClientsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic
|
||||
logger := smqlog.NewMock()
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
clientsapi.MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
clientsapi.MakeHandler(svc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for clients and Groups API endpoints.
|
||||
func MakeHandler(tsvc clients.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
func MakeHandler(tsvc clients.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
mux = clientsHandler(tsvc, authn, mux, logger, idp)
|
||||
|
||||
mux.Get("/health", supermq.Health("clients", instanceID))
|
||||
|
||||
+3
-1
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/absmach/supermq/certs/postgres"
|
||||
"github.com/absmach/supermq/certs/tracing"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
"github.com/absmach/supermq/pkg/grpcclient"
|
||||
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
|
||||
@@ -138,6 +139,7 @@ func main() {
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
|
||||
if err != nil {
|
||||
@@ -163,7 +165,7 @@ func main() {
|
||||
|
||||
idp := uuid.New()
|
||||
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, cfg.InstanceID, idp), logger)
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, cfg.InstanceID, idp), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -179,6 +180,7 @@ func main() {
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
@@ -304,7 +306,7 @@ func main() {
|
||||
}
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.New()
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
|
||||
+3
-1
@@ -31,6 +31,7 @@ import (
|
||||
gpostgres "github.com/absmach/supermq/groups/postgres"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -190,6 +191,7 @@ func main() {
|
||||
}
|
||||
defer authnClient.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
@@ -292,7 +294,7 @@ func main() {
|
||||
}
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.New()
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
|
||||
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
|
||||
|
||||
grpcServerConfig := server.Config{Port: defSvcAuthGRPCPort}
|
||||
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
|
||||
|
||||
+3
-1
@@ -27,6 +27,7 @@ import (
|
||||
dtracing "github.com/absmach/supermq/domains/tracing"
|
||||
redisclient "github.com/absmach/supermq/internal/clients/redis"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
"github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -159,6 +160,7 @@ func main() {
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
database := postgres.NewDatabase(db, dbConfig, tracer)
|
||||
domainsRepo := dpostgres.NewRepository(database)
|
||||
@@ -239,7 +241,7 @@ func main() {
|
||||
}
|
||||
mux := chi.NewMux()
|
||||
idp := uuid.New()
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
|
||||
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
|
||||
|
||||
g.Go(func() error {
|
||||
return hs.Start()
|
||||
|
||||
+3
-1
@@ -28,6 +28,7 @@ import (
|
||||
pgroups "github.com/absmach/supermq/groups/private"
|
||||
"github.com/absmach/supermq/groups/tracing"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -162,6 +163,7 @@ func main() {
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
@@ -263,7 +265,7 @@ func main() {
|
||||
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.New()
|
||||
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
|
||||
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
|
||||
|
||||
grpcServerConfig := server.Config{}
|
||||
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixgRPC}); err != nil {
|
||||
|
||||
+3
-1
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/absmach/supermq/journal/middleware"
|
||||
journalpg "github.com/absmach/supermq/journal/postgres"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -112,6 +113,7 @@ func main() {
|
||||
}
|
||||
defer authnHandler.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
@@ -173,7 +175,7 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, svcName, cfg.InstanceID), logger)
|
||||
hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, svcName, cfg.InstanceID), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
|
||||
+51
-29
@@ -20,6 +20,7 @@ import (
|
||||
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
|
||||
"github.com/absmach/supermq/internal/email"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
|
||||
smqauthz "github.com/absmach/supermq/pkg/authz"
|
||||
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
|
||||
@@ -67,28 +68,31 @@ const (
|
||||
)
|
||||
|
||||
type config struct {
|
||||
LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"`
|
||||
AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"`
|
||||
AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"`
|
||||
AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"`
|
||||
AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"`
|
||||
AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"`
|
||||
PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"`
|
||||
OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"`
|
||||
OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"`
|
||||
DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"`
|
||||
DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost:8080"`
|
||||
PassRegex *regexp.Regexp
|
||||
LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"`
|
||||
AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"`
|
||||
AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"`
|
||||
AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"`
|
||||
AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"`
|
||||
AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"`
|
||||
PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"`
|
||||
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
|
||||
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
|
||||
InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""`
|
||||
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
|
||||
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
|
||||
SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"`
|
||||
OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"`
|
||||
OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"`
|
||||
DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"`
|
||||
DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"`
|
||||
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
|
||||
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
|
||||
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
|
||||
PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost/password/reset"`
|
||||
PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"`
|
||||
VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"`
|
||||
VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"`
|
||||
PassRegex *regexp.Regexp
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -121,12 +125,21 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
ec := email.Config{}
|
||||
if err := env.Parse(&ec); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load email configuration : %s", err.Error()))
|
||||
resetPasswordEmailConfig := email.Config{}
|
||||
if err := env.Parse(&resetPasswordEmailConfig); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load reset password email configuration : %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
resetPasswordEmailConfig.Template = cfg.PasswordResetEmailTemplate
|
||||
|
||||
verificationEmailConfig := email.Config{}
|
||||
if err := env.Parse(&verificationEmailConfig); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load verification password email configuration : %s", err.Error()))
|
||||
exitCode = 1
|
||||
return
|
||||
}
|
||||
verificationEmailConfig.Template = cfg.VerificationEmailTemplate
|
||||
|
||||
dbConfig := pgclient.Config{Name: defDB}
|
||||
if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil {
|
||||
@@ -182,6 +195,8 @@ func main() {
|
||||
defer authnHandler.Close()
|
||||
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
|
||||
|
||||
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
|
||||
|
||||
domsGrpcCfg := grpcclient.Config{}
|
||||
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
|
||||
@@ -213,7 +228,7 @@ func main() {
|
||||
}
|
||||
logger.Info("Policy client successfully connected to spicedb gRPC server")
|
||||
|
||||
csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, ec, logger)
|
||||
csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, resetPasswordEmailConfig, verificationEmailConfig, logger)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to setup service: %s", err))
|
||||
exitCode = 1
|
||||
@@ -237,7 +252,7 @@ func main() {
|
||||
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.New()
|
||||
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authn, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger)
|
||||
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authnMiddleware, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger)
|
||||
|
||||
if cfg.SendTelemetry {
|
||||
chc := chclient.New(svcName, supermq.Version, logger, cancel)
|
||||
@@ -257,16 +272,23 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, ec email.Config, logger *slog.Logger) (users.Service, error) {
|
||||
func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, resetPasswordEmailConfig, verificationEmailConfig email.Config, logger *slog.Logger) (users.Service, error) {
|
||||
database := pg.NewDatabase(db, dbConfig, tracer)
|
||||
idp := uuid.New()
|
||||
hsr := hasher.New()
|
||||
|
||||
// Creating users service
|
||||
repo := postgres.NewRepository(database)
|
||||
emailerClient, err := emailer.New(fmt.Sprintf("%s/reset-request", c.PasswordResetURLPrefix), &ec)
|
||||
|
||||
emailerClient, err := emailer.New(
|
||||
c.PasswordResetURLPrefix,
|
||||
c.VerificationURLPrefix,
|
||||
&resetPasswordEmailConfig,
|
||||
&verificationEmailConfig,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Error(fmt.Sprintf("failed to configure e-mailing util: %s", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc := users.NewService(token, repo, policyService, emailerClient, hsr, idp)
|
||||
|
||||
+10
-6
@@ -248,7 +248,6 @@ SMQ_USERS_DB_SSL_MODE=disable
|
||||
SMQ_USERS_DB_SSL_CERT=
|
||||
SMQ_USERS_DB_SSL_KEY=
|
||||
SMQ_USERS_DB_SSL_ROOT_CERT=
|
||||
SMQ_USERS_RESET_PWD_TEMPLATE=users.tmpl
|
||||
SMQ_USERS_INSTANCE_ID=
|
||||
SMQ_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH
|
||||
SMQ_USERS_ADMIN_EMAIL=admin@example.com
|
||||
@@ -261,19 +260,21 @@ SMQ_OAUTH_UI_REDIRECT_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/tokens/secu
|
||||
SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/error
|
||||
SMQ_USERS_DELETE_INTERVAL=24h
|
||||
SMQ_USERS_DELETE_AFTER=720h
|
||||
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9001
|
||||
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset
|
||||
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl
|
||||
SMQ_VERIFICATION_URL_PREFIX=http://localhost/verify-email
|
||||
SMQ_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl
|
||||
|
||||
#### Users Client Config
|
||||
SMQ_USERS_URL=users:9002
|
||||
|
||||
### Email utility
|
||||
SMQ_EMAIL_HOST=smtp.mailtrap.io
|
||||
SMQ_EMAIL_HOST=host.docker.internal
|
||||
SMQ_EMAIL_PORT=2525
|
||||
SMQ_EMAIL_USERNAME=18bf7f70705139
|
||||
SMQ_EMAIL_PASSWORD=2b0d302e775b1e
|
||||
SMQ_EMAIL_USERNAME=from@example.com
|
||||
SMQ_EMAIL_PASSWORD=password
|
||||
SMQ_EMAIL_FROM_ADDRESS=from@example.com
|
||||
SMQ_EMAIL_FROM_NAME=Example
|
||||
SMQ_EMAIL_TEMPLATE=email.tmpl
|
||||
|
||||
### Google OAuth2
|
||||
SMQ_GOOGLE_CLIENT_ID=
|
||||
@@ -514,5 +515,8 @@ SMQ_GRAFANA_PORT=3000
|
||||
SMQ_GRAFANA_ADMIN_USER=supermq
|
||||
SMQ_GRAFANA_ADMIN_PASSWORD=supermq
|
||||
|
||||
## Allow unverified user to access
|
||||
SMQ_ALLOW_UNVERIFIED_USER=true
|
||||
|
||||
# Docker image tag
|
||||
SMQ_RELEASE_TAG=latest
|
||||
|
||||
@@ -74,6 +74,7 @@ services:
|
||||
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
|
||||
SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY}
|
||||
SMQ_CERTS_INSTANCE_ID: ${SMQ_CERTS_INSTANCE_ID}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
volumes:
|
||||
- ../../ssl/certs/ca.key:/etc/ssl/certs/ca.key
|
||||
- ../../ssl/certs/ca.crt:/etc/ssl/certs/ca.crt
|
||||
|
||||
@@ -69,6 +69,7 @@ services:
|
||||
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
|
||||
SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key}
|
||||
SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
ports:
|
||||
- ${SMQ_JOURNAL_HTTP_PORT}:${SMQ_JOURNAL_HTTP_PORT}
|
||||
networks:
|
||||
|
||||
@@ -285,6 +285,7 @@ services:
|
||||
SMQ_DOMAINS_CALLOUT_CERT: ${SMQ_DOMAINS_CALLOUT_CERT}
|
||||
SMQ_DOMAINS_CALLOUT_KEY: ${SMQ_DOMAINS_CALLOUT_KEY}
|
||||
SMQ_DOMAINS_CALLOUT_OPERATIONS: ${SMQ_DOMAINS_CALLOUT_OPERATIONS}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
ports:
|
||||
- ${SMQ_DOMAINS_HTTP_PORT}:${SMQ_DOMAINS_HTTP_PORT}
|
||||
- ${SMQ_DOMAINS_GRPC_PORT}:${SMQ_DOMAINS_GRPC_PORT}
|
||||
@@ -516,6 +517,7 @@ services:
|
||||
SMQ_CLIENTS_CALLOUT_CERT: ${SMQ_CLIENTS_CALLOUT_CERT}
|
||||
SMQ_CLIENTS_CALLOUT_KEY: ${SMQ_CLIENTS_CALLOUT_KEY}
|
||||
SMQ_CLIENTS_CALLOUT_OPERATIONS: ${SMQ_CLIENTS_CALLOUT_OPERATIONS}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
ports:
|
||||
- ${SMQ_CLIENTS_HTTP_PORT}:${SMQ_CLIENTS_HTTP_PORT}
|
||||
- ${SMQ_CLIENTS_GRPC_PORT}:${SMQ_CLIENTS_GRPC_PORT}
|
||||
@@ -706,6 +708,7 @@ services:
|
||||
SMQ_CHANNELS_CALLOUT_CERT: ${SMQ_CHANNELS_CALLOUT_CERT}
|
||||
SMQ_CHANNELS_CALLOUT_KEY: ${SMQ_CHANNELS_CALLOUT_KEY}
|
||||
SMQ_CHANNELS_CALLOUT_OPERATIONS: ${SMQ_CHANNELS_CALLOUT_OPERATIONS}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
ports:
|
||||
- ${SMQ_CHANNELS_HTTP_PORT}:${SMQ_CHANNELS_HTTP_PORT}
|
||||
- ${SMQ_CHANNELS_GRPC_PORT}:${SMQ_CHANNELS_GRPC_PORT}
|
||||
@@ -855,7 +858,6 @@ services:
|
||||
SMQ_EMAIL_PASSWORD: ${SMQ_EMAIL_PASSWORD}
|
||||
SMQ_EMAIL_FROM_ADDRESS: ${SMQ_EMAIL_FROM_ADDRESS}
|
||||
SMQ_EMAIL_FROM_NAME: ${SMQ_EMAIL_FROM_NAME}
|
||||
SMQ_EMAIL_TEMPLATE: ${SMQ_EMAIL_TEMPLATE}
|
||||
SMQ_ES_URL: ${SMQ_ES_URL}
|
||||
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
|
||||
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
|
||||
@@ -882,12 +884,16 @@ services:
|
||||
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
|
||||
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
|
||||
SMQ_PASSWORD_RESET_URL_PREFIX: ${SMQ_PASSWORD_RESET_URL_PREFIX}
|
||||
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE: ${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}
|
||||
SMQ_VERIFICATION_URL_PREFIX: ${SMQ_VERIFICATION_URL_PREFIX}
|
||||
SMQ_VERIFICATION_EMAIL_TEMPLATE: ${SMQ_VERIFICATION_EMAIL_TEMPLATE}
|
||||
ports:
|
||||
- ${SMQ_USERS_HTTP_PORT}:${SMQ_USERS_HTTP_PORT}
|
||||
networks:
|
||||
- supermq-base-net
|
||||
volumes:
|
||||
- ./templates/${SMQ_USERS_RESET_PWD_TEMPLATE}:/email.tmpl
|
||||
- ./templates/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}:/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}
|
||||
- ./templates/${SMQ_VERIFICATION_EMAIL_TEMPLATE}:/${SMQ_VERIFICATION_EMAIL_TEMPLATE}
|
||||
# Auth gRPC client certificates
|
||||
- type: bind
|
||||
source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
|
||||
@@ -1007,6 +1013,7 @@ services:
|
||||
SMQ_GROUPS_CALLOUT_CERT: ${SMQ_GROUPS_CALLOUT_CERT}
|
||||
SMQ_GROUPS_CALLOUT_KEY: ${SMQ_GROUPS_CALLOUT_KEY}
|
||||
SMQ_GROUPS_CALLOUT_OPERATIONS: ${SMQ_GROUPS_CALLOUT_OPERATIONS}
|
||||
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
|
||||
ports:
|
||||
- ${SMQ_GROUPS_HTTP_PORT}:${SMQ_GROUPS_HTTP_PORT}
|
||||
- ${SMQ_GROUPS_GRPC_PORT}:${SMQ_GROUPS_GRPC_PORT}
|
||||
|
||||
@@ -72,7 +72,7 @@ http {
|
||||
}
|
||||
|
||||
# Proxy pass to users service
|
||||
location ~ ^/(users|password|authorize|oauth/callback/[^/]+) {
|
||||
location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) {
|
||||
include snippets/proxy-headers.conf;
|
||||
add_header Access-Control-Expose-Headers Location;
|
||||
proxy_pass http://users:${SMQ_USERS_HTTP_PORT};
|
||||
|
||||
@@ -81,7 +81,7 @@ http {
|
||||
}
|
||||
|
||||
# Proxy pass to users service
|
||||
location ~ ^/(users|groups|password|authorize|oauth/callback/[^/]+) {
|
||||
location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) {
|
||||
include snippets/proxy-headers.conf;
|
||||
add_header Access-Control-Expose-Headers Location;
|
||||
proxy_pass http://users:${SMQ_USERS_HTTP_PORT};
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
Dear {{.User}},
|
||||
|
||||
Welcome to {{.Host}}! To complete your account registration, please verify your email address by clicking on the link below:
|
||||
|
||||
{{.Content}}
|
||||
|
||||
This verification link will expire in 24 hours. If you did not create an account on {{.Host}}, please disregard this message.
|
||||
|
||||
Once verified, you'll have full access to all features on {{.Host}}.
|
||||
|
||||
Thank you for joining {{.Host}}!
|
||||
|
||||
Best regards,
|
||||
|
||||
{{.Footer}}
|
||||
@@ -6,7 +6,6 @@ package http
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/domains"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
@@ -25,7 +24,7 @@ func createDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -53,7 +52,7 @@ func retrieveDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -73,7 +72,7 @@ func updateDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -103,7 +102,7 @@ func listDomainsEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -123,7 +122,7 @@ func enableDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -142,7 +141,7 @@ func disableDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -161,7 +160,7 @@ func freezeDomainEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -179,7 +178,7 @@ func sendInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -208,7 +207,7 @@ func listDomainInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -231,7 +230,7 @@ func listUserInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -254,7 +253,7 @@ func acceptInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -274,7 +273,7 @@ func rejectInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -294,7 +293,7 @@ func deleteInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
smqlog "github.com/absmach/supermq/logger"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
authnmock "github.com/absmach/supermq/pkg/authn/mocks"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
@@ -98,7 +99,8 @@ func newDomainsServer() (*httptest.Server, *mocks.Service, *authnmock.Authentica
|
||||
authn := new(authnmock.Authentication)
|
||||
mux := chi.NewMux()
|
||||
idp := uuid.NewMock()
|
||||
domainsapi.MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
domainsapi.MakeHandler(svc, am, mux, logger, "", idp)
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/domains"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
|
||||
"github.com/go-chi/chi/v5"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for Domains and Invitations API endpoints.
|
||||
func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
func MakeHandler(svc domains.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, false))
|
||||
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware())
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
createDomainEndpoint(svc),
|
||||
decodeCreateDomainRequest,
|
||||
@@ -49,7 +49,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
|
||||
})
|
||||
|
||||
r.Route("/{domainID}", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
retrieveDomainEndpoint(svc),
|
||||
decodeRetrieveDomainRequest,
|
||||
@@ -88,7 +88,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
|
||||
})
|
||||
|
||||
r.Route("/{domainID}/invitations", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
sendInvitationEndpoint(svc),
|
||||
decodeSendInvitationReq,
|
||||
@@ -111,7 +111,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
|
||||
})
|
||||
|
||||
mux.Route("/invitations", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, false))
|
||||
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware())
|
||||
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listUserInvitationsEndpoint(svc),
|
||||
decodeListInvitationsReq,
|
||||
|
||||
@@ -60,7 +60,8 @@ func newGroupsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentica
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
logger := smqlog.NewMock()
|
||||
mux = MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux = MakeHandler(svc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ package api
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/groups"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
@@ -22,7 +21,7 @@ func CreateGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return createGroupRes{created: false}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return createGroupRes{created: false}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -43,7 +42,7 @@ func ViewGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return viewGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return viewGroupRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -64,7 +63,7 @@ func UpdateGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return updateGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return updateGroupRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -92,7 +91,7 @@ func updateGroupTagsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -117,7 +116,7 @@ func EnableGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func DisableGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -158,7 +157,7 @@ func ListGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return groupPageRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return groupPageRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -198,7 +197,7 @@ func DeleteGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return deleteGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return deleteGroupRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -216,7 +215,7 @@ func retrieveGroupHierarchyEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return retrieveGroupHierarchyRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -244,7 +243,7 @@ func addParentGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return addParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -263,7 +262,7 @@ func removeParentGroupEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return removeParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -282,7 +281,7 @@ func addChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return addChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -301,7 +300,7 @@ func removeChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return removeChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -320,7 +319,7 @@ func removeAllChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return removeAllChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -339,7 +338,7 @@ func listChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
|
||||
return listChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return changeStatusRes{}, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/groups"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
smqauthn "github.com/absmach/supermq/pkg/authn"
|
||||
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
|
||||
"github.com/go-chi/chi/v5"
|
||||
kithttp "github.com/go-kit/kit/transport/http"
|
||||
@@ -19,14 +19,14 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for Groups API endpoints.
|
||||
func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
|
||||
func MakeHandler(svc groups.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
d := roleManagerHttp.NewDecoder("groupID")
|
||||
|
||||
mux.Route("/{domainID}/groups", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
|
||||
@@ -6,7 +6,7 @@ syntax = "proto3";
|
||||
package auth.v1;
|
||||
option go_package = "github.com/absmach/supermq/api/grpc/auth/v1";
|
||||
|
||||
// AuthService is a service that provides authentication
|
||||
// AuthService is a service that provides authentication
|
||||
// and authorization functionalities for SuperMQ services.
|
||||
service AuthService {
|
||||
rpc Authorize(AuthZReq) returns (AuthZRes) {}
|
||||
@@ -23,7 +23,8 @@ message AuthNReq {
|
||||
message AuthNRes {
|
||||
string id = 1; // token id
|
||||
string user_id = 2; // user id
|
||||
uint32 user_role = 3; // user role
|
||||
uint32 user_role = 3; // user role
|
||||
bool verified = 4; // verified user
|
||||
}
|
||||
|
||||
message AuthZReq {
|
||||
|
||||
@@ -15,10 +15,12 @@ message IssueReq {
|
||||
string user_id = 1;
|
||||
uint32 user_role = 2;
|
||||
uint32 type = 3;
|
||||
bool verified = 4;
|
||||
}
|
||||
|
||||
message RefreshReq {
|
||||
string refresh_token = 1;
|
||||
bool verified = 2;
|
||||
}
|
||||
|
||||
// If a token is not carrying any information itself, the type
|
||||
|
||||
@@ -6,7 +6,6 @@ package api
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/journal"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
@@ -22,7 +21,7 @@ func retrieveJournalsEndpoint(svc journal.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -45,7 +44,7 @@ func retrieveClientTelemetryEndpoint(svc journal.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
@@ -57,7 +57,8 @@ func newjournalServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic
|
||||
|
||||
logger := smqlog.NewMock()
|
||||
authn := new(authnmocks.Authentication)
|
||||
mux := httpapi.MakeHandler(svc, authn, logger, "journal-log", "test")
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux := httpapi.MakeHandler(svc, am, logger, "journal-log", "test")
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ const (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP API handler with health check and metrics.
|
||||
func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slog.Logger, svcName, instanceID string) http.Handler {
|
||||
func MakeHandler(svc journal.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, svcName, instanceID string) http.Handler {
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo
|
||||
idp := uuid.New()
|
||||
mux.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
mux.With(api.AuthenticateMiddleware(authn, false)).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
mux.With(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware()).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
retrieveJournalsEndpoint(svc),
|
||||
decodeRetrieveUserJournalReq,
|
||||
api.EncodeResponse,
|
||||
@@ -51,7 +51,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo
|
||||
), "list_user_journals").ServeHTTP)
|
||||
|
||||
mux.Route("/{domainID}/journal", func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, true))
|
||||
r.Use(authn.Middleware())
|
||||
|
||||
r.Get("/{entityType}/{entityID}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
retrieveJournalsEndpoint(svc),
|
||||
|
||||
@@ -44,6 +44,7 @@ type Session struct {
|
||||
DomainID string
|
||||
DomainUserID string
|
||||
SuperAdmin bool
|
||||
Verified bool
|
||||
Role Role
|
||||
}
|
||||
|
||||
|
||||
@@ -54,5 +54,5 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
|
||||
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
|
||||
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole()), Verified: res.GetVerified()}, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package authn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
type sessionKeyType string
|
||||
|
||||
const (
|
||||
allowUnverifiedUserEnv = "SMQ_ALLOW_UNVERIFIED_USER"
|
||||
jsonContentType = "application/json"
|
||||
|
||||
SessionKey = sessionKeyType("session")
|
||||
)
|
||||
|
||||
// middlewareOptions contains configuration for authentication middleware.
|
||||
type middlewareOptions struct {
|
||||
domainCheck bool
|
||||
allowUnverifiedUser bool
|
||||
}
|
||||
|
||||
// defaultMiddlewareOptions returns the default middleware configuration.
|
||||
func defaultMiddlewareOptions() *middlewareOptions {
|
||||
return &middlewareOptions{
|
||||
domainCheck: true,
|
||||
allowUnverifiedUser: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MiddlewareOption is a function that modifies middleware options.
|
||||
type MiddlewareOption func(*middlewareOptions)
|
||||
|
||||
// WithDomainCheck sets whether domain checking is enabled.
|
||||
func WithDomainCheck(enabled bool) MiddlewareOption {
|
||||
return func(opts *middlewareOptions) {
|
||||
opts.domainCheck = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// WithAllowUnverifiedUser sets whether unverified users are allowed.
|
||||
func WithAllowUnverifiedUser(allowed bool) MiddlewareOption {
|
||||
return func(opts *middlewareOptions) {
|
||||
opts.allowUnverifiedUser = allowed
|
||||
}
|
||||
}
|
||||
|
||||
// WithDefaultMiddlewareOptions resets options to default values.
|
||||
func WithDefaultMiddlewareOptions() MiddlewareOption {
|
||||
return func(opts *middlewareOptions) {
|
||||
defaults := defaultMiddlewareOptions()
|
||||
opts.domainCheck = defaults.domainCheck
|
||||
opts.allowUnverifiedUser = defaults.allowUnverifiedUser
|
||||
}
|
||||
}
|
||||
|
||||
// AuthNMiddleware defines the interface for authenticated services with middleware.
|
||||
type AuthNMiddleware interface {
|
||||
Authentication
|
||||
WithOptions(options ...MiddlewareOption) AuthNMiddleware
|
||||
Middleware() func(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
// authnMiddleware wraps Authentication with middleware functionality.
|
||||
type authnMiddleware struct {
|
||||
Authentication
|
||||
options []MiddlewareOption
|
||||
}
|
||||
|
||||
// NewAuthNMiddleware creates a new authenticated service with middleware support.
|
||||
// The order of precedence for options is as follows, with later options overriding earlier ones:
|
||||
// 1. Default options (lowest precedence).
|
||||
// 2. Options from environment variables (e.g., SMQ_ALLOW_UNVERIFIED_USER).
|
||||
// 3. Options passed as arguments to this function (highest precedence).
|
||||
//
|
||||
// For example, consider the 'allowUnverifiedUser' option:
|
||||
// - By default, it is 'false'.
|
||||
// - If the SMQ_ALLOW_UNVERIFIED_USER environment variable is set to "true",
|
||||
// it becomes 'true'.
|
||||
// - If NewAuthNMiddleware is called with WithAllowUnverifiedUser(false), it will be 'false',
|
||||
// regardless of the environment variable, as function arguments have the highest precedence.
|
||||
func NewAuthNMiddleware(authnSvc Authentication, options ...MiddlewareOption) AuthNMiddleware {
|
||||
allOptions := []MiddlewareOption{WithDefaultMiddlewareOptions()}
|
||||
if val, ok := os.LookupEnv(allowUnverifiedUserEnv); ok {
|
||||
allowUnverifiedUser, err := strconv.ParseBool(val)
|
||||
if err == nil && allowUnverifiedUser {
|
||||
allOptions = append(allOptions, WithAllowUnverifiedUser(true))
|
||||
}
|
||||
}
|
||||
allOptions = append(allOptions, options...)
|
||||
return &authnMiddleware{
|
||||
Authentication: authnSvc,
|
||||
options: allOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// WithOptions returns a new service with additional options.
|
||||
func (a *authnMiddleware) WithOptions(options ...MiddlewareOption) AuthNMiddleware {
|
||||
return &authnMiddleware{
|
||||
Authentication: a.Authentication,
|
||||
options: append(a.options, options...),
|
||||
}
|
||||
}
|
||||
|
||||
// getMiddlewareOptions returns the configured middleware options.
|
||||
func (a *authnMiddleware) getMiddlewareOptions() *middlewareOptions {
|
||||
opts := defaultMiddlewareOptions()
|
||||
for _, option := range a.options {
|
||||
option(opts)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware function that handles authentication.
|
||||
func (a *authnMiddleware) Middleware() func(http.Handler) http.Handler {
|
||||
opts := a.getMiddlewareOptions()
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := apiutil.ExtractBearerToken(r)
|
||||
if token == "" {
|
||||
encodeError(w, apiutil.ErrBearerToken, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
resp, err := a.Authenticate(r.Context(), token)
|
||||
if err != nil {
|
||||
encodeError(w, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Type == AccessToken && !opts.allowUnverifiedUser && resp.Role != AdminRole && !resp.Verified {
|
||||
encodeError(w, apiutil.ErrEmailNotVerified, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if opts.domainCheck {
|
||||
domain := chi.URLParam(r, "domainID")
|
||||
if domain == "" {
|
||||
encodeError(w, apiutil.ErrMissingDomainID, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
resp.DomainID = domain
|
||||
switch resp.Role {
|
||||
case AdminRole:
|
||||
resp.DomainUserID = resp.UserID
|
||||
case UserRole:
|
||||
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), SessionKey, resp)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func encodeError(w http.ResponseWriter, err error, statusCode int) {
|
||||
if errorVal, ok := err.(errors.Error); ok {
|
||||
w.Header().Set("Content-Type", jsonContentType)
|
||||
w.WriteHeader(statusCode)
|
||||
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.Error(w, err.Error(), statusCode)
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewAuthNMiddleware creates a new instance of AuthNMiddleware. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewAuthNMiddleware(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *AuthNMiddleware {
|
||||
mock := &AuthNMiddleware{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// AuthNMiddleware is an autogenerated mock type for the AuthNMiddleware type
|
||||
type AuthNMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type AuthNMiddleware_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *AuthNMiddleware) EXPECT() *AuthNMiddleware_Expecter {
|
||||
return &AuthNMiddleware_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Authenticate provides a mock function for the type AuthNMiddleware
|
||||
func (_mock *AuthNMiddleware) Authenticate(ctx context.Context, token string) (authn.Session, error) {
|
||||
ret := _mock.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Authenticate")
|
||||
}
|
||||
|
||||
var r0 authn.Session
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (authn.Session, error)); ok {
|
||||
return returnFunc(ctx, token)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) authn.Session); ok {
|
||||
r0 = returnFunc(ctx, token)
|
||||
} else {
|
||||
r0 = ret.Get(0).(authn.Session)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, token)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// AuthNMiddleware_Authenticate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Authenticate'
|
||||
type AuthNMiddleware_Authenticate_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Authenticate is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - token string
|
||||
func (_e *AuthNMiddleware_Expecter) Authenticate(ctx interface{}, token interface{}) *AuthNMiddleware_Authenticate_Call {
|
||||
return &AuthNMiddleware_Authenticate_Call{Call: _e.mock.On("Authenticate", ctx, token)}
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Authenticate_Call) Run(run func(ctx context.Context, token string)) *AuthNMiddleware_Authenticate_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Authenticate_Call) Return(session authn.Session, err error) *AuthNMiddleware_Authenticate_Call {
|
||||
_c.Call.Return(session, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Authenticate_Call) RunAndReturn(run func(ctx context.Context, token string) (authn.Session, error)) *AuthNMiddleware_Authenticate_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Middleware provides a mock function for the type AuthNMiddleware
|
||||
func (_mock *AuthNMiddleware) Middleware() func(http.Handler) http.Handler {
|
||||
ret := _mock.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Middleware")
|
||||
}
|
||||
|
||||
var r0 func(http.Handler) http.Handler
|
||||
if returnFunc, ok := ret.Get(0).(func() func(http.Handler) http.Handler); ok {
|
||||
r0 = returnFunc()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(func(http.Handler) http.Handler)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AuthNMiddleware_Middleware_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Middleware'
|
||||
type AuthNMiddleware_Middleware_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Middleware is a helper method to define mock.On call
|
||||
func (_e *AuthNMiddleware_Expecter) Middleware() *AuthNMiddleware_Middleware_Call {
|
||||
return &AuthNMiddleware_Middleware_Call{Call: _e.mock.On("Middleware")}
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Middleware_Call) Run(run func()) *AuthNMiddleware_Middleware_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Middleware_Call) Return(fn func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call {
|
||||
_c.Call.Return(fn)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_Middleware_Call) RunAndReturn(run func() func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// WithOptions provides a mock function for the type AuthNMiddleware
|
||||
func (_mock *AuthNMiddleware) WithOptions(options ...authn.MiddlewareOption) authn.AuthNMiddleware {
|
||||
var tmpRet mock.Arguments
|
||||
if len(options) > 0 {
|
||||
tmpRet = _mock.Called(options)
|
||||
} else {
|
||||
tmpRet = _mock.Called()
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for WithOptions")
|
||||
}
|
||||
|
||||
var r0 authn.AuthNMiddleware
|
||||
if returnFunc, ok := ret.Get(0).(func(...authn.MiddlewareOption) authn.AuthNMiddleware); ok {
|
||||
r0 = returnFunc(options...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(authn.AuthNMiddleware)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// AuthNMiddleware_WithOptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithOptions'
|
||||
type AuthNMiddleware_WithOptions_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// WithOptions is a helper method to define mock.On call
|
||||
// - options ...authn.MiddlewareOption
|
||||
func (_e *AuthNMiddleware_Expecter) WithOptions(options ...interface{}) *AuthNMiddleware_WithOptions_Call {
|
||||
return &AuthNMiddleware_WithOptions_Call{Call: _e.mock.On("WithOptions",
|
||||
append([]interface{}{}, options...)...)}
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_WithOptions_Call) Run(run func(options ...authn.MiddlewareOption)) *AuthNMiddleware_WithOptions_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 []authn.MiddlewareOption
|
||||
var variadicArgs []authn.MiddlewareOption
|
||||
if len(args) > 0 {
|
||||
variadicArgs = args[0].([]authn.MiddlewareOption)
|
||||
}
|
||||
arg0 = variadicArgs
|
||||
run(
|
||||
arg0...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_WithOptions_Call) Return(authNMiddleware authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call {
|
||||
_c.Call.Return(authNMiddleware)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *AuthNMiddleware_WithOptions_Call) RunAndReturn(run func(options ...authn.MiddlewareOption) authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -93,4 +93,13 @@ var (
|
||||
|
||||
// ErrSuperAdminAction indicates that the user is not a super admin.
|
||||
ErrSuperAdminAction = errors.New("not authorized to perform admin action")
|
||||
|
||||
// ErrUserAlreadyVerified indicates user is already verified.
|
||||
ErrUserAlreadyVerified = errors.New("user already verified")
|
||||
|
||||
// ErrInvalidUserVerification indicates user verification is invalid.
|
||||
ErrInvalidUserVerification = errors.New("invalid verification")
|
||||
|
||||
// ErrUserVerificationExpired indicates user verification is expired.
|
||||
ErrUserVerificationExpired = errors.New("verification expired, please generate new verification")
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ package http
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -22,7 +21,7 @@ func CreateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -42,7 +41,7 @@ func ListRolesEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -62,7 +61,7 @@ func ListEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -94,7 +93,7 @@ func RemoveEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func ViewRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -133,7 +132,7 @@ func UpdateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -153,7 +152,7 @@ func DeleteRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -172,7 +171,7 @@ func ListAvailableActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -192,7 +191,7 @@ func AddRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -212,7 +211,7 @@ func ListRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -232,7 +231,7 @@ func DeleteRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -251,7 +250,7 @@ func DeleteAllRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -270,7 +269,7 @@ func AddRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -290,7 +289,7 @@ func ListRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -310,7 +309,7 @@ func DeleteRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -329,7 +328,7 @@ func DeleteAllRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
@@ -67,7 +67,8 @@ func setupCerts() (*httptest.Server, *mocks.Service, *authnmocks.Authentication)
|
||||
logger := smqlog.NewMock()
|
||||
authn := new(authnmocks.Authentication)
|
||||
idp := uuid.NewMock()
|
||||
mux := httpapi.MakeHandler(svc, authn, logger, instanceID, idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux := httpapi.MakeHandler(svc, am, logger, instanceID, idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -43,7 +43,8 @@ func setupChannels() (*httptest.Server, *chmocks.Service, *authnmocks.Authentica
|
||||
authn := new(authnmocks.Authentication)
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
chapi.MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
chapi.MakeHandler(svc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
@@ -1094,7 +1095,7 @@ func TestUpdateChannelTags(t *testing.T) {
|
||||
svcRes: channels.Channel{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Channel{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "update channel tags with empty token",
|
||||
|
||||
@@ -37,7 +37,8 @@ func setupClients() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
authn := new(authnmocks.Authentication)
|
||||
api.MakeHandler(tsvc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
api.MakeHandler(tsvc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), tsvc, authn
|
||||
}
|
||||
@@ -647,7 +648,7 @@ func TestViewClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "view client with empty token",
|
||||
@@ -786,7 +787,7 @@ func TestUpdateClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "update client with empty token",
|
||||
@@ -940,7 +941,7 @@ func TestUpdateClientTags(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "update client tags with empty token",
|
||||
@@ -1089,7 +1090,7 @@ func TestUpdateClientSecret(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "update client secret with empty token",
|
||||
@@ -1217,7 +1218,7 @@ func TestEnableClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "enable client with an invalid client id",
|
||||
@@ -1320,7 +1321,7 @@ func TestDisableClient(t *testing.T) {
|
||||
svcRes: clients.Client{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Client{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "disable client with an invalid client id",
|
||||
@@ -1415,7 +1416,7 @@ func TestDeleteClient(t *testing.T) {
|
||||
token: invalidToken,
|
||||
clientID: client.ID,
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "delete client with empty token",
|
||||
|
||||
@@ -62,8 +62,9 @@ func setupDomains() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
|
||||
mux := chi.NewRouter()
|
||||
idp := uuid.NewMock()
|
||||
authn := new(authnmocks.Authentication)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
|
||||
handler := domainapi.MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
handler := domainapi.MakeHandler(svc, am, mux, logger, "", idp)
|
||||
return httptest.NewServer(handler), svc, authn
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,8 @@ func setupGroups() (*httptest.Server, *mocks.Service, *authnmocks.Authentication
|
||||
provider := new(oauth2mocks.Provider)
|
||||
provider.On("Name").Return(roleName)
|
||||
authn := new(authnmocks.Authentication)
|
||||
httpapi.MakeHandler(svc, authn, mux, logger, "", idp)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
httpapi.MakeHandler(svc, am, mux, logger, "", idp)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
@@ -922,7 +923,7 @@ func TestUpdateGroupTags(t *testing.T) {
|
||||
svcRes: groups.Group{},
|
||||
authenticateErr: svcerr.ErrAuthorization,
|
||||
response: sdk.Group{},
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
|
||||
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
|
||||
},
|
||||
{
|
||||
desc: "update group tags with empty token",
|
||||
|
||||
@@ -29,7 +29,8 @@ func setupJournal() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
|
||||
svc := new(mocks.Service)
|
||||
authn := new(authnmocks.Authentication)
|
||||
logger := smqlog.NewMock()
|
||||
mux := api.MakeHandler(svc, authn, logger, "journal-log", "test")
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
mux := api.MakeHandler(svc, am, logger, "journal-log", "test")
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
|
||||
@@ -8040,6 +8040,65 @@ func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domai
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendVerification provides a mock function for the type SDK
|
||||
func (_mock *SDK) SendVerification(ctx context.Context, token string) errors.SDKError {
|
||||
ret := _mock.Called(ctx, token)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendVerification")
|
||||
}
|
||||
|
||||
var r0 errors.SDKError
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
|
||||
r0 = returnFunc(ctx, token)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(errors.SDKError)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// SDK_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
|
||||
type SDK_SendVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendVerification is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - token string
|
||||
func (_e *SDK_Expecter) SendVerification(ctx interface{}, token interface{}) *SDK_SendVerification_Call {
|
||||
return &SDK_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, token)}
|
||||
}
|
||||
|
||||
func (_c *SDK_SendVerification_Call) Run(run func(ctx context.Context, token string)) *SDK_SendVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_SendVerification_Call) Return(sDKError errors.SDKError) *SDK_SendVerification_Call {
|
||||
_c.Call.Return(sDKError)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_SendVerification_Call) RunAndReturn(run func(ctx context.Context, token string) errors.SDKError) *SDK_SendVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetChannelParent provides a mock function for the type SDK
|
||||
func (_mock *SDK) SetChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError {
|
||||
ret := _mock.Called(ctx, id, domainID, groupID, token)
|
||||
@@ -9974,6 +10033,65 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk.Page
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyEmail provides a mock function for the type SDK
|
||||
func (_mock *SDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError {
|
||||
ret := _mock.Called(ctx, verificationToken)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyEmail")
|
||||
}
|
||||
|
||||
var r0 errors.SDKError
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
|
||||
r0 = returnFunc(ctx, verificationToken)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(errors.SDKError)
|
||||
}
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// SDK_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail'
|
||||
type SDK_VerifyEmail_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyEmail is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - verificationToken string
|
||||
func (_e *SDK_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *SDK_VerifyEmail_Call {
|
||||
return &SDK_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)}
|
||||
}
|
||||
|
||||
func (_c *SDK_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *SDK_VerifyEmail_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_VerifyEmail_Call) Return(sDKError errors.SDKError) *SDK_VerifyEmail_Call {
|
||||
_c.Call.Return(sDKError)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) errors.SDKError) *SDK_VerifyEmail_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ViewCert provides a mock function for the type SDK
|
||||
func (_mock *SDK) ViewCert(ctx context.Context, certID string, domainID string, token string) (sdk.Cert, errors.SDKError) {
|
||||
ret := _mock.Called(ctx, certID, domainID, token)
|
||||
|
||||
@@ -178,6 +178,20 @@ type SDK interface {
|
||||
// fmt.Println(user)
|
||||
CreateUser(ctx context.Context, user User, token string) (User, errors.SDKError)
|
||||
|
||||
// SendVerification sends a verification email to the user.
|
||||
//
|
||||
// example:
|
||||
// err := sdk.SendVerification("token")
|
||||
// fmt.Println(err)
|
||||
SendVerification(ctx context.Context, token string) errors.SDKError
|
||||
|
||||
// VerifyEmail verifies the user's email address using the provided token.
|
||||
//
|
||||
// example:
|
||||
// err := sdk.VerifyEmail("verificationToken")
|
||||
// fmt.Println(user)
|
||||
VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError
|
||||
|
||||
// User returns user object by id.
|
||||
//
|
||||
// example:
|
||||
|
||||
+28
-7
@@ -15,13 +15,16 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
usersEndpoint = "users"
|
||||
enableEndpoint = "enable"
|
||||
disableEndpoint = "disable"
|
||||
issueTokenEndpoint = "tokens/issue"
|
||||
refreshTokenEndpoint = "tokens/refresh"
|
||||
membersEndpoint = "members"
|
||||
PasswordResetEndpoint = "password"
|
||||
usersEndpoint = "users"
|
||||
enableEndpoint = "enable"
|
||||
disableEndpoint = "disable"
|
||||
issueTokenEndpoint = "tokens/issue"
|
||||
refreshTokenEndpoint = "tokens/refresh"
|
||||
membersEndpoint = "members"
|
||||
PasswordResetEndpoint = "password"
|
||||
sendVerificationEndpoint = "send-verification"
|
||||
verifyEmailEndpoint = "verify-email"
|
||||
tokenQueryParamKey = "token"
|
||||
)
|
||||
|
||||
// User represents supermq user its credentials.
|
||||
@@ -61,6 +64,24 @@ func (sdk mgSDK) CreateUser(ctx context.Context, user User, token string) (User,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (sdk mgSDK) SendVerification(ctx context.Context, token string) errors.SDKError {
|
||||
url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, usersEndpoint, sendVerificationEndpoint)
|
||||
|
||||
_, _, sdkErr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK)
|
||||
|
||||
return sdkErr
|
||||
}
|
||||
|
||||
func (sdk mgSDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError {
|
||||
url := fmt.Sprintf("%s/%s?%s=%s", sdk.usersURL, verifyEmailEndpoint, tokenQueryParamKey, verificationToken)
|
||||
|
||||
_, _, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK)
|
||||
if sdkErr != nil {
|
||||
return sdkErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sdk mgSDK) Users(ctx context.Context, pm PageMetadata, token string) (UsersPage, errors.SDKError) {
|
||||
url, err := sdk.withQueryParams(sdk.usersURL, usersEndpoint, pm)
|
||||
if err != nil {
|
||||
|
||||
@@ -45,8 +45,9 @@ func setupUsers() (*httptest.Server, *umocks.Service, *authnmocks.Authentication
|
||||
provider := new(oauth2mocks.Provider)
|
||||
provider.On("Name").Return("test")
|
||||
authn := new(authnmocks.Authentication)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithDomainCheck(false), smqauthn.WithAllowUnverifiedUser(true))
|
||||
token := new(authmocks.TokenServiceClient)
|
||||
httpapi.MakeHandler(usvc, authn, token, true, mux, logger, "", passRegex, idp, provider)
|
||||
httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, provider)
|
||||
|
||||
return httptest.NewServer(mux), usvc, authn
|
||||
}
|
||||
@@ -567,6 +568,10 @@ func TestListUsers(t *testing.T) {
|
||||
resp, err := mgsdk.Users(context.Background(), tc.pageMeta, tc.token)
|
||||
assert.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.response, resp)
|
||||
if tc.token != "" {
|
||||
ok := authCall.Parent.AssertCalled(t, "Authenticate", mock.Anything, tc.token)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
if tc.err == nil {
|
||||
ok := svcCall.Parent.AssertCalled(t, "ListUsers", mock.Anything, tc.session, tc.svcReq)
|
||||
assert.True(t, ok)
|
||||
|
||||
+1
-1
@@ -38,7 +38,7 @@ done
|
||||
###
|
||||
# Users
|
||||
###
|
||||
SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_EMAIL_TEMPLATE=../docker/templates/users.tmpl $BUILD_DIR/supermq-users &
|
||||
SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=../docker/templates/reset-password-email.tmpl SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email SMQ_VERIFICATION_EMAIL_TEMPLATE=../docker/templates/verification-email.tmpl $BUILD_DIR/supermq-users &
|
||||
|
||||
###
|
||||
# Clients
|
||||
|
||||
@@ -114,6 +114,7 @@ packages:
|
||||
github.com/absmach/supermq/pkg/authn:
|
||||
interfaces:
|
||||
Authentication:
|
||||
AuthNMiddleware:
|
||||
github.com/absmach/supermq/pkg/authz:
|
||||
interfaces:
|
||||
Authorization:
|
||||
|
||||
+49
-43
@@ -12,48 +12,51 @@ For in-depth explanation of the aforementioned scenarios, as well as thorough un
|
||||
|
||||
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
|
||||
|
||||
| Variable | Description | Default |
|
||||
| ------------------------------ | ----------------------------------------------------------------------- | --------------------------------- |
|
||||
| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info |
|
||||
| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | <admin@example.com> |
|
||||
| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 |
|
||||
| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ |
|
||||
| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost |
|
||||
| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 |
|
||||
| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" |
|
||||
| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" |
|
||||
| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
|
||||
| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" |
|
||||
| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 |
|
||||
| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s |
|
||||
| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" |
|
||||
| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" |
|
||||
| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
|
||||
| SMQ_USERS_DB_HOST | Database host address | localhost |
|
||||
| SMQ_USERS_DB_PORT | Database host port | 5432 |
|
||||
| SMQ_USERS_DB_USER | Database user | supermq |
|
||||
| SMQ_USERS_DB_PASS | Database password | supermq |
|
||||
| SMQ_USERS_DB_NAME | Name of the database used by the service | users |
|
||||
| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
|
||||
| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
|
||||
| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" |
|
||||
| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
|
||||
| SMQ_EMAIL_HOST | Mail server host | localhost |
|
||||
| SMQ_EMAIL_PORT | Mail server port | 25 |
|
||||
| SMQ_EMAIL_USERNAME | Mail server username | "" |
|
||||
| SMQ_EMAIL_PASSWORD | Mail server password | "" |
|
||||
| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" |
|
||||
| SMQ_EMAIL_FROM_NAME | Email "from" name | "" |
|
||||
| SMQ_EMAIL_TEMPLATE | Email template for sending emails with password reset link | email.tmpl |
|
||||
| SMQ_USERS_ES_URL | Event store URL | <nats://localhost:4222> |
|
||||
| SMQ_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
|
||||
| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | <http://localhost:9095/domains> |
|
||||
| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | <http://localhost:9095/error> |
|
||||
| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h |
|
||||
| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h |
|
||||
| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
|
||||
| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true |
|
||||
| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" |
|
||||
| Variable | Description | Default |
|
||||
| --------------------------------- | ----------------------------------------------------------------------- | --------------------------------- |
|
||||
| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info |
|
||||
| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | <admin@example.com> |
|
||||
| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 |
|
||||
| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ |
|
||||
| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost |
|
||||
| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 |
|
||||
| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" |
|
||||
| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" |
|
||||
| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
|
||||
| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" |
|
||||
| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 |
|
||||
| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s |
|
||||
| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" |
|
||||
| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" |
|
||||
| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
|
||||
| SMQ_USERS_DB_HOST | Database host address | localhost |
|
||||
| SMQ_USERS_DB_PORT | Database host port | 5432 |
|
||||
| SMQ_USERS_DB_USER | Database user | supermq |
|
||||
| SMQ_USERS_DB_PASS | Database password | supermq |
|
||||
| SMQ_USERS_DB_NAME | Name of the database used by the service | users |
|
||||
| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
|
||||
| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
|
||||
| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" |
|
||||
| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
|
||||
| SMQ_EMAIL_HOST | Mail server host | localhost |
|
||||
| SMQ_EMAIL_PORT | Mail server port | 25 |
|
||||
| SMQ_EMAIL_USERNAME | Mail server username | "" |
|
||||
| SMQ_EMAIL_PASSWORD | Mail server password | "" |
|
||||
| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" |
|
||||
| SMQ_EMAIL_FROM_NAME | Email "from" name | "" |
|
||||
| SMQ_PASSWORD_RESET_URL_PREFIX | Password reset URL prefix | <http://localhost/password/reset> |
|
||||
| SMQ_PASSWORD_RESET_EMAIL_TEMPLATE | Password reset email template | reset-password-email.tmpl |
|
||||
| SMQ_VERIFICATION_URL_PREFIX | Verification URL prefix | <http://localhost/verify-email> |
|
||||
| SMQ_VERIFICATION_EMAIL_TEMPLATE | Verification email template | verification-email.tmpl |
|
||||
| SMQ_USERS_ES_URL | Event store URL | <nats://localhost:4222> |
|
||||
| SMQ_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
|
||||
| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | <http://localhost:9095/domains> |
|
||||
| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | <http://localhost:9095/error> |
|
||||
| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h |
|
||||
| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h |
|
||||
| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
|
||||
| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true |
|
||||
| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" |
|
||||
|
||||
## Deployment
|
||||
|
||||
@@ -104,7 +107,10 @@ SMQ_EMAIL_USERNAME="18bf7f7070513" \
|
||||
SMQ_EMAIL_PASSWORD="2b0d302e775b1e" \
|
||||
SMQ_EMAIL_FROM_ADDRESS=from@example.com \
|
||||
SMQ_EMAIL_FROM_NAME=Example \
|
||||
SMQ_EMAIL_TEMPLATE="docker/templates/users.tmpl" \
|
||||
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset \
|
||||
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=docker/templates/reset-password-email.tmpl \
|
||||
SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email \
|
||||
SMQ_VERIFICATION_EMAIL_TEMPLATE=docker/templates/verification-email.tmpl \
|
||||
SMQ_USERS_ES_URL=nats://localhost:4222 \
|
||||
SMQ_JAEGER_URL=http://localhost:14268/api/traces \
|
||||
SMQ_JAEGER_TRACE_RATIO=1.0 \
|
||||
|
||||
+147
-1
@@ -94,8 +94,9 @@ func newUsersServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticat
|
||||
provider := new(oauth2mocks.Provider)
|
||||
provider.On("Name").Return("test")
|
||||
authn := new(authnmocks.Authentication)
|
||||
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
|
||||
token := new(authmocks.TokenServiceClient)
|
||||
usersapi.MakeHandler(svc, authn, token, true, mux, logger, "", passRegex, idp, provider)
|
||||
usersapi.MakeHandler(svc, am, token, true, mux, logger, "", passRegex, idp, provider)
|
||||
|
||||
return httptest.NewServer(mux), svc, authn
|
||||
}
|
||||
@@ -1750,6 +1751,151 @@ func TestPasswordResetRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendVerification(t *testing.T) {
|
||||
us, svc, authn := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
status int
|
||||
authnRes smqauthn.Session
|
||||
authnErr error
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "send verification with valid token",
|
||||
token: validToken,
|
||||
status: http.StatusOK,
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "send verification with invalid token",
|
||||
token: inValidToken,
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
authnRes: smqauthn.Session{},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "send verification with empty token",
|
||||
token: "",
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
authnRes: smqauthn.Session{},
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "send verification with service error",
|
||||
token: validToken,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID},
|
||||
svcErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/users/send-verification", us.URL),
|
||||
token: tc.token,
|
||||
}
|
||||
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("SendVerification", mock.Anything, tc.authnRes).Return(tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err))
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var errRes respBody
|
||||
if len(body) > 0 {
|
||||
if err := json.Unmarshal(body, &errRes); err != nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err))
|
||||
}
|
||||
}
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyEmail(t *testing.T) {
|
||||
us, svc, _ := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
status int
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "verify email with valid token",
|
||||
token: validToken,
|
||||
status: http.StatusOK,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "verify email with empty token",
|
||||
token: "",
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrValidation,
|
||||
},
|
||||
{
|
||||
desc: "verify email with service error",
|
||||
token: validToken,
|
||||
status: http.StatusBadRequest,
|
||||
svcErr: svcerr.ErrMalformedEntity,
|
||||
err: svcerr.ErrMalformedEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodGet,
|
||||
url: fmt.Sprintf("%s/verify-email?token=%s", us.URL, tc.token),
|
||||
}
|
||||
|
||||
svcCall := svc.On("VerifyEmail", mock.Anything, mock.Anything).Return(users.User{}, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err))
|
||||
}
|
||||
defer res.Body.Close()
|
||||
var errRes respBody
|
||||
if len(body) > 0 {
|
||||
if err := json.Unmarshal(body, &errRes); err != nil {
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err))
|
||||
}
|
||||
}
|
||||
if errRes.Err != "" || errRes.Message != "" {
|
||||
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordReset(t *testing.T) {
|
||||
us, svc, authn := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
+48
-17
@@ -6,7 +6,6 @@ package api
|
||||
import (
|
||||
"context"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
apiutil "github.com/absmach/supermq/api/http/util"
|
||||
"github.com/absmach/supermq/pkg/authn"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -25,7 +24,7 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin
|
||||
|
||||
var ok bool
|
||||
if !selfRegister {
|
||||
session, ok = ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok = ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -43,6 +42,38 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin
|
||||
}
|
||||
}
|
||||
|
||||
func sendVerificationEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
_ = request.(sendVerificationReq)
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
if err := svc.SendVerification(ctx, session); err != nil {
|
||||
return sendVerificationRes{}, err
|
||||
}
|
||||
|
||||
return sendVerificationRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func verifyEmailEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(verifyEmailReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
if _, err := svc.VerifyEmail(ctx, req.token); err != nil {
|
||||
return verifyEmailRes{}, err
|
||||
}
|
||||
|
||||
return verifyEmailRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func viewEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(viewUserReq)
|
||||
@@ -50,7 +81,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -65,7 +96,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
|
||||
func viewProfileEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -85,7 +116,7 @@ func listUsersEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -172,7 +203,7 @@ func updateEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -199,7 +230,7 @@ func updateTagsEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -224,7 +255,7 @@ func updateEmailEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -273,7 +304,7 @@ func passwordResetEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -292,7 +323,7 @@ func updateSecretEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -312,7 +343,7 @@ func updateUsernameEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -337,7 +368,7 @@ func updateProfilePictureEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
ProfilePicture: req.ProfilePicture,
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthorization
|
||||
}
|
||||
@@ -363,7 +394,7 @@ func updateRoleEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
Role: req.role,
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -404,7 +435,7 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -429,7 +460,7 @@ func enableEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -450,7 +481,7 @@ func disableEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
@@ -471,7 +502,7 @@ func deleteEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(api.SessionKey).(authn.Session)
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
+20
-6
@@ -4,7 +4,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"net/url"
|
||||
|
||||
api "github.com/absmach/supermq/api/http"
|
||||
@@ -39,15 +38,16 @@ func (req createUserReq) validate() error {
|
||||
return err
|
||||
}
|
||||
// Username must not be a valid email format due to username/email login.
|
||||
if _, err := mail.ParseAddress(req.User.Credentials.Username); err == nil {
|
||||
if err := api.ValidateEmail(req.User.Credentials.Username); err == nil {
|
||||
return apiutil.ErrInvalidUsername
|
||||
}
|
||||
|
||||
if req.User.Email == "" {
|
||||
return apiutil.ErrMissingEmail
|
||||
}
|
||||
// Email must be in a valid format.
|
||||
if _, err := mail.ParseAddress(req.User.Email); err != nil {
|
||||
return apiutil.ErrInvalidEmail
|
||||
if err := api.ValidateEmail(req.User.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.User.Credentials.Secret == "" {
|
||||
return apiutil.ErrMissingPass
|
||||
@@ -67,6 +67,20 @@ func (req createUserReq) validate() error {
|
||||
return req.User.Validate()
|
||||
}
|
||||
|
||||
type sendVerificationReq struct{}
|
||||
|
||||
type verifyEmailReq struct {
|
||||
token string
|
||||
}
|
||||
|
||||
func (req verifyEmailReq) validate() error {
|
||||
if req.token == "" {
|
||||
return apiutil.ErrInvalidVerification
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type viewUserReq struct {
|
||||
id string
|
||||
}
|
||||
@@ -183,8 +197,8 @@ func (req updateEmailReq) validate() error {
|
||||
if req.id == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
if _, err := mail.ParseAddress(req.Email); err != nil {
|
||||
return apiutil.ErrInvalidEmail
|
||||
if err := api.ValidateEmail(req.Email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -16,6 +16,8 @@ const MailSent = "Email with reset link is sent"
|
||||
|
||||
var (
|
||||
_ supermq.Response = (*tokenRes)(nil)
|
||||
_ supermq.Response = (*sendVerificationRes)(nil)
|
||||
_ supermq.Response = (*verifyEmailRes)(nil)
|
||||
_ supermq.Response = (*viewUserRes)(nil)
|
||||
_ supermq.Response = (*createUserRes)(nil)
|
||||
_ supermq.Response = (*changeUserStatusRes)(nil)
|
||||
@@ -79,6 +81,34 @@ func (res tokenRes) Empty() bool {
|
||||
return res.AccessToken == "" || res.RefreshToken == ""
|
||||
}
|
||||
|
||||
type sendVerificationRes struct{}
|
||||
|
||||
func (res sendVerificationRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res sendVerificationRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res sendVerificationRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type verifyEmailRes struct{}
|
||||
|
||||
func (res verifyEmailRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res verifyEmailRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res verifyEmailRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type updateUserRes struct {
|
||||
users.User `json:",inline"`
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
)
|
||||
|
||||
// MakeHandler returns a HTTP handler for Users and Groups API endpoints.
|
||||
func MakeHandler(cls users.Service, authn smqauthn.Authentication, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler {
|
||||
func MakeHandler(cls users.Service, authn smqauthn.AuthNMiddleware, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler {
|
||||
mux = usersHandler(cls, authn, tokensvc, selfRegister, mux, logger, pr, idp, providers...)
|
||||
|
||||
mux.Get("/health", supermq.Health("users", instanceID))
|
||||
|
||||
+52
-19
@@ -28,13 +28,15 @@ import (
|
||||
var passRegex = regexp.MustCompile("^.{8,}$")
|
||||
|
||||
// usersHandler returns a HTTP handler for API endpoints.
|
||||
func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux {
|
||||
func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux {
|
||||
passRegex = pr
|
||||
|
||||
opts := []kithttp.ServerOption{
|
||||
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
|
||||
}
|
||||
|
||||
// All endpoints in users service don't required Domain check
|
||||
authn = authn.WithOptions(smqauthn.WithDomainCheck(false))
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Use(api.RequestIDMiddleware(idp))
|
||||
|
||||
@@ -47,16 +49,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
opts...,
|
||||
), "register_user").ServeHTTP)
|
||||
default:
|
||||
r.With(api.AuthenticateMiddleware(authn, false)).Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
r.With(authn.Middleware()).Post("/", otelhttp.NewHandler(kithttp.NewServer(
|
||||
registrationEndpoint(svc, selfRegister),
|
||||
decodeCreateUserReq,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "register_user").ServeHTTP)
|
||||
}
|
||||
|
||||
// Endpoints which are allowed for unverified user
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, false))
|
||||
r.Use(authn.WithOptions(smqauthn.WithAllowUnverifiedUser(true)).Middleware())
|
||||
r.Post("/send-verification", otelhttp.NewHandler(kithttp.NewServer(
|
||||
sendVerificationEndpoint(svc),
|
||||
decodeSendVerification,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "send_verification").ServeHTTP)
|
||||
|
||||
r.Get("/profile", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewProfileEndpoint(svc),
|
||||
@@ -64,6 +72,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "view_profile").ServeHTTP)
|
||||
r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer(
|
||||
refreshTokenEndpoint(svc),
|
||||
decodeRefreshToken,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "refresh_token").ServeHTTP)
|
||||
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateEmailEndpoint(svc),
|
||||
decodeUpdateUserEmail,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "update_user_email").ServeHTTP)
|
||||
})
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(authn.Middleware())
|
||||
|
||||
r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer(
|
||||
viewEndpoint(svc),
|
||||
@@ -121,13 +145,6 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
opts...,
|
||||
), "update_user_tags").ServeHTTP)
|
||||
|
||||
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateEmailEndpoint(svc),
|
||||
decodeUpdateUserEmail,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "update_user_email").ServeHTTP)
|
||||
|
||||
r.Patch("/{id}/role", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateRoleEndpoint(svc),
|
||||
decodeUpdateUserRole,
|
||||
@@ -155,18 +172,11 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "delete_user").ServeHTTP)
|
||||
|
||||
r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer(
|
||||
refreshTokenEndpoint(svc),
|
||||
decodeRefreshToken,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "refresh_token").ServeHTTP)
|
||||
})
|
||||
})
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(api.AuthenticateMiddleware(authn, false))
|
||||
r.Use(authn.Middleware())
|
||||
r.Put("/password/reset", otelhttp.NewHandler(kithttp.NewServer(
|
||||
passwordResetEndpoint(svc),
|
||||
decodePasswordReset,
|
||||
@@ -189,6 +199,13 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
opts...,
|
||||
), "password_reset_req").ServeHTTP)
|
||||
|
||||
r.Get("/verify-email", otelhttp.NewHandler(kithttp.NewServer(
|
||||
verifyEmailEndpoint(svc),
|
||||
decodeVerifyEmail,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "verify_email").ServeHTTP)
|
||||
|
||||
for _, provider := range providers {
|
||||
r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, svc, tokenClient))
|
||||
}
|
||||
@@ -196,6 +213,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
|
||||
return r
|
||||
}
|
||||
|
||||
func decodeSendVerification(_ context.Context, r *http.Request) (any, error) {
|
||||
req := sendVerificationReq{}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeVerifyEmail(_ context.Context, r *http.Request) (any, error) {
|
||||
token, err := apiutil.ReadStringQuery(r, api.TokenKey, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
return verifyEmailReq{
|
||||
token: token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func decodeViewUser(_ context.Context, r *http.Request) (any, error) {
|
||||
req := viewUserReq{
|
||||
id: chi.URLParam(r, "id"),
|
||||
|
||||
@@ -7,4 +7,7 @@ package users
|
||||
type Emailer interface {
|
||||
// SendPasswordReset sends an email to the user with a link to reset the password.
|
||||
SendPasswordReset(To []string, user, token string) error
|
||||
|
||||
// SendVerification sends an email to the user with a verification token.
|
||||
SendVerification(To []string, user, verificationToken string) error
|
||||
}
|
||||
|
||||
@@ -13,20 +13,38 @@ import (
|
||||
var _ users.Emailer = (*emailer)(nil)
|
||||
|
||||
type emailer struct {
|
||||
resetURL string
|
||||
agent *email.Agent
|
||||
resetURL string
|
||||
verificationURL string
|
||||
resetAgent *email.Agent
|
||||
verifyAgent *email.Agent
|
||||
}
|
||||
|
||||
// New creates new emailer utility.
|
||||
func New(url string, c *email.Config) (users.Emailer, error) {
|
||||
e, err := email.New(c)
|
||||
func New(resetURL, verificationURL string, resetConfig, verifyConfig *email.Config) (users.Emailer, error) {
|
||||
resetAgent, err := email.New(resetConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
verifyAgent, err := email.New(verifyConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &emailer{
|
||||
resetURL: url,
|
||||
agent: e,
|
||||
}, err
|
||||
resetURL: resetURL,
|
||||
verificationURL: verificationURL,
|
||||
resetAgent: resetAgent,
|
||||
verifyAgent: verifyAgent,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *emailer) SendPasswordReset(to []string, user, token string) error {
|
||||
url := fmt.Sprintf("%s?token=%s", e.resetURL, token)
|
||||
return e.agent.Send(to, "", "Password Reset Request", "", user, url, "")
|
||||
return e.resetAgent.Send(to, "", "Password Reset Request", "", user, url, "")
|
||||
}
|
||||
|
||||
func (e *emailer) SendVerification(to []string, user, verificationToken string) error {
|
||||
url := fmt.Sprintf("%s?token=%s", e.verificationURL, verificationToken)
|
||||
return e.verifyAgent.Send(to, "", "Email Verification", "", user, url, "")
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
const (
|
||||
userPrefix = "user."
|
||||
userCreate = userPrefix + "create"
|
||||
userSendVerification = userPrefix + "send_verification"
|
||||
userVerifyEmail = userPrefix + "verify_email"
|
||||
userUpdate = userPrefix + "update"
|
||||
userUpdateRole = userPrefix + "update_role"
|
||||
userUpdateTags = userPrefix + "update_tags"
|
||||
@@ -41,6 +43,8 @@ const (
|
||||
|
||||
var (
|
||||
_ events.Event = (*createUserEvent)(nil)
|
||||
_ events.Event = (*sendVerificationEvent)(nil)
|
||||
_ events.Event = (*verifyEmailEvent)(nil)
|
||||
_ events.Event = (*updateUserEvent)(nil)
|
||||
_ events.Event = (*updateProfilePictureEvent)(nil)
|
||||
_ events.Event = (*updateUsernameEvent)(nil)
|
||||
@@ -99,6 +103,37 @@ func (uce createUserEvent) Encode() (map[string]any, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
type sendVerificationEvent struct {
|
||||
authn.Session
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (sve sendVerificationEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": userSendVerification,
|
||||
"user_id": sve.UserID,
|
||||
"token_type": sve.Type.String(),
|
||||
"request_id": sve.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type verifyEmailEvent struct {
|
||||
requestID string
|
||||
email string
|
||||
userID string
|
||||
verifiedAt time.Time
|
||||
}
|
||||
|
||||
func (vee verifyEmailEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": userVerifyEmail,
|
||||
"request_id": vee.requestID,
|
||||
"email": vee.email,
|
||||
"user_id": vee.userID,
|
||||
"verified_at": vee.verifiedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type updateUserEvent struct {
|
||||
users.User
|
||||
operation string
|
||||
|
||||
@@ -17,6 +17,8 @@ import (
|
||||
const (
|
||||
supermqPrefix = "supermq."
|
||||
createStream = supermqPrefix + userCreate
|
||||
sendVerificationStream = supermqPrefix + userSendVerification
|
||||
verifyEmailStream = supermqPrefix + userVerifyEmail
|
||||
updateStream = supermqPrefix + userUpdate
|
||||
updateRoleStream = supermqPrefix + userUpdateRole
|
||||
updateTagsStream = supermqPrefix + userUpdateTags
|
||||
@@ -82,6 +84,38 @@ func (es *eventStore) Register(ctx context.Context, session authn.Session, user
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
err := es.svc.SendVerification(ctx, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := sendVerificationEvent{
|
||||
session,
|
||||
middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
return es.Publish(ctx, sendVerificationStream, event)
|
||||
}
|
||||
|
||||
func (es *eventStore) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
|
||||
user, err := es.svc.VerifyEmail(ctx, verificationToken)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
event := verifyEmailEvent{
|
||||
email: user.Email,
|
||||
userID: user.ID,
|
||||
verifiedAt: user.VerifiedAt,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
if err := es.Publish(ctx, verifyEmailStream, event); err != nil {
|
||||
return user, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) Update(ctx context.Context, session authn.Session, id string, usr users.UserReq) (users.User, error) {
|
||||
user, err := es.svc.Update(ctx, session, id, usr)
|
||||
if err != nil {
|
||||
|
||||
@@ -26,11 +26,15 @@ type authorizationMiddleware struct {
|
||||
|
||||
// AuthorizationMiddleware adds authorization to the clients service.
|
||||
func AuthorizationMiddleware(svc users.Service, authz smqauthz.Authorization, selfRegister bool) users.Service {
|
||||
return &authorizationMiddleware{
|
||||
svc: svc,
|
||||
authz: authz,
|
||||
selfRegister: selfRegister,
|
||||
}
|
||||
return &authorizationMiddleware{svc: svc, authz: authz, selfRegister: selfRegister}
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
return am.svc.SendVerification(ctx, session)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
|
||||
return am.svc.VerifyEmail(ctx, verificationToken)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) Register(ctx context.Context, session authn.Session, user users.User, selfRegister bool) (users.User, error) {
|
||||
|
||||
@@ -50,6 +50,43 @@ func (lm *loggingMiddleware) Register(ctx context.Context, session authn.Session
|
||||
return lm.svc.Register(ctx, session, user, selfRegister)
|
||||
}
|
||||
|
||||
// SendVerification logs the send_verification request. It logs the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) SendVerification(ctx context.Context, session authn.Session) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Send verification failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Send verification completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.SendVerification(ctx, session)
|
||||
}
|
||||
|
||||
// VerifyEmail logs the verify_email request. It logs the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (user users.User, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("user_id", user.ID),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Verify email failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Verify email completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.VerifyEmail(ctx, verificationToken)
|
||||
}
|
||||
|
||||
// IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) {
|
||||
|
||||
@@ -39,6 +39,24 @@ func (ms *metricsMiddleware) Register(ctx context.Context, session authn.Session
|
||||
return ms.svc.Register(ctx, session, user, selfRegister)
|
||||
}
|
||||
|
||||
// SendVerification instruments SendVerification method with metrics.
|
||||
func (ms *metricsMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "send_verification").Add(1)
|
||||
ms.latency.With("method", "send_verification").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.SendVerification(ctx, session)
|
||||
}
|
||||
|
||||
// VerifyEmail instruments VerifyEmail method with metrics.
|
||||
func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "verify_email").Add(1)
|
||||
ms.latency.With("method", "verify_email").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.VerifyEmail(ctx, verificationToken)
|
||||
}
|
||||
|
||||
// IssueToken instruments IssueToken method with metrics.
|
||||
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
defer func(begin time.Time) {
|
||||
|
||||
@@ -100,3 +100,66 @@ func (_c *Emailer_SendPasswordReset_Call) RunAndReturn(run func(To []string, use
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendVerification provides a mock function for the type Emailer
|
||||
func (_mock *Emailer) SendVerification(To []string, user string, verificationToken string) error {
|
||||
ret := _mock.Called(To, user, verificationToken)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendVerification")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func([]string, string, string) error); ok {
|
||||
r0 = returnFunc(To, user, verificationToken)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Emailer_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
|
||||
type Emailer_SendVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendVerification is a helper method to define mock.On call
|
||||
// - To []string
|
||||
// - user string
|
||||
// - verificationToken string
|
||||
func (_e *Emailer_Expecter) SendVerification(To interface{}, user interface{}, verificationToken interface{}) *Emailer_SendVerification_Call {
|
||||
return &Emailer_SendVerification_Call{Call: _e.mock.On("SendVerification", To, user, verificationToken)}
|
||||
}
|
||||
|
||||
func (_c *Emailer_SendVerification_Call) Run(run func(To []string, user string, verificationToken string)) *Emailer_SendVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 []string
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].([]string)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Emailer_SendVerification_Call) Return(err error) *Emailer_SendVerification_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Emailer_SendVerification_Call) RunAndReturn(run func(To []string, user string, verificationToken string) error) *Emailer_SendVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -41,6 +41,63 @@ func (_m *Repository) EXPECT() *Repository_Expecter {
|
||||
return &Repository_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// AddUserVerification provides a mock function for the type Repository
|
||||
func (_mock *Repository) AddUserVerification(ctx context.Context, uv users.UserVerification) error {
|
||||
ret := _mock.Called(ctx, uv)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for AddUserVerification")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok {
|
||||
r0 = returnFunc(ctx, uv)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Repository_AddUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddUserVerification'
|
||||
type Repository_AddUserVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// AddUserVerification is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - uv users.UserVerification
|
||||
func (_e *Repository_Expecter) AddUserVerification(ctx interface{}, uv interface{}) *Repository_AddUserVerification_Call {
|
||||
return &Repository_AddUserVerification_Call{Call: _e.mock.On("AddUserVerification", ctx, uv)}
|
||||
}
|
||||
|
||||
func (_c *Repository_AddUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_AddUserVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 users.UserVerification
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(users.UserVerification)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_AddUserVerification_Call) Return(err error) *Repository_AddUserVerification_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_AddUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_AddUserVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ChangeStatus provides a mock function for the type Repository
|
||||
func (_mock *Repository) ChangeStatus(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
@@ -551,6 +608,78 @@ func (_c *Repository_RetrieveByUsername_Call) RunAndReturn(run func(ctx context.
|
||||
return _c
|
||||
}
|
||||
|
||||
// RetrieveUserVerification provides a mock function for the type Repository
|
||||
func (_mock *Repository) RetrieveUserVerification(ctx context.Context, userID string, email string) (users.UserVerification, error) {
|
||||
ret := _mock.Called(ctx, userID, email)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RetrieveUserVerification")
|
||||
}
|
||||
|
||||
var r0 users.UserVerification
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (users.UserVerification, error)); ok {
|
||||
return returnFunc(ctx, userID, email)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) users.UserVerification); ok {
|
||||
r0 = returnFunc(ctx, userID, email)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.UserVerification)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, userID, email)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_RetrieveUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveUserVerification'
|
||||
type Repository_RetrieveUserVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RetrieveUserVerification is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - email string
|
||||
func (_e *Repository_Expecter) RetrieveUserVerification(ctx interface{}, userID interface{}, email interface{}) *Repository_RetrieveUserVerification_Call {
|
||||
return &Repository_RetrieveUserVerification_Call{Call: _e.mock.On("RetrieveUserVerification", ctx, userID, email)}
|
||||
}
|
||||
|
||||
func (_c *Repository_RetrieveUserVerification_Call) Run(run func(ctx context.Context, userID string, email string)) *Repository_RetrieveUserVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_RetrieveUserVerification_Call) Return(userVerification users.UserVerification, err error) *Repository_RetrieveUserVerification_Call {
|
||||
_c.Call.Return(userVerification, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_RetrieveUserVerification_Call) RunAndReturn(run func(ctx context.Context, userID string, email string) (users.UserVerification, error)) *Repository_RetrieveUserVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Save provides a mock function for the type Repository
|
||||
func (_mock *Repository) Save(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
@@ -755,6 +884,138 @@ func (_c *Repository_Update_Call) RunAndReturn(run func(ctx context.Context, id
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateEmail provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateEmail(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateEmail")
|
||||
}
|
||||
|
||||
var r0 users.User
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
|
||||
return returnFunc(ctx, user)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
|
||||
r0 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.User)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
|
||||
r1 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_UpdateEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateEmail'
|
||||
type Repository_UpdateEmail_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateEmail is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - user users.User
|
||||
func (_e *Repository_Expecter) UpdateEmail(ctx interface{}, user interface{}) *Repository_UpdateEmail_Call {
|
||||
return &Repository_UpdateEmail_Call{Call: _e.mock.On("UpdateEmail", ctx, user)}
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateEmail_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateEmail_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 users.User
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(users.User)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateEmail_Call) Return(user1 users.User, err error) *Repository_UpdateEmail_Call {
|
||||
_c.Call.Return(user1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateEmail_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateEmail_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateRole provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateRole(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateRole")
|
||||
}
|
||||
|
||||
var r0 users.User
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
|
||||
return returnFunc(ctx, user)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
|
||||
r0 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.User)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
|
||||
r1 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole'
|
||||
type Repository_UpdateRole_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateRole is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - user users.User
|
||||
func (_e *Repository_Expecter) UpdateRole(ctx interface{}, user interface{}) *Repository_UpdateRole_Call {
|
||||
return &Repository_UpdateRole_Call{Call: _e.mock.On("UpdateRole", ctx, user)}
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateRole_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateRole_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 users.User
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(users.User)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateRole_Call) Return(user1 users.User, err error) *Repository_UpdateRole_Call {
|
||||
_c.Call.Return(user1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateRole_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateSecret provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateSecret(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
@@ -821,6 +1082,63 @@ func (_c *Repository_UpdateSecret_Call) RunAndReturn(run func(ctx context.Contex
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateUserVerification provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error {
|
||||
ret := _mock.Called(ctx, uv)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateUserVerification")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok {
|
||||
r0 = returnFunc(ctx, uv)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Repository_UpdateUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUserVerification'
|
||||
type Repository_UpdateUserVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateUserVerification is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - uv users.UserVerification
|
||||
func (_e *Repository_Expecter) UpdateUserVerification(ctx interface{}, uv interface{}) *Repository_UpdateUserVerification_Call {
|
||||
return &Repository_UpdateUserVerification_Call{Call: _e.mock.On("UpdateUserVerification", ctx, uv)}
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_UpdateUserVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 users.UserVerification
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(users.UserVerification)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateUserVerification_Call) Return(err error) *Repository_UpdateUserVerification_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_UpdateUserVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateUsername provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
@@ -886,3 +1204,69 @@ func (_c *Repository_UpdateUsername_Call) RunAndReturn(run func(ctx context.Cont
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateVerifiedAt provides a mock function for the type Repository
|
||||
func (_mock *Repository) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) {
|
||||
ret := _mock.Called(ctx, user)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for UpdateVerifiedAt")
|
||||
}
|
||||
|
||||
var r0 users.User
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
|
||||
return returnFunc(ctx, user)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
|
||||
r0 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.User)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
|
||||
r1 = returnFunc(ctx, user)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Repository_UpdateVerifiedAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateVerifiedAt'
|
||||
type Repository_UpdateVerifiedAt_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateVerifiedAt is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - user users.User
|
||||
func (_e *Repository_Expecter) UpdateVerifiedAt(ctx interface{}, user interface{}) *Repository_UpdateVerifiedAt_Call {
|
||||
return &Repository_UpdateVerifiedAt_Call{Call: _e.mock.On("UpdateVerifiedAt", ctx, user)}
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateVerifiedAt_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateVerifiedAt_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 users.User
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(users.User)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateVerifiedAt_Call) Return(user1 users.User, err error) *Repository_UpdateVerifiedAt_Call {
|
||||
_c.Call.Return(user1, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Repository_UpdateVerifiedAt_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateVerifiedAt_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -923,6 +923,63 @@ func (_c *Service_SendPasswordReset_Call) RunAndReturn(run func(ctx context.Cont
|
||||
return _c
|
||||
}
|
||||
|
||||
// SendVerification provides a mock function for the type Service
|
||||
func (_mock *Service) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
ret := _mock.Called(ctx, session)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SendVerification")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) error); ok {
|
||||
r0 = returnFunc(ctx, session)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
|
||||
type Service_SendVerification_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SendVerification is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - session authn.Session
|
||||
func (_e *Service_Expecter) SendVerification(ctx interface{}, session interface{}) *Service_SendVerification_Call {
|
||||
return &Service_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, session)}
|
||||
}
|
||||
|
||||
func (_c *Service_SendVerification_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_SendVerification_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 authn.Session
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(authn.Session)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_SendVerification_Call) Return(err error) *Service_SendVerification_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_SendVerification_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) error) *Service_SendVerification_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Update provides a mock function for the type Service
|
||||
func (_mock *Service) Update(ctx context.Context, session authn.Session, id string, user users.UserReq) (users.User, error) {
|
||||
ret := _mock.Called(ctx, session, id, user)
|
||||
@@ -1463,6 +1520,72 @@ func (_c *Service_UpdateUsername_Call) RunAndReturn(run func(ctx context.Context
|
||||
return _c
|
||||
}
|
||||
|
||||
// VerifyEmail provides a mock function for the type Service
|
||||
func (_mock *Service) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
|
||||
ret := _mock.Called(ctx, verificationToken)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for VerifyEmail")
|
||||
}
|
||||
|
||||
var r0 users.User
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (users.User, error)); ok {
|
||||
return returnFunc(ctx, verificationToken)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) users.User); ok {
|
||||
r0 = returnFunc(ctx, verificationToken)
|
||||
} else {
|
||||
r0 = ret.Get(0).(users.User)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, verificationToken)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail'
|
||||
type Service_VerifyEmail_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// VerifyEmail is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - verificationToken string
|
||||
func (_e *Service_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *Service_VerifyEmail_Call {
|
||||
return &Service_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)}
|
||||
}
|
||||
|
||||
func (_c *Service_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *Service_VerifyEmail_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_VerifyEmail_Call) Return(user users.User, err error) *Service_VerifyEmail_Call {
|
||||
_c.Call.Return(user, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) (users.User, error)) *Service_VerifyEmail_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// View provides a mock function for the type Service
|
||||
func (_mock *Service) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
|
||||
ret := _mock.Called(ctx, session, id)
|
||||
|
||||
+24
-3
@@ -14,7 +14,7 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
Migrations: []*migrate.Migration{
|
||||
{
|
||||
Id: "clients_01",
|
||||
// VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters
|
||||
// VARCHAR(36) for column with IDs as UUIDS have a maximum of 36 characters
|
||||
// STATUS 0 to imply enabled and 1 to imply disabled
|
||||
// Role 0 to imply user role and 1 to imply admin role
|
||||
Up: []string{
|
||||
@@ -50,8 +50,8 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
Up: []string{
|
||||
`ALTER TABLE clients
|
||||
ADD COLUMN username VARCHAR(254) UNIQUE,
|
||||
ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '',
|
||||
ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '',
|
||||
ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '',
|
||||
ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '',
|
||||
ADD COLUMN profile_picture TEXT`,
|
||||
`ALTER TABLE clients RENAME COLUMN identity TO email`,
|
||||
`ALTER TABLE clients DROP COLUMN name`,
|
||||
@@ -97,6 +97,27 @@ func Migration() *migrate.MemoryMigrationSource {
|
||||
`ALTER TABLE users ALTER COLUMN updated_at TYPE TIMESTAMP;`,
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "clients_07",
|
||||
Up: []string{
|
||||
`ALTER TABLE users ADD COLUMN verified_at TIMESTAMPTZ DEFAULT NULL;`,
|
||||
`CREATE TABLE users_verifications (
|
||||
user_id VARCHAR(36) NOT NULL,
|
||||
email VARCHAR(254) NOT NULL,
|
||||
otp VARCHAR(255),
|
||||
created_at TIMESTAMPTZ,
|
||||
expires_at TIMESTAMPTZ,
|
||||
used_at TIMESTAMPTZ,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX idx_users_verifications_lookup ON users_verifications (user_id, email, created_at DESC);
|
||||
`,
|
||||
},
|
||||
Down: []string{
|
||||
`ALTER TABLE users DROP COLUMN verified_at;`,
|
||||
`DROP TABLE users_verifications;`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+64
-63
@@ -33,7 +33,7 @@ func NewRepository(db postgres.Database) users.Repository {
|
||||
func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error) {
|
||||
q := `INSERT INTO users (id, tags, email, secret, metadata, created_at, status, role, first_name, last_name, username, profile_picture)
|
||||
VALUES (:id, :tags, :email, :secret, :metadata, :created_at, :status, :role, :first_name, :last_name, :username, :profile_picture)
|
||||
RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture`
|
||||
RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture, verified_at`
|
||||
|
||||
dbu, err := toDBUser(c)
|
||||
if err != nil {
|
||||
@@ -81,7 +81,7 @@ func (repo *userRepo) CheckSuperAdmin(ctx context.Context, adminID string) error
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User, error) {
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture, verified_at
|
||||
FROM users WHERE id = :id`
|
||||
|
||||
dbu := DBUser{
|
||||
@@ -95,20 +95,20 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
|
||||
defer rows.Close()
|
||||
|
||||
dbu = DBUser{}
|
||||
if rows.Next() {
|
||||
if err = rows.StructScan(&dbu); err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
user, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
if !rows.Next() {
|
||||
return users.User{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
return users.User{}, repoerr.ErrNotFound
|
||||
if err = rows.StructScan(&dbu); err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
user, err := ToUser(dbu)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.UsersPage, error) {
|
||||
@@ -127,7 +127,7 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
|
||||
}
|
||||
|
||||
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
|
||||
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by
|
||||
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at
|
||||
FROM users u %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
|
||||
|
||||
dbPage, err := ToDBUsersPage(pm)
|
||||
@@ -180,35 +180,9 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
|
||||
func (repo *userRepo) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET username = :username, updated_at = :updated_at, updated_by = :updated_by
|
||||
WHERE id = :id AND status = :status
|
||||
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email`
|
||||
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email, role, verified_at`
|
||||
|
||||
dbu, err := toDBUser(user)
|
||||
if err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
|
||||
if err != nil {
|
||||
return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
|
||||
dbu = DBUser{
|
||||
ID: user.ID,
|
||||
Username: stringToNullString(user.Credentials.Username),
|
||||
UpdatedAt: sql.NullTime{Time: time.Now().UTC(), Valid: true},
|
||||
}
|
||||
|
||||
if ok := row.Next(); !ok {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrNotFound, row.Err())
|
||||
}
|
||||
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, err
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
|
||||
func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (users.User, error) {
|
||||
@@ -231,18 +205,10 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
|
||||
query = append(query, "tags = :tags")
|
||||
u.Tags = *ur.Tags
|
||||
}
|
||||
if ur.Role != nil && *ur.Role != users.AllRole {
|
||||
query = append(query, "role = :role")
|
||||
u.Role = *ur.Role
|
||||
}
|
||||
if ur.ProfilePicture != nil {
|
||||
query = append(query, "profile_picture = :profile_picture")
|
||||
u.ProfilePicture = *ur.ProfilePicture
|
||||
}
|
||||
if ur.Email != nil && *ur.Email != "" {
|
||||
query = append(query, "email = :email")
|
||||
u.Email = *ur.Email
|
||||
}
|
||||
u.UpdatedAt = time.Now().UTC()
|
||||
if ur.UpdatedAt != nil {
|
||||
query = append(query, "updated_at = :updated_at")
|
||||
@@ -259,7 +225,7 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
|
||||
|
||||
q := fmt.Sprintf(`UPDATE users SET %s
|
||||
WHERE id = :id AND status = :status
|
||||
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role`, upq)
|
||||
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role, verified_at`, upq)
|
||||
|
||||
u.Status = users.EnabledStatus
|
||||
return repo.update(ctx, u, q)
|
||||
@@ -278,21 +244,37 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string)
|
||||
defer row.Close()
|
||||
|
||||
dbu = DBUser{}
|
||||
if row.Next() {
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
if !row.Next() {
|
||||
return users.User{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
return users.User{}, repoerr.ErrNotFound
|
||||
if err := row.StructScan(&dbu); err != nil {
|
||||
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return ToUser(dbu)
|
||||
}
|
||||
|
||||
func (repo *userRepo) UpdateEmail(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET email = :email, verified_at = NULL, updated_at = :updated_at, updated_by = :updated_by
|
||||
WHERE id = :id AND status = :status
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
|
||||
user.Status = users.EnabledStatus
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
|
||||
func (repo *userRepo) UpdateRole(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET role = :role, updated_at = :updated_at, updated_by = :updated_by
|
||||
WHERE id = :id AND status = :status
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
|
||||
user.Status = users.EnabledStatus
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
|
||||
func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET secret = :secret, updated_at = :updated_at, updated_by = :updated_by
|
||||
WHERE id = :id AND status = :status
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username`
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
|
||||
user.Status = users.EnabledStatus
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
@@ -300,7 +282,15 @@ func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users.
|
||||
func (repo *userRepo) ChangeStatus(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET status = :status, updated_at = :updated_at, updated_by = :updated_by
|
||||
WHERE id = :id
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username`
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
|
||||
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
|
||||
func (repo *userRepo) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) {
|
||||
q := `UPDATE users SET verified_at = :verified_at
|
||||
WHERE id = :id and email = :email
|
||||
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
|
||||
|
||||
return repo.update(ctx, user, q)
|
||||
}
|
||||
@@ -434,7 +424,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.User, error) {
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at
|
||||
FROM users WHERE email = :email AND status = :status`
|
||||
|
||||
dbu := DBUser{
|
||||
@@ -461,7 +451,7 @@ func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.
|
||||
}
|
||||
|
||||
func (repo *userRepo) RetrieveByUsername(ctx context.Context, username string) (users.User, error) {
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username
|
||||
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at
|
||||
FROM users WHERE username = :username AND status = :status`
|
||||
|
||||
dbu := DBUser{
|
||||
@@ -504,6 +494,7 @@ type DBUser struct {
|
||||
LastName sql.NullString `db:"last_name, omitempty"`
|
||||
ProfilePicture sql.NullString `db:"profile_picture, omitempty"`
|
||||
Email string `db:"email,omitempty"`
|
||||
VerifiedAt sql.NullTime `db:"verified_at,omitempty"`
|
||||
}
|
||||
|
||||
func toDBUser(u users.User) (DBUser, error) {
|
||||
@@ -527,6 +518,10 @@ func toDBUser(u users.User) (DBUser, error) {
|
||||
if u.UpdatedAt != (time.Time{}) {
|
||||
updatedAt = sql.NullTime{Time: u.UpdatedAt, Valid: true}
|
||||
}
|
||||
var verifiedAt sql.NullTime
|
||||
if u.VerifiedAt != (time.Time{}) {
|
||||
verifiedAt = sql.NullTime{Time: u.VerifiedAt, Valid: true}
|
||||
}
|
||||
|
||||
return DBUser{
|
||||
ID: u.ID,
|
||||
@@ -543,6 +538,7 @@ func toDBUser(u users.User) (DBUser, error) {
|
||||
Username: stringToNullString(u.Credentials.Username),
|
||||
ProfilePicture: stringToNullString(u.ProfilePicture),
|
||||
Email: u.Email,
|
||||
VerifiedAt: verifiedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -565,6 +561,10 @@ func ToUser(dbu DBUser) (users.User, error) {
|
||||
if dbu.UpdatedAt.Valid {
|
||||
updatedAt = dbu.UpdatedAt.Time.UTC()
|
||||
}
|
||||
var verifiedAt time.Time
|
||||
if dbu.VerifiedAt.Valid {
|
||||
verifiedAt = dbu.VerifiedAt.Time.UTC()
|
||||
}
|
||||
|
||||
user := users.User{
|
||||
ID: dbu.ID,
|
||||
@@ -582,6 +582,7 @@ func ToUser(dbu DBUser) (users.User, error) {
|
||||
Status: dbu.Status,
|
||||
Tags: tags,
|
||||
ProfilePicture: nullStringString(dbu.ProfilePicture),
|
||||
VerifiedAt: verifiedAt,
|
||||
}
|
||||
if dbu.Role != nil {
|
||||
user.Role = *dbu.Role
|
||||
|
||||
+141
-75
@@ -1048,6 +1048,147 @@ func TestSearch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRole(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
|
||||
})
|
||||
repo := cpostgres.NewRepository(database)
|
||||
user1 := generateUser(t, users.EnabledStatus, repo)
|
||||
user2 := generateUser(t, users.DisabledStatus, repo)
|
||||
adminRole := users.AdminRole
|
||||
userRole := users.UserRole
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
update string
|
||||
userID string
|
||||
userReq users.User
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update role of user to admin",
|
||||
userReq: users.User{
|
||||
ID: user1.ID,
|
||||
Role: adminRole,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update role of admin to user",
|
||||
userReq: users.User{
|
||||
ID: user1.ID,
|
||||
Role: userRole,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update role for disabled user",
|
||||
userReq: users.User{
|
||||
ID: user2.ID,
|
||||
Role: adminRole,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update role for invalid user",
|
||||
userReq: users.User{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Role: adminRole,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
updatedAt := time.Now().UTC().Truncate(time.Millisecond)
|
||||
updatedBy := testsutil.GenerateUUID(t)
|
||||
c.userReq.UpdatedAt = updatedAt
|
||||
c.userReq.UpdatedBy = updatedBy
|
||||
expected, err := repo.UpdateRole(context.Background(), c.userReq)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
|
||||
if err == nil {
|
||||
assert.Equal(t, c.userReq.Role, expected.Role)
|
||||
assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt)
|
||||
assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateEmail(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
|
||||
})
|
||||
repo := cpostgres.NewRepository(database)
|
||||
|
||||
user1 := generateUser(t, users.EnabledStatus, repo)
|
||||
user2 := generateUser(t, users.DisabledStatus, repo)
|
||||
user3 := generateUser(t, users.EnabledStatus, repo)
|
||||
|
||||
updatedEmail := namesgen.Generate() + emailSuffix
|
||||
emptyName := ""
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
update string
|
||||
userReq users.User
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update email for enabled user",
|
||||
userReq: users.User{
|
||||
ID: user1.ID,
|
||||
Email: updatedEmail,
|
||||
},
|
||||
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update empty email for enabled user",
|
||||
userReq: users.User{
|
||||
ID: user3.ID,
|
||||
Email: emptyName,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update email for disabled user",
|
||||
userReq: users.User{
|
||||
ID: user2.ID,
|
||||
Email: updatedEmail,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update email for invalid user",
|
||||
userReq: users.User{
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Email: updatedEmail,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.desc, func(t *testing.T) {
|
||||
updatedAt := time.Now().UTC().Truncate(time.Millisecond)
|
||||
updatedBy := testsutil.GenerateUUID(t)
|
||||
c.userReq.UpdatedAt = updatedAt
|
||||
c.userReq.UpdatedBy = updatedBy
|
||||
expected, err := repo.UpdateEmail(context.Background(), c.userReq)
|
||||
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
|
||||
if err == nil {
|
||||
assert.Equal(t, c.userReq.Email, expected.Email)
|
||||
assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt)
|
||||
assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
@@ -1065,11 +1206,8 @@ func TestUpdate(t *testing.T) {
|
||||
updatedFirstName := namesgen.Generate()
|
||||
updateTags := namesgen.GenerateMultiple(5)
|
||||
updatedProfilePicture := namesgen.Generate()
|
||||
adminRole := users.AdminRole
|
||||
updatedEmail := namesgen.Generate() + emailSuffix
|
||||
emptyName := ""
|
||||
emptyTags := []string{}
|
||||
allRole := users.AllRole
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -1286,78 +1424,6 @@ func TestUpdate(t *testing.T) {
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update role for enabled user",
|
||||
userID: user1.ID,
|
||||
userReq: users.UserReq{
|
||||
Role: &adminRole,
|
||||
},
|
||||
userRes: users.User{
|
||||
Role: adminRole,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update all role for enabled user",
|
||||
userID: user3.ID,
|
||||
userReq: users.UserReq{
|
||||
Role: &allRole,
|
||||
},
|
||||
userRes: user3,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update role for disabled user",
|
||||
userID: user2.ID,
|
||||
userReq: users.UserReq{
|
||||
Role: &adminRole,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update role for invalid user",
|
||||
userID: testsutil.GenerateUUID(t),
|
||||
userReq: users.UserReq{
|
||||
Role: &adminRole,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update email for enabled user",
|
||||
userID: user1.ID,
|
||||
userReq: users.UserReq{
|
||||
Email: &updatedEmail,
|
||||
},
|
||||
userRes: users.User{
|
||||
Email: updatedEmail,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update empty email for enabled user",
|
||||
userID: user3.ID,
|
||||
userReq: users.UserReq{
|
||||
Email: &emptyName,
|
||||
},
|
||||
userRes: user3,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update email for disabled user",
|
||||
userID: user2.ID,
|
||||
userReq: users.UserReq{
|
||||
Email: &updatedEmail,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "update email for invalid user",
|
||||
userID: testsutil.GenerateUUID(t),
|
||||
userReq: users.UserReq{
|
||||
Email: &updatedEmail,
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
"github.com/absmach/supermq/users"
|
||||
)
|
||||
|
||||
// AddUserVerification adds new verification for given user id and email.
|
||||
func (repo *userRepo) AddUserVerification(ctx context.Context, uv users.UserVerification) error {
|
||||
q := `INSERT INTO users_verifications (user_id, email, otp, created_at, expires_at )
|
||||
VALUES (:user_id, :email, :otp, :created_at, :expires_at );`
|
||||
dbuv := toDBUserVerification(uv)
|
||||
if _, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv); err != nil {
|
||||
return errors.Wrap(repoerr.ErrCreateEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetrieveUserVerification retrieves verification details of given user id and email.
|
||||
func (repo *userRepo) RetrieveUserVerification(ctx context.Context, userID, email string) (users.UserVerification, error) {
|
||||
dbuv := dbUserVerification{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
q := `SELECT user_id, email, otp, created_at, expires_at , used_at FROM users_verifications WHERE user_id = :user_id AND email = :email ORDER BY created_at DESC LIMIT 1 `
|
||||
|
||||
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbuv)
|
||||
if err != nil {
|
||||
return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
if !row.Next() {
|
||||
return users.UserVerification{}, repoerr.ErrNotFound
|
||||
}
|
||||
|
||||
defer row.Close()
|
||||
|
||||
if err := row.StructScan(&dbuv); err != nil {
|
||||
return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
return toUserVerification(dbuv), nil
|
||||
}
|
||||
|
||||
// UpdateUserVerification update user verification details for the given user id and email.
|
||||
func (repo *userRepo) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error {
|
||||
q := `UPDATE users_verifications SET otp = :otp, used_at = :used_at WHERE user_id = :user_id AND email = :email`
|
||||
dbuv := toDBUserVerification(uv)
|
||||
res, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv)
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return errors.Wrap(repoerr.ErrUpdateEntity, err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return repoerr.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbUserVerification struct {
|
||||
UserID string `db:"user_id"`
|
||||
Email string `db:"email"`
|
||||
OTP sql.NullString `db:"otp"`
|
||||
CreatedAt sql.NullTime `db:"created_at"`
|
||||
ExpiresAt sql.NullTime `db:"expires_at"`
|
||||
UsedAt sql.NullTime `db:"used_at"`
|
||||
}
|
||||
|
||||
func toDBUserVerification(uv users.UserVerification) dbUserVerification {
|
||||
var otp sql.NullString
|
||||
if uv.OTP != "" {
|
||||
otp = sql.NullString{String: uv.OTP, Valid: true}
|
||||
}
|
||||
var createdAt sql.NullTime
|
||||
if !uv.CreatedAt.IsZero() {
|
||||
createdAt = sql.NullTime{Time: uv.CreatedAt, Valid: true}
|
||||
}
|
||||
var expiresAt sql.NullTime
|
||||
if !uv.ExpiresAt.IsZero() {
|
||||
expiresAt = sql.NullTime{Time: uv.ExpiresAt, Valid: true}
|
||||
}
|
||||
var usedAt sql.NullTime
|
||||
if !uv.UsedAt.IsZero() {
|
||||
usedAt = sql.NullTime{Time: uv.UsedAt, Valid: true}
|
||||
}
|
||||
|
||||
return dbUserVerification{
|
||||
UserID: uv.UserID,
|
||||
Email: uv.Email,
|
||||
OTP: otp,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
UsedAt: usedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func toUserVerification(dbuv dbUserVerification) users.UserVerification {
|
||||
var createdAt time.Time
|
||||
if dbuv.CreatedAt.Valid {
|
||||
createdAt = dbuv.CreatedAt.Time.UTC()
|
||||
}
|
||||
|
||||
var expiresAt time.Time
|
||||
if dbuv.ExpiresAt.Valid {
|
||||
expiresAt = dbuv.ExpiresAt.Time.UTC()
|
||||
}
|
||||
|
||||
var usedAt time.Time
|
||||
if dbuv.UsedAt.Valid {
|
||||
usedAt = dbuv.UsedAt.Time.UTC()
|
||||
}
|
||||
|
||||
return users.UserVerification{
|
||||
UserID: dbuv.UserID,
|
||||
Email: dbuv.Email,
|
||||
OTP: dbuv.OTP.String,
|
||||
CreatedAt: createdAt,
|
||||
ExpiresAt: expiresAt,
|
||||
UsedAt: usedAt,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,217 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
repoerr "github.com/absmach/supermq/pkg/errors/repository"
|
||||
"github.com/absmach/supermq/users"
|
||||
"github.com/absmach/supermq/users/postgres"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAddUserVerification(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM users_verifications")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
first_name := namesgen.Generate()
|
||||
last_name := namesgen.Generate()
|
||||
username := namesgen.Generate()
|
||||
user := users.User{
|
||||
ID: "test-user-id",
|
||||
Email: "test@example.com",
|
||||
FirstName: first_name,
|
||||
LastName: last_name,
|
||||
Credentials: users.Credentials{
|
||||
Username: username,
|
||||
},
|
||||
}
|
||||
_, err := repo.Save(context.Background(), user)
|
||||
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
uv users.UserVerification
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "add new user verification",
|
||||
uv: users.UserVerification{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
OTP: "123456",
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "add user verification for non-existing user",
|
||||
uv: users.UserVerification{
|
||||
UserID: "non-existing-user",
|
||||
Email: "non-existing@example.com",
|
||||
OTP: "654321",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
err: repoerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
err := repo.AddUserVerification(context.Background(), tc.uv)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveUserVerification(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM users_verifications")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
first_name := namesgen.Generate()
|
||||
last_name := namesgen.Generate()
|
||||
username := namesgen.Generate()
|
||||
user := users.User{
|
||||
ID: "test-user-id",
|
||||
Email: "test@example.com",
|
||||
FirstName: first_name,
|
||||
LastName: last_name,
|
||||
Credentials: users.Credentials{
|
||||
Username: username,
|
||||
},
|
||||
}
|
||||
_, err := repo.Save(context.Background(), user)
|
||||
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
|
||||
|
||||
uv := users.UserVerification{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
OTP: "123456",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
err = repo.AddUserVerification(context.Background(), uv)
|
||||
require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
email string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "retrieve existing user verification",
|
||||
userID: user.ID,
|
||||
email: user.Email,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve non-existing user verification",
|
||||
userID: "non-existing-user",
|
||||
email: "non-existing@example.com",
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.userID, tc.email)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.Equal(t, uv.UserID, retrieved.UserID, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.UserID, retrieved.UserID))
|
||||
assert.Equal(t, uv.Email, retrieved.Email, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.Email, retrieved.Email))
|
||||
assert.Equal(t, uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.OTP, retrieved.OTP))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserVerification(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := db.Exec("DELETE FROM users")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
|
||||
_, err = db.Exec("DELETE FROM users_verifications")
|
||||
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
|
||||
})
|
||||
|
||||
repo := postgres.NewRepository(database)
|
||||
|
||||
first_name := namesgen.Generate()
|
||||
last_name := namesgen.Generate()
|
||||
username := namesgen.Generate()
|
||||
user := users.User{
|
||||
ID: "test-user-id",
|
||||
Email: "test@example.com",
|
||||
FirstName: first_name,
|
||||
LastName: last_name,
|
||||
Credentials: users.Credentials{
|
||||
Username: username,
|
||||
},
|
||||
}
|
||||
_, err := repo.Save(context.Background(), user)
|
||||
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
|
||||
|
||||
uv := users.UserVerification{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
OTP: "123456",
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Hour),
|
||||
}
|
||||
err = repo.AddUserVerification(context.Background(), uv)
|
||||
require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err))
|
||||
|
||||
usedTime := time.Now()
|
||||
cases := []struct {
|
||||
desc string
|
||||
uv users.UserVerification
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "update existing user verification",
|
||||
uv: users.UserVerification{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
OTP: "654321",
|
||||
UsedAt: usedTime,
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update non-existing user verification",
|
||||
uv: users.UserVerification{
|
||||
UserID: "non-existing-user",
|
||||
Email: "non-existing@example.com",
|
||||
},
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
err := repo.UpdateUserVerification(context.Background(), tc.uv)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.uv.UserID, tc.uv.Email)
|
||||
require.Nil(t, err, fmt.Sprintf("retrieving updated verification unexpected error: %s", err))
|
||||
assert.Equal(t, tc.uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.OTP, retrieved.OTP))
|
||||
assert.WithinDuration(t, tc.uv.UsedAt, retrieved.UsedAt, 10*time.Second, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.UsedAt, retrieved.UsedAt))
|
||||
}
|
||||
}
|
||||
}
|
||||
+104
-15
@@ -5,6 +5,7 @@ package users
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"time"
|
||||
|
||||
@@ -92,6 +93,85 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc service) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !dbUser.VerifiedAt.IsZero() {
|
||||
return svcerr.ErrUserAlreadyVerified
|
||||
}
|
||||
|
||||
uv, err := svc.users.RetrieveUserVerification(ctx, dbUser.ID, dbUser.Email)
|
||||
if err != nil && err != repoerr.ErrNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = uv.Valid(); err != nil {
|
||||
uv, err = NewUserVerification(dbUser.ID, dbUser.Email)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
if err := svc.users.AddUserVerification(ctx, uv); err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
}
|
||||
|
||||
uvs, err := uv.Encode()
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
|
||||
if err := svc.email.SendVerification([]string{dbUser.Email}, dbUser.Credentials.Username, uvs); err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) VerifyEmail(ctx context.Context, token string) (User, error) {
|
||||
var received UserVerification
|
||||
if err := received.Decode(token); err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrInvalidUserVerification, err)
|
||||
}
|
||||
|
||||
stored, err := svc.users.RetrieveUserVerification(ctx, received.UserID, received.Email)
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
|
||||
if err := stored.Match(received); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
if err := stored.Valid(); err != nil {
|
||||
if err == svcerr.ErrUserVerificationExpired {
|
||||
return User{}, err
|
||||
}
|
||||
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, err)
|
||||
}
|
||||
|
||||
stored.UsedAt = time.Now().UTC()
|
||||
if err = svc.users.UpdateUserVerification(ctx, stored); err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
user := User{
|
||||
ID: stored.UserID,
|
||||
Email: stored.Email,
|
||||
VerifiedAt: time.Now().UTC(),
|
||||
}
|
||||
user, err = svc.users.UpdateVerifiedAt(ctx, user)
|
||||
if err == repoerr.ErrNotFound {
|
||||
return User{}, svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) {
|
||||
var dbUser User
|
||||
var err error
|
||||
@@ -110,7 +190,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
|
||||
}
|
||||
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey)})
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
|
||||
}
|
||||
@@ -127,7 +207,7 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
|
||||
}
|
||||
|
||||
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken})
|
||||
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
|
||||
}
|
||||
|
||||
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
|
||||
@@ -255,14 +335,23 @@ func (svc service) UpdateEmail(ctx context.Context, session authn.Session, userI
|
||||
return User{}, err
|
||||
}
|
||||
}
|
||||
|
||||
updatedAt := time.Now().UTC()
|
||||
usr := UserReq{
|
||||
Email: &email,
|
||||
UpdatedAt: &updatedAt,
|
||||
UpdatedBy: &session.UserID,
|
||||
oldUsr, err := svc.users.RetrieveByID(ctx, userID)
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
user, err := svc.users.Update(ctx, userID, usr)
|
||||
if oldUsr.Email == email {
|
||||
return User{}, fmt.Errorf("current email is same as update requested email")
|
||||
}
|
||||
|
||||
usr := User{
|
||||
ID: userID,
|
||||
Email: email,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
UpdatedBy: session.UserID,
|
||||
VerifiedAt: time.Time{},
|
||||
}
|
||||
|
||||
user, err := svc.users.UpdateEmail(ctx, usr)
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
@@ -362,18 +451,18 @@ func (svc service) UpdateRole(ctx context.Context, session authn.Session, usr Us
|
||||
if err := svc.checkSuperAdmin(ctx, session); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
updateAt := time.Now().UTC()
|
||||
uReq := UserReq{
|
||||
Role: &usr.Role,
|
||||
UpdatedAt: &updateAt,
|
||||
UpdatedBy: &session.UserID,
|
||||
usr = User{
|
||||
ID: usr.ID,
|
||||
Role: usr.Role,
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
UpdatedBy: session.UserID,
|
||||
}
|
||||
|
||||
if err := svc.updateUserPolicy(ctx, usr.ID, usr.Role); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
|
||||
u, err := svc.users.Update(ctx, usr.ID, uReq)
|
||||
u, err := svc.users.UpdateRole(ctx, usr)
|
||||
if err != nil {
|
||||
// If failed to update role in DB, then revert back to platform admin policies in spicedb
|
||||
if errRollback := svc.updateUserPolicy(ctx, usr.ID, UserRole); errRollback != nil {
|
||||
|
||||
+182
-11
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
|
||||
smqauth "github.com/absmach/supermq/auth"
|
||||
@@ -776,13 +777,13 @@ func TestUpdateRole(t *testing.T) {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
|
||||
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
|
||||
repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
|
||||
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
|
||||
|
||||
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything)
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
@@ -906,7 +907,7 @@ func TestUpdateEmail(t *testing.T) {
|
||||
svc, cRepo := newServiceMinimal()
|
||||
|
||||
user2 := user
|
||||
user2.Email = "updated@example.com"
|
||||
user2.Email = "user2@example.com"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -921,16 +922,26 @@ func TestUpdateEmail(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
desc: "update user as normal user successfully",
|
||||
email: "updated@example.com",
|
||||
email: "user2-update-1@example.com",
|
||||
token: validToken,
|
||||
reqUserID: user.ID,
|
||||
id: user.ID,
|
||||
updateEmailResponse: user2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update to same email as normal user successfully",
|
||||
email: "user2-update-1@example.com",
|
||||
token: validToken,
|
||||
reqUserID: user.ID,
|
||||
id: user.ID,
|
||||
updateEmailResponse: user2,
|
||||
err: nil,
|
||||
},
|
||||
|
||||
{
|
||||
desc: "update user email as normal user with repo error on update",
|
||||
email: "updated@example.com",
|
||||
email: "user2-update-2@example.com",
|
||||
token: validToken,
|
||||
reqUserID: user.ID,
|
||||
id: user.ID,
|
||||
@@ -940,14 +951,14 @@ func TestUpdateEmail(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update user email as admin successfully",
|
||||
email: "updated@example.com",
|
||||
email: "user2-update-3@example.com",
|
||||
token: validToken,
|
||||
id: user.ID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "update user email as admin with repo error on update",
|
||||
email: "updated@exmaple.com",
|
||||
email: "user2-update-4@exmaple.com",
|
||||
token: validToken,
|
||||
reqUserID: user.ID,
|
||||
id: user.ID,
|
||||
@@ -957,7 +968,7 @@ func TestUpdateEmail(t *testing.T) {
|
||||
},
|
||||
{
|
||||
desc: "update user as admin user with failed check on super admin",
|
||||
email: "updated@exmaple.com",
|
||||
email: "user2-update-5@exmaple.com",
|
||||
token: validToken,
|
||||
reqUserID: user.ID,
|
||||
id: "",
|
||||
@@ -970,15 +981,18 @@ func TestUpdateEmail(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
|
||||
repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
|
||||
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
|
||||
if tc.err == nil {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything)
|
||||
if tc.err == nil && user2.Email != tc.email {
|
||||
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
|
||||
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
|
||||
user2.Email = tc.email
|
||||
}
|
||||
repoCall.Unset()
|
||||
repocall2.Unset()
|
||||
repoCall1.Unset()
|
||||
}
|
||||
}
|
||||
@@ -1791,3 +1805,160 @@ func TestOAuthCallback(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendVerification(t *testing.T) {
|
||||
svc, _, cRepo, _, e := newService()
|
||||
|
||||
verifiedAt := time.Now().UTC()
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
retrieveByIDResponse users.User
|
||||
retrieveByIDError error
|
||||
retrieveUserVerResponse users.UserVerification
|
||||
retrieveUserVerError error
|
||||
addUserVerError error
|
||||
sendVerificationEmailError error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "send verification email successfully",
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: user,
|
||||
retrieveUserVerError: repoerr.ErrNotFound,
|
||||
sendVerificationEmailError: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "send verification email for already verified user",
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: users.User{VerifiedAt: verifiedAt},
|
||||
err: svcerr.ErrUserAlreadyVerified,
|
||||
},
|
||||
{
|
||||
desc: "send verification email for non-existing user",
|
||||
session: authn.Session{UserID: wrongID},
|
||||
retrieveByIDError: repoerr.ErrNotFound,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "send verification email with failed to retrieve user verification",
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: user,
|
||||
retrieveUserVerError: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "send verification email with failed to add user verification",
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: user,
|
||||
retrieveUserVerError: repoerr.ErrNotFound,
|
||||
addUserVerError: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
{
|
||||
desc: "send verification email with failed to send email",
|
||||
session: authn.Session{UserID: user.ID},
|
||||
retrieveByIDResponse: user,
|
||||
retrieveUserVerError: repoerr.ErrNotFound,
|
||||
sendVerificationEmailError: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.retrieveByIDResponse, tc.retrieveByIDError)
|
||||
repoCall1 := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError)
|
||||
repoCall2 := cRepo.On("AddUserVerification", context.Background(), mock.Anything).Return(tc.addUserVerError)
|
||||
emailCall := e.On("SendVerification", []string{user.Email}, user.Credentials.Username, mock.Anything).Return(tc.sendVerificationEmailError)
|
||||
|
||||
err := svc.SendVerification(context.Background(), tc.session)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
emailCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyEmail(t *testing.T) {
|
||||
//nolint:dogsled
|
||||
svc, _, cRepo, _, _ := newService()
|
||||
uv, err := users.NewUserVerification(user.ID, user.Email)
|
||||
assert.Nil(t, err, fmt.Sprintf("failed to create user verification: %v", err))
|
||||
uvs, err := uv.Encode()
|
||||
assert.Nil(t, err, fmt.Sprintf("failed to encode user verification: %v", err))
|
||||
createdAt := time.Now().Add(-5 * users.VerificationExpiryDuration).UTC()
|
||||
expiresdAt := time.Now().Add(-users.VerificationExpiryDuration).UTC()
|
||||
cases := []struct {
|
||||
desc string
|
||||
uvs string
|
||||
retrieveUserVerResponse users.UserVerification
|
||||
retrieveUserVerError error
|
||||
updateUserVerError error
|
||||
updateVerifiedAtError error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "verify email successfully",
|
||||
uvs: uvs,
|
||||
retrieveUserVerResponse: uv,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "verify email with malformed token",
|
||||
uvs: "invalid",
|
||||
err: svcerr.ErrInvalidUserVerification,
|
||||
},
|
||||
{
|
||||
desc: "verify email with non-existing user verification",
|
||||
uvs: uvs,
|
||||
retrieveUserVerError: repoerr.ErrNotFound,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "verify email with expired token",
|
||||
uvs: uvs,
|
||||
retrieveUserVerResponse: users.UserVerification{
|
||||
UserID: uv.UserID,
|
||||
Email: uv.Email,
|
||||
OTP: uv.OTP,
|
||||
ExpiresAt: expiresdAt,
|
||||
CreatedAt: createdAt,
|
||||
UsedAt: uv.UsedAt,
|
||||
},
|
||||
err: svcerr.ErrUserVerificationExpired,
|
||||
},
|
||||
{
|
||||
desc: "verify email with failed to update user verification",
|
||||
uvs: uvs,
|
||||
retrieveUserVerResponse: uv,
|
||||
updateUserVerError: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
{
|
||||
desc: "verify email with failed to update verified at",
|
||||
uvs: uvs,
|
||||
retrieveUserVerResponse: uv,
|
||||
updateVerifiedAtError: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError)
|
||||
repoCall1 := cRepo.On("UpdateUserVerification", context.Background(), mock.Anything).Return(tc.updateUserVerError)
|
||||
repoCall2 := cRepo.On("UpdateVerifiedAt", context.Background(), mock.Anything).Return(users.User{}, tc.updateVerifiedAtError)
|
||||
_, err := svc.VerifyEmail(context.Background(), tc.uvs)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
|
||||
repoCall.Unset()
|
||||
repoCall1.Unset()
|
||||
repoCall2.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,21 @@ func (tm *tracingMiddleware) Register(ctx context.Context, session authn.Session
|
||||
return tm.svc.Register(ctx, session, user, selfRegister)
|
||||
}
|
||||
|
||||
// SendVerification traces the "SendVerification" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_send_verification")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.SendVerification(ctx, session)
|
||||
}
|
||||
|
||||
// VerifyEmail traces the "VerifyEmail" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_verify_email")
|
||||
defer span.End()
|
||||
return tm.svc.VerifyEmail(ctx, verificationToken)
|
||||
}
|
||||
|
||||
// IssueToken traces the "IssueToken" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username)))
|
||||
|
||||
+26
-2
@@ -29,6 +29,7 @@ type User struct {
|
||||
CreatedAt time.Time `json:"created_at,omitempty"`
|
||||
UpdatedAt time.Time `json:"updated_at,omitempty"`
|
||||
UpdatedBy string `json:"updated_by,omitempty"`
|
||||
VerifiedAt time.Time `json:"verified_at,omitempty"`
|
||||
}
|
||||
|
||||
type Credentials struct {
|
||||
@@ -49,9 +50,7 @@ type UserReq struct {
|
||||
LastName *string `json:"last_name,omitempty"`
|
||||
Metadata *Metadata `json:"metadata,omitempty"`
|
||||
Tags *[]string `json:"tags,omitempty"`
|
||||
Role *Role `json:"role,omitempty"`
|
||||
ProfilePicture *string `json:"profile_picture,omitempty"`
|
||||
Email *string `json:"email,omitempty"`
|
||||
UpdatedBy *string `json:"updated_by,omitempty"`
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"`
|
||||
}
|
||||
@@ -90,6 +89,15 @@ type Repository interface {
|
||||
// UpdateSecret updates secret for user with given email.
|
||||
UpdateSecret(ctx context.Context, user User) (User, error)
|
||||
|
||||
// UpdateEmail updates email for user with given id.
|
||||
UpdateEmail(ctx context.Context, user User) (User, error)
|
||||
|
||||
// UpdateRole updates role for user with given id.
|
||||
UpdateRole(ctx context.Context, user User) (User, error)
|
||||
|
||||
// UpdateVerifiedAt updates the verified time for user with given id.
|
||||
UpdateVerifiedAt(ctx context.Context, user User) (User, error)
|
||||
|
||||
// ChangeStatus changes user status to enabled or disabled
|
||||
ChangeStatus(ctx context.Context, user User) (User, error)
|
||||
|
||||
@@ -107,6 +115,15 @@ type Repository interface {
|
||||
// Save persists the user account. A non-nil error is returned to indicate
|
||||
// operation failure.
|
||||
Save(ctx context.Context, user User) (User, error)
|
||||
|
||||
// AddUserVerification adds new verification for given user id and email
|
||||
AddUserVerification(ctx context.Context, uv UserVerification) error
|
||||
|
||||
// RetrieveVerificationToken retrieves verification token of given user id and email.
|
||||
RetrieveUserVerification(ctx context.Context, userID, email string) (UserVerification, error)
|
||||
|
||||
// UpdateUserVerificationDetails update verification details for the given user id and email.
|
||||
UpdateUserVerification(ctx context.Context, uv UserVerification) error
|
||||
}
|
||||
|
||||
// Validate returns an error if user representation is invalid.
|
||||
@@ -143,6 +160,7 @@ type Page struct {
|
||||
FirstName string `json:"first_name,omitempty"`
|
||||
LastName string `json:"last_name,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Verified bool `json:"verified,omitempty"`
|
||||
}
|
||||
|
||||
// Service specifies an API that must be fullfiled by the domain service
|
||||
@@ -152,6 +170,12 @@ type Service interface {
|
||||
// non-nil error value is returned.
|
||||
Register(ctx context.Context, session authn.Session, user User, selfRegister bool) (User, error)
|
||||
|
||||
// SendVerification sends a verification email to the user.
|
||||
SendVerification(ctx context.Context, session authn.Session) error
|
||||
|
||||
// VerifyEmail verifies user's email using the verification token.
|
||||
VerifyEmail(ctx context.Context, verificationToken string) (User, error)
|
||||
|
||||
// View retrieves user info for a given user ID and an authorized token.
|
||||
View(ctx context.Context, session authn.Session, id string) (User, error)
|
||||
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package users
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
svcerr "github.com/absmach/supermq/pkg/errors/service"
|
||||
)
|
||||
|
||||
const VerificationExpiryDuration = 24 * time.Hour
|
||||
|
||||
var (
|
||||
errFailedToCreateUserVerification = errors.New("failed to create new user verification")
|
||||
errFailedToEncodeUserVerification = errors.New("failed to encode user verification")
|
||||
errFailedToDecodeUserVerification = errors.New("failed to decode user verification")
|
||||
)
|
||||
|
||||
// UserVerification OTP is sent to the user's email as base64 encoded with UserID, Email and OTP. It should not be exposed via API.
|
||||
type UserVerification struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
OTP string `json:"otp"`
|
||||
CreatedAt time.Time `json:"-"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
UsedAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
func NewUserVerification(userID, email string) (UserVerification, error) {
|
||||
randomBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return UserVerification{}, errors.Wrap(errFailedToCreateUserVerification, err)
|
||||
}
|
||||
|
||||
return UserVerification{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
OTP: base64.URLEncoding.EncodeToString(randomBytes),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
ExpiresAt: time.Now().Add(VerificationExpiryDuration).UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u UserVerification) Encode() (string, error) {
|
||||
jsonBytes, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
return "", errors.Wrap(errFailedToEncodeUserVerification, err)
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(jsonBytes), nil
|
||||
}
|
||||
|
||||
func (u *UserVerification) Decode(data string) error {
|
||||
decodedPayload, err := base64.URLEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return errors.Wrap(errFailedToDecodeUserVerification, err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(decodedPayload, u); err != nil {
|
||||
return errors.Wrap(errFailedToDecodeUserVerification, err)
|
||||
}
|
||||
|
||||
if u.UserID == "" || u.Email == "" || u.OTP == "" {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u UserVerification) Valid() error {
|
||||
if u.UserID == "" || u.Email == "" || u.OTP == "" {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
// Verification should have created time.
|
||||
if u.CreatedAt.IsZero() {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
// Verification should have expiry time.
|
||||
if u.ExpiresAt.IsZero() {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
// Expiry time should not be before Created time
|
||||
if u.ExpiresAt.Before(u.CreatedAt) {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
// Verification should be not be Expired.
|
||||
if time.Now().After(u.ExpiresAt) {
|
||||
return svcerr.ErrUserVerificationExpired
|
||||
}
|
||||
|
||||
// Verification should not be used.
|
||||
if !u.UsedAt.IsZero() {
|
||||
return svcerr.ErrUserVerificationExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u UserVerification) Match(ruv UserVerification) error {
|
||||
if u.UserID != ruv.UserID {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
if u.Email != ruv.Email {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
|
||||
if u.OTP != ruv.OTP {
|
||||
return svcerr.ErrInvalidUserVerification
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user