SMQ-3093 - User email verification (#3101)

Signed-off-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Arvindh
2025-09-05 18:53:58 +05:30
committed by GitHub
parent 22616911d2
commit e57ad79cd4
98 changed files with 3096 additions and 524 deletions
+11 -2
View File
@@ -73,6 +73,7 @@ type AuthNRes struct {
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` // verified user
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -128,6 +129,13 @@ func (x *AuthNRes) GetUserRole() uint32 {
return 0
}
func (x *AuthNRes) GetVerified() bool {
if x != nil {
return x.Verified
}
return false
}
type AuthZReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` // Domain
@@ -378,11 +386,12 @@ const file_auth_v1_auth_proto_rawDesc = "" +
"\n" +
"\x12auth/v1/auth.proto\x12\aauth.v1\" \n" +
"\bAuthNReq\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\"P\n" +
"\x05token\x18\x01 \x01(\tR\x05token\"l\n" +
"\bAuthNRes\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" +
"\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" +
"\tuser_role\x18\x03 \x01(\rR\buserRole\"\xa2\x02\n" +
"\tuser_role\x18\x03 \x01(\rR\buserRole\x12\x1a\n" +
"\bverified\x18\x04 \x01(\bR\bverified\"\xa2\x02\n" +
"\bAuthZReq\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" +
"\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" +
+21 -3
View File
@@ -29,6 +29,7 @@ type IssueReq struct {
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -84,9 +85,17 @@ func (x *IssueReq) GetType() uint32 {
return 0
}
func (x *IssueReq) GetVerified() bool {
if x != nil {
return x.Verified
}
return false
}
type RefreshReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
Verified bool `protobuf:"varint,2,opt,name=verified,proto3" json:"verified,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -128,6 +137,13 @@ func (x *RefreshReq) GetRefreshToken() string {
return ""
}
func (x *RefreshReq) GetVerified() bool {
if x != nil {
return x.Verified
}
return false
}
// If a token is not carrying any information itself, the type
// field can be used to determine how to validate the token.
// Also, different tokens can be encoded in different ways.
@@ -195,14 +211,16 @@ var File_token_v1_token_proto protoreflect.FileDescriptor
const file_token_v1_token_proto_rawDesc = "" +
"\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"T\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"p\n" +
"\bIssueReq\x12\x17\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
"\x04type\x18\x03 \x01(\rR\x04type\"1\n" +
"\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" +
"\bverified\x18\x04 \x01(\bR\bverified\"M\n" +
"\n" +
"RefreshReq\x12#\n" +
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\"\x87\x01\n" +
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" +
"\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" +
"\x05Token\x12!\n" +
"\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" +
"\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" +
-54
View File
@@ -1,54 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package http
import (
"context"
"net/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
smqauthn "github.com/absmach/supermq/pkg/authn"
"github.com/go-chi/chi/v5"
)
type sessionKeyType string
const SessionKey = sessionKeyType("session")
func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := apiutil.ExtractBearerToken(r)
if token == "" {
EncodeError(r.Context(), apiutil.ErrBearerToken, w)
return
}
resp, err := authn.Authenticate(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
}
if domainCheck {
domain := chi.URLParam(r, "domainID")
if domain == "" {
EncodeError(r.Context(), apiutil.ErrMissingDomainID, w)
return
}
resp.DomainID = domain
switch resp.Role {
case smqauthn.AdminRole:
resp.DomainUserID = resp.UserID
case smqauthn.UserRole:
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
}
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
+9
View File
@@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"net/http"
"net/mail"
"regexp"
"strings"
@@ -110,6 +111,13 @@ func ValidateUUID(extID string) (err error) {
return nil
}
func ValidateEmail(email string) (err error) {
if _, err := mail.ParseAddress(email); err != nil {
return apiutil.ErrInvalidEmail
}
return nil
}
// ValidateName validates name format.
func ValidateName(id string) error {
if !nameRegExp.MatchString(id) {
@@ -201,6 +209,7 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
errors.Contains(err, apiutil.ErrMissingSecret),
errors.Contains(err, errors.ErrMalformedEntity),
errors.Contains(err, apiutil.ErrMissingID),
errors.Contains(err, apiutil.ErrInvalidVerification),
errors.Contains(err, apiutil.ErrMissingName),
errors.Contains(err, apiutil.ErrMissingEmail),
errors.Contains(err, apiutil.ErrInvalidEmail),
+6
View File
@@ -268,4 +268,10 @@ var (
// ErrMissingUsernameEmail indicates missing user name / email.
ErrMissingUsernameEmail = errors.New("missing username / email")
// ErrInvalidVerification indicates invalid email verification.
ErrInvalidVerification = errors.New("invalid verification")
// ErrEmailNotVerified indicates invalid email not verified.
ErrEmailNotVerified = errors.New("email not verified")
)
+56
View File
@@ -607,6 +607,54 @@ paths:
"500":
$ref: "#/components/responses/ServiceError"
/users/send-verification:
post:
operationId: sendVerification
tags:
- Users
summary: Sends a verification email
description: |
Sends a verification email to the user.
security:
- bearerAuth: []
responses:
"200":
description: Sent verification email if registered.
"400":
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"415":
description: Missing or invalid content type.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/verify-email:
get:
operationId: verifyEmail
tags:
- Users
summary: Verify user's email
description: |
Verify user's email using the token from the verification link.
parameters:
- $ref: "#/components/parameters/VerificationToken"
responses:
"200":
description: Email verified successfully.
"400":
description: Failed due to malformed query parameters.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/health:
get:
operationId: health
@@ -1240,6 +1288,14 @@ components:
required: false
example: "0"
VerificationToken:
name: token
description: Verification token.
in: query
schema:
type: string
required: true
requestBodies:
UserCreateReq:
description: JSON-formatted document describing the new user to be registered
+2 -2
View File
@@ -75,7 +75,7 @@ func (client authGrpcClient) Authenticate(ctx context.Context, token *grpcAuthV1
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
}
ir := res.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole)}, nil
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole), Verified: ir.verified}, nil
}
func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) {
@@ -85,7 +85,7 @@ func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) {
func decodeIdentifyResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(*grpcAuthV1.AuthNRes)
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole)}, nil
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole), verified: res.GetVerified()}, nil
}
func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
+1 -1
View File
@@ -23,7 +23,7 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
return authenticateRes{}, err
}
return authenticateRes{userID: key.Subject, userRole: key.Role}, nil
return authenticateRes{userID: key.Subject, userRole: key.Role, verified: key.Verified}, nil
}
}
+1
View File
@@ -9,6 +9,7 @@ type authenticateRes struct {
id string
userID string
userRole smqauth.Role
verified bool
}
type authorizeRes struct {
+1 -1
View File
@@ -82,7 +82,7 @@ func decodeAuthenticateRequest(_ context.Context, grpcReq any) (any, error) {
func encodeAuthenticateResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole)}, nil
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole), Verified: res.verified}, nil
}
func encodeAuthenticatePATResponse(_ context.Context, grpcRes any) (any, error) {
+4 -2
View File
@@ -56,6 +56,7 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.GetVerified(),
})
if err != nil {
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
@@ -69,6 +70,7 @@ func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
UserId: req.userID,
UserRole: uint32(req.userRole),
Type: uint32(req.keyType),
Verified: req.verified,
}, nil
}
@@ -80,7 +82,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken()})
res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken(), verified: req.GetVerified()})
if err != nil {
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
}
@@ -89,7 +91,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr
func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(refreshReq)
return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken}, nil
return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken, Verified: req.verified}, nil
}
func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) {
+5 -4
View File
@@ -18,9 +18,10 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
}
key := auth.Key{
Type: req.keyType,
Subject: req.userID,
Role: req.userRole,
Type: req.keyType,
Subject: req.userID,
Role: req.userRole,
Verified: req.verified,
}
tkn, err := svc.Issue(ctx, "", key)
if err != nil {
@@ -42,7 +43,7 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint {
return issueRes{}, err
}
key := auth.Key{Type: auth.RefreshKey}
key := auth.Key{Type: auth.RefreshKey, Verified: req.verified}
tkn, err := svc.Issue(ctx, req.refreshToken, key)
if err != nil {
return issueRes{}, err
+2
View File
@@ -12,6 +12,7 @@ type issueReq struct {
userID string
userRole auth.Role
keyType auth.KeyType
verified bool
}
func (req issueReq) validate() error {
@@ -27,6 +28,7 @@ func (req issueReq) validate() error {
type refreshReq struct {
refreshToken string
verified bool
}
func (req refreshReq) validate() error {
+2 -1
View File
@@ -58,12 +58,13 @@ func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.Verified,
}, nil
}
func decodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcTokenV1.RefreshReq)
return refreshReq{refreshToken: req.GetRefreshToken()}, nil
return refreshReq{refreshToken: req.GetRefreshToken(), verified: req.Verified}, nil
}
func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) {
+15
View File
@@ -21,6 +21,8 @@ var (
errInvalidType = errors.New("invalid token type")
// errInvalidRole is returned when the role is invalid.
errInvalidRole = errors.New("invalid role")
// errInvalidVerified is returned when the verified is invalid.
errInvalidVerified = errors.New("invalid verified")
// errJWTExpiryKey is used to check if the token is expired.
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
// ErrSignJWT indicates an error in signing jwt token.
@@ -36,6 +38,7 @@ const (
tokenType = "type"
userField = "user"
RoleField = "role"
VerifiedField = "verified"
oauthProviderField = "oauth_provider"
oauthAccessTokenField = "access_token"
oauthRefreshTokenField = "refresh_token"
@@ -62,6 +65,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(RoleField, key.Role)
builder.Claim(VerifiedField, key.Verified)
if key.Subject != "" {
builder.Subject(key.Subject)
}
@@ -150,6 +154,16 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
if !ok {
return auth.Key{}, errInvalidRole
}
tVerified, ok := tkn.Get(VerifiedField)
if !ok {
return auth.Key{}, errInvalidVerified
}
kVerified, ok := tVerified.(bool)
if !ok {
return auth.Key{}, errInvalidVerified
}
kr := auth.Role(kRole)
if !kr.Validate() {
return auth.Key{}, errInvalidRole
@@ -162,6 +176,7 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
key.Subject = tkn.Subject()
key.IssuedAt = tkn.IssuedAt()
key.ExpiresAt = tkn.Expiration()
key.Verified = kVerified
return key, nil
}
+1
View File
@@ -88,6 +88,7 @@ type Key struct {
Role Role `json:"role,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Verified bool `json:"verified,omitempty"`
}
func (key Key) String() string {
+2 -1
View File
@@ -72,7 +72,8 @@ func newCertServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticati
logger := smqlog.NewMock()
idp := uuid.NewMock()
authn := new(authnmocks.Authentication)
mux := api.MakeHandler(svc, authn, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := api.MakeHandler(svc, am, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
+2 -2
View File
@@ -31,7 +31,7 @@ const (
)
// MakeHandler returns a HTTP handler for API endpoints.
func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
func MakeHandler(svc certs.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -39,7 +39,7 @@ func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog.
r := chi.NewRouter()
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Route("/{domainID}", func(r chi.Router) {
+2 -1
View File
@@ -58,7 +58,8 @@ func newChannelsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenti
mux := chi.NewRouter()
idp := uuid.NewMock()
logger := smqlog.NewMock()
mux = MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux = MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
+15 -16
View File
@@ -6,7 +6,6 @@ package http
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/pkg/authn"
@@ -22,7 +21,7 @@ func createChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -46,7 +45,7 @@ func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -77,7 +76,7 @@ func viewChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -98,7 +97,7 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -138,7 +137,7 @@ func updateChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -164,7 +163,7 @@ func updateChannelTagsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -189,7 +188,7 @@ func setChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -209,7 +208,7 @@ func removeChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -229,7 +228,7 @@ func enableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -250,7 +249,7 @@ func disableChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -271,7 +270,7 @@ func connectChannelClientEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -291,7 +290,7 @@ func disconnectChannelClientsEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -311,7 +310,7 @@ func connectEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -331,7 +330,7 @@ func disconnectEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -351,7 +350,7 @@ func deleteChannelEndpoint(svc channels.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
+2 -2
View File
@@ -19,7 +19,7 @@ import (
)
// MakeHandler returns a HTTP handler for Channels API endpoints.
func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
func MakeHandler(svc channels.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -27,7 +27,7 @@ func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.M
d := roleManagerHttp.NewDecoder("channelID")
mux.Route("/{domainID}/channels", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+10 -8
View File
@@ -12,14 +12,16 @@ import (
)
const (
all = "all"
create = "create"
get = "get"
update = "update"
delete = "delete"
enable = "enable"
disable = "disable"
users = "users"
all = "all"
create = "create"
get = "get"
update = "update"
delete = "delete"
enable = "enable"
disable = "disable"
users = "users"
sendVerification = "send-verification"
verifyEmail = "verify-email"
usageCreate = "cli channels <channel_id> create <JSON_channel> <domain_id> <user_auth_token>"
usageGet = "cli channels <channel_id|all> get <domain_id> <user_auth_token>"
+34 -1
View File
@@ -72,11 +72,16 @@ Examples:
logUsageCmd(*cmd, cmd.Use)
return
}
switch args[0] {
case create:
handleUserCreate(cmd, args[1:])
return
case sendVerification:
handleSendVerification(cmd, args[1])
return
case verifyEmail:
handleVerify(cmd, args[1])
return
case token:
handleUserToken(cmd, args[1], args[2:])
return
@@ -152,6 +157,34 @@ func handleUserCreate(cmd *cobra.Command, args []string) {
logJSONCmd(*cmd, user)
}
func handleSendVerification(cmd *cobra.Command, token string) {
if token == "" {
logUsageCmd(*cmd, usageUserToken)
return
}
if err := sdk.SendVerification(cmd.Context(), token); err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, "sent verification successfully")
}
func handleVerify(cmd *cobra.Command, token string) {
if token == "" {
logUsageCmd(*cmd, usageUserToken)
return
}
if err := sdk.VerifyEmail(cmd.Context(), token); err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, "verified successfully")
}
func handleUserGet(cmd *cobra.Command, userParams string, args []string) {
if len(args) != 1 {
logUsageCmd(*cmd, usageUserGet)
+2 -2
View File
@@ -17,14 +17,14 @@ import (
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
func clientsHandler(svc clients.Service, authn smqauthn.Authentication, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux {
func clientsHandler(svc clients.Service, authn smqauthn.AuthNMiddleware, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
d := roleManagerHttp.NewDecoder("clientID")
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Route("/{domainID}/clients", func(r chi.Router) {
+12 -13
View File
@@ -6,7 +6,6 @@ package http
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/pkg/authn"
@@ -22,7 +21,7 @@ func createClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -46,7 +45,7 @@ func createClientsEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -77,7 +76,7 @@ func viewClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -98,7 +97,7 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -138,7 +137,7 @@ func updateClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -164,7 +163,7 @@ func updateClientTagsEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -189,7 +188,7 @@ func updateClientSecretEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -210,7 +209,7 @@ func enableClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -231,7 +230,7 @@ func disableClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -252,7 +251,7 @@ func setClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -271,7 +270,7 @@ func removeClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -290,7 +289,7 @@ func deleteClientEndpoint(svc clients.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
+2 -1
View File
@@ -96,7 +96,8 @@ func newClientsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic
logger := smqlog.NewMock()
mux := chi.NewRouter()
idp := uuid.NewMock()
clientsapi.MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
clientsapi.MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
+1 -1
View File
@@ -15,7 +15,7 @@ import (
)
// MakeHandler returns a HTTP handler for clients and Groups API endpoints.
func MakeHandler(tsvc clients.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
func MakeHandler(tsvc clients.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
mux = clientsHandler(tsvc, authn, mux, logger, idp)
mux.Get("/health", supermq.Health("clients", instanceID))
+3 -1
View File
@@ -20,6 +20,7 @@ import (
"github.com/absmach/supermq/certs/postgres"
"github.com/absmach/supermq/certs/tracing"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
"github.com/absmach/supermq/pkg/grpcclient"
jaegerclient "github.com/absmach/supermq/pkg/jaeger"
@@ -138,6 +139,7 @@ func main() {
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio)
if err != nil {
@@ -163,7 +165,7 @@ func main() {
idp := uuid.New()
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, cfg.InstanceID, idp), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, cfg.InstanceID, idp), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+3 -1
View File
@@ -31,6 +31,7 @@ import (
gpostgres "github.com/absmach/supermq/groups/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -179,6 +180,7 @@ func main() {
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
@@ -304,7 +306,7 @@ func main() {
}
mux := chi.NewRouter()
idp := uuid.New()
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+3 -1
View File
@@ -31,6 +31,7 @@ import (
gpostgres "github.com/absmach/supermq/groups/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -190,6 +191,7 @@ func main() {
}
defer authnClient.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
@@ -292,7 +294,7 @@ func main() {
}
mux := chi.NewRouter()
idp := uuid.New()
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
grpcServerConfig := server.Config{Port: defSvcAuthGRPCPort}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil {
+3 -1
View File
@@ -27,6 +27,7 @@ import (
dtracing "github.com/absmach/supermq/domains/tracing"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
"github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -159,6 +160,7 @@ func main() {
}
defer authnHandler.Close()
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
database := postgres.NewDatabase(db, dbConfig, tracer)
domainsRepo := dpostgres.NewRepository(database)
@@ -239,7 +241,7 @@ func main() {
}
mux := chi.NewMux()
idp := uuid.New()
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
g.Go(func() error {
return hs.Start()
+3 -1
View File
@@ -28,6 +28,7 @@ import (
pgroups "github.com/absmach/supermq/groups/private"
"github.com/absmach/supermq/groups/tracing"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -162,6 +163,7 @@ func main() {
}
defer authnHandler.Close()
logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
@@ -263,7 +265,7 @@ func main() {
mux := chi.NewRouter()
idp := uuid.New()
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger)
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger)
grpcServerConfig := server.Config{}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixgRPC}); err != nil {
+3 -1
View File
@@ -20,6 +20,7 @@ import (
"github.com/absmach/supermq/journal/middleware"
journalpg "github.com/absmach/supermq/journal/postgres"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -112,6 +113,7 @@ func main() {
}
defer authnHandler.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
@@ -173,7 +175,7 @@ func main() {
return
}
hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, svcName, cfg.InstanceID), logger)
hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, svcName, cfg.InstanceID), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
+51 -29
View File
@@ -20,6 +20,7 @@ import (
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
"github.com/absmach/supermq/internal/email"
smqlog "github.com/absmach/supermq/logger"
smqauthn "github.com/absmach/supermq/pkg/authn"
authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc"
smqauthz "github.com/absmach/supermq/pkg/authz"
authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc"
@@ -67,28 +68,31 @@ const (
)
type config struct {
LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"`
AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"`
AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"`
AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"`
AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"`
AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"`
PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"`
OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"`
OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"`
DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"`
DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost:8080"`
PassRegex *regexp.Regexp
LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"`
AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"`
AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"`
AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"`
AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"`
AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"`
PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"`
JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"`
SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"`
InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"`
OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"`
OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"`
DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"`
DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"`
SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"`
SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"`
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost/password/reset"`
PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"`
VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"`
VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"`
PassRegex *regexp.Regexp
}
func main() {
@@ -121,12 +125,21 @@ func main() {
}
}
ec := email.Config{}
if err := env.Parse(&ec); err != nil {
logger.Error(fmt.Sprintf("failed to load email configuration : %s", err.Error()))
resetPasswordEmailConfig := email.Config{}
if err := env.Parse(&resetPasswordEmailConfig); err != nil {
logger.Error(fmt.Sprintf("failed to load reset password email configuration : %s", err.Error()))
exitCode = 1
return
}
resetPasswordEmailConfig.Template = cfg.PasswordResetEmailTemplate
verificationEmailConfig := email.Config{}
if err := env.Parse(&verificationEmailConfig); err != nil {
logger.Error(fmt.Sprintf("failed to load verification password email configuration : %s", err.Error()))
exitCode = 1
return
}
verificationEmailConfig.Template = cfg.VerificationEmailTemplate
dbConfig := pgclient.Config{Name: defDB}
if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil {
@@ -182,6 +195,8 @@ func main() {
defer authnHandler.Close()
logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure())
authnMiddleware := smqauthn.NewAuthNMiddleware(authn)
domsGrpcCfg := grpcclient.Config{}
if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil {
logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err))
@@ -213,7 +228,7 @@ func main() {
}
logger.Info("Policy client successfully connected to spicedb gRPC server")
csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, ec, logger)
csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, resetPasswordEmailConfig, verificationEmailConfig, logger)
if err != nil {
logger.Error(fmt.Sprintf("failed to setup service: %s", err))
exitCode = 1
@@ -237,7 +252,7 @@ func main() {
mux := chi.NewRouter()
idp := uuid.New()
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authn, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger)
httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authnMiddleware, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger)
if cfg.SendTelemetry {
chc := chclient.New(svcName, supermq.Version, logger, cancel)
@@ -257,16 +272,23 @@ func main() {
}
}
func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, ec email.Config, logger *slog.Logger) (users.Service, error) {
func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, resetPasswordEmailConfig, verificationEmailConfig email.Config, logger *slog.Logger) (users.Service, error) {
database := pg.NewDatabase(db, dbConfig, tracer)
idp := uuid.New()
hsr := hasher.New()
// Creating users service
repo := postgres.NewRepository(database)
emailerClient, err := emailer.New(fmt.Sprintf("%s/reset-request", c.PasswordResetURLPrefix), &ec)
emailerClient, err := emailer.New(
c.PasswordResetURLPrefix,
c.VerificationURLPrefix,
&resetPasswordEmailConfig,
&verificationEmailConfig,
)
if err != nil {
logger.Error(fmt.Sprintf("failed to configure e-mailing util: %s", err.Error()))
return nil, err
}
svc := users.NewService(token, repo, policyService, emailerClient, hsr, idp)
+10 -6
View File
@@ -248,7 +248,6 @@ SMQ_USERS_DB_SSL_MODE=disable
SMQ_USERS_DB_SSL_CERT=
SMQ_USERS_DB_SSL_KEY=
SMQ_USERS_DB_SSL_ROOT_CERT=
SMQ_USERS_RESET_PWD_TEMPLATE=users.tmpl
SMQ_USERS_INSTANCE_ID=
SMQ_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH
SMQ_USERS_ADMIN_EMAIL=admin@example.com
@@ -261,19 +260,21 @@ SMQ_OAUTH_UI_REDIRECT_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/tokens/secu
SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/error
SMQ_USERS_DELETE_INTERVAL=24h
SMQ_USERS_DELETE_AFTER=720h
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9001
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl
SMQ_VERIFICATION_URL_PREFIX=http://localhost/verify-email
SMQ_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl
#### Users Client Config
SMQ_USERS_URL=users:9002
### Email utility
SMQ_EMAIL_HOST=smtp.mailtrap.io
SMQ_EMAIL_HOST=host.docker.internal
SMQ_EMAIL_PORT=2525
SMQ_EMAIL_USERNAME=18bf7f70705139
SMQ_EMAIL_PASSWORD=2b0d302e775b1e
SMQ_EMAIL_USERNAME=from@example.com
SMQ_EMAIL_PASSWORD=password
SMQ_EMAIL_FROM_ADDRESS=from@example.com
SMQ_EMAIL_FROM_NAME=Example
SMQ_EMAIL_TEMPLATE=email.tmpl
### Google OAuth2
SMQ_GOOGLE_CLIENT_ID=
@@ -514,5 +515,8 @@ SMQ_GRAFANA_PORT=3000
SMQ_GRAFANA_ADMIN_USER=supermq
SMQ_GRAFANA_ADMIN_PASSWORD=supermq
## Allow unverified user to access
SMQ_ALLOW_UNVERIFIED_USER=true
# Docker image tag
SMQ_RELEASE_TAG=latest
+1
View File
@@ -74,6 +74,7 @@ services:
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY}
SMQ_CERTS_INSTANCE_ID: ${SMQ_CERTS_INSTANCE_ID}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
volumes:
- ../../ssl/certs/ca.key:/etc/ssl/certs/ca.key
- ../../ssl/certs/ca.crt:/etc/ssl/certs/ca.crt
@@ -69,6 +69,7 @@ services:
SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt}
SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key}
SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
ports:
- ${SMQ_JOURNAL_HTTP_PORT}:${SMQ_JOURNAL_HTTP_PORT}
networks:
+9 -2
View File
@@ -285,6 +285,7 @@ services:
SMQ_DOMAINS_CALLOUT_CERT: ${SMQ_DOMAINS_CALLOUT_CERT}
SMQ_DOMAINS_CALLOUT_KEY: ${SMQ_DOMAINS_CALLOUT_KEY}
SMQ_DOMAINS_CALLOUT_OPERATIONS: ${SMQ_DOMAINS_CALLOUT_OPERATIONS}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
ports:
- ${SMQ_DOMAINS_HTTP_PORT}:${SMQ_DOMAINS_HTTP_PORT}
- ${SMQ_DOMAINS_GRPC_PORT}:${SMQ_DOMAINS_GRPC_PORT}
@@ -516,6 +517,7 @@ services:
SMQ_CLIENTS_CALLOUT_CERT: ${SMQ_CLIENTS_CALLOUT_CERT}
SMQ_CLIENTS_CALLOUT_KEY: ${SMQ_CLIENTS_CALLOUT_KEY}
SMQ_CLIENTS_CALLOUT_OPERATIONS: ${SMQ_CLIENTS_CALLOUT_OPERATIONS}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
ports:
- ${SMQ_CLIENTS_HTTP_PORT}:${SMQ_CLIENTS_HTTP_PORT}
- ${SMQ_CLIENTS_GRPC_PORT}:${SMQ_CLIENTS_GRPC_PORT}
@@ -706,6 +708,7 @@ services:
SMQ_CHANNELS_CALLOUT_CERT: ${SMQ_CHANNELS_CALLOUT_CERT}
SMQ_CHANNELS_CALLOUT_KEY: ${SMQ_CHANNELS_CALLOUT_KEY}
SMQ_CHANNELS_CALLOUT_OPERATIONS: ${SMQ_CHANNELS_CALLOUT_OPERATIONS}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
ports:
- ${SMQ_CHANNELS_HTTP_PORT}:${SMQ_CHANNELS_HTTP_PORT}
- ${SMQ_CHANNELS_GRPC_PORT}:${SMQ_CHANNELS_GRPC_PORT}
@@ -855,7 +858,6 @@ services:
SMQ_EMAIL_PASSWORD: ${SMQ_EMAIL_PASSWORD}
SMQ_EMAIL_FROM_ADDRESS: ${SMQ_EMAIL_FROM_ADDRESS}
SMQ_EMAIL_FROM_NAME: ${SMQ_EMAIL_FROM_NAME}
SMQ_EMAIL_TEMPLATE: ${SMQ_EMAIL_TEMPLATE}
SMQ_ES_URL: ${SMQ_ES_URL}
SMQ_JAEGER_URL: ${SMQ_JAEGER_URL}
SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO}
@@ -882,12 +884,16 @@ services:
SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST}
SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT}
SMQ_PASSWORD_RESET_URL_PREFIX: ${SMQ_PASSWORD_RESET_URL_PREFIX}
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE: ${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}
SMQ_VERIFICATION_URL_PREFIX: ${SMQ_VERIFICATION_URL_PREFIX}
SMQ_VERIFICATION_EMAIL_TEMPLATE: ${SMQ_VERIFICATION_EMAIL_TEMPLATE}
ports:
- ${SMQ_USERS_HTTP_PORT}:${SMQ_USERS_HTTP_PORT}
networks:
- supermq-base-net
volumes:
- ./templates/${SMQ_USERS_RESET_PWD_TEMPLATE}:/email.tmpl
- ./templates/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}:/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}
- ./templates/${SMQ_VERIFICATION_EMAIL_TEMPLATE}:/${SMQ_VERIFICATION_EMAIL_TEMPLATE}
# Auth gRPC client certificates
- type: bind
source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert}
@@ -1007,6 +1013,7 @@ services:
SMQ_GROUPS_CALLOUT_CERT: ${SMQ_GROUPS_CALLOUT_CERT}
SMQ_GROUPS_CALLOUT_KEY: ${SMQ_GROUPS_CALLOUT_KEY}
SMQ_GROUPS_CALLOUT_OPERATIONS: ${SMQ_GROUPS_CALLOUT_OPERATIONS}
SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER}
ports:
- ${SMQ_GROUPS_HTTP_PORT}:${SMQ_GROUPS_HTTP_PORT}
- ${SMQ_GROUPS_GRPC_PORT}:${SMQ_GROUPS_GRPC_PORT}
+1 -1
View File
@@ -72,7 +72,7 @@ http {
}
# Proxy pass to users service
location ~ ^/(users|password|authorize|oauth/callback/[^/]+) {
location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://users:${SMQ_USERS_HTTP_PORT};
+1 -1
View File
@@ -81,7 +81,7 @@ http {
}
# Proxy pass to users service
location ~ ^/(users|groups|password|authorize|oauth/callback/[^/]+) {
location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://users:${SMQ_USERS_HTTP_PORT};
+15
View File
@@ -0,0 +1,15 @@
Dear {{.User}},
Welcome to {{.Host}}! To complete your account registration, please verify your email address by clicking on the link below:
{{.Content}}
This verification link will expire in 24 hours. If you did not create an account on {{.Host}}, please disregard this message.
Once verified, you'll have full access to all features on {{.Host}}.
Thank you for joining {{.Host}}!
Best regards,
{{.Footer}}
+13 -14
View File
@@ -6,7 +6,6 @@ package http
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
@@ -25,7 +24,7 @@ func createDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -53,7 +52,7 @@ func retrieveDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -73,7 +72,7 @@ func updateDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -103,7 +102,7 @@ func listDomainsEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -123,7 +122,7 @@ func enableDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -142,7 +141,7 @@ func disableDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -161,7 +160,7 @@ func freezeDomainEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, err
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -179,7 +178,7 @@ func sendInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -208,7 +207,7 @@ func listDomainInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -231,7 +230,7 @@ func listUserInvitationsEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -254,7 +253,7 @@ func acceptInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -274,7 +273,7 @@ func rejectInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -294,7 +293,7 @@ func deleteInvitationEndpoint(svc domains.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
+3 -1
View File
@@ -21,6 +21,7 @@ import (
"github.com/absmach/supermq/internal/testsutil"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
authnmock "github.com/absmach/supermq/pkg/authn/mocks"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
@@ -98,7 +99,8 @@ func newDomainsServer() (*httptest.Server, *mocks.Service, *authnmock.Authentica
authn := new(authnmock.Authentication)
mux := chi.NewMux()
idp := uuid.NewMock()
domainsapi.MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
domainsapi.MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
+6 -6
View File
@@ -11,7 +11,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
@@ -20,7 +20,7 @@ import (
)
// MakeHandler returns a HTTP handler for Domains and Invitations API endpoints.
func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
func MakeHandler(svc domains.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -30,7 +30,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
r.Use(api.RequestIDMiddleware(idp))
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware())
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
createDomainEndpoint(svc),
decodeCreateDomainRequest,
@@ -49,7 +49,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
})
r.Route("/{domainID}", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
retrieveDomainEndpoint(svc),
decodeRetrieveDomainRequest,
@@ -88,7 +88,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
})
r.Route("/{domainID}/invitations", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
sendInvitationEndpoint(svc),
decodeSendInvitationReq,
@@ -111,7 +111,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux,
})
mux.Route("/invitations", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware())
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listUserInvitationsEndpoint(svc),
decodeListInvitationsReq,
+2 -1
View File
@@ -60,7 +60,8 @@ func newGroupsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentica
mux := chi.NewRouter()
idp := uuid.NewMock()
logger := smqlog.NewMock()
mux = MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux = MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
+15 -16
View File
@@ -6,7 +6,6 @@ package api
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/authn"
@@ -22,7 +21,7 @@ func CreateGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return createGroupRes{created: false}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return createGroupRes{created: false}, svcerr.ErrAuthentication
}
@@ -43,7 +42,7 @@ func ViewGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return viewGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return viewGroupRes{}, svcerr.ErrAuthentication
}
@@ -64,7 +63,7 @@ func UpdateGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return updateGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return updateGroupRes{}, svcerr.ErrAuthentication
}
@@ -92,7 +91,7 @@ func updateGroupTagsEndpoint(svc groups.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -117,7 +116,7 @@ func EnableGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -137,7 +136,7 @@ func DisableGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -158,7 +157,7 @@ func ListGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
return groupPageRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return groupPageRes{}, svcerr.ErrAuthentication
}
@@ -198,7 +197,7 @@ func DeleteGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return deleteGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return deleteGroupRes{}, svcerr.ErrAuthentication
}
@@ -216,7 +215,7 @@ func retrieveGroupHierarchyEndpoint(svc groups.Service) endpoint.Endpoint {
return retrieveGroupHierarchyRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -244,7 +243,7 @@ func addParentGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return addParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -263,7 +262,7 @@ func removeParentGroupEndpoint(svc groups.Service) endpoint.Endpoint {
return removeParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -282,7 +281,7 @@ func addChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
return addChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -301,7 +300,7 @@ func removeChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
return removeChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -320,7 +319,7 @@ func removeAllChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
return removeAllChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
@@ -339,7 +338,7 @@ func listChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint {
return listChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return changeStatusRes{}, svcerr.ErrAuthentication
}
+3 -3
View File
@@ -10,7 +10,7 @@ import (
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/pkg/authn"
smqauthn "github.com/absmach/supermq/pkg/authn"
roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api"
"github.com/go-chi/chi/v5"
kithttp "github.com/go-kit/kit/transport/http"
@@ -19,14 +19,14 @@ import (
)
// MakeHandler returns a HTTP handler for Groups API endpoints.
func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
func MakeHandler(svc groups.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
d := roleManagerHttp.NewDecoder("groupID")
mux.Route("/{domainID}/groups", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Use(api.RequestIDMiddleware(idp))
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
+3 -2
View File
@@ -6,7 +6,7 @@ syntax = "proto3";
package auth.v1;
option go_package = "github.com/absmach/supermq/api/grpc/auth/v1";
// AuthService is a service that provides authentication
// AuthService is a service that provides authentication
// and authorization functionalities for SuperMQ services.
service AuthService {
rpc Authorize(AuthZReq) returns (AuthZRes) {}
@@ -23,7 +23,8 @@ message AuthNReq {
message AuthNRes {
string id = 1; // token id
string user_id = 2; // user id
uint32 user_role = 3; // user role
uint32 user_role = 3; // user role
bool verified = 4; // verified user
}
message AuthZReq {
+2
View File
@@ -15,10 +15,12 @@ message IssueReq {
string user_id = 1;
uint32 user_role = 2;
uint32 type = 3;
bool verified = 4;
}
message RefreshReq {
string refresh_token = 1;
bool verified = 2;
}
// If a token is not carrying any information itself, the type
+2 -3
View File
@@ -6,7 +6,6 @@ package api
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/journal"
"github.com/absmach/supermq/pkg/authn"
@@ -22,7 +21,7 @@ func retrieveJournalsEndpoint(svc journal.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -45,7 +44,7 @@ func retrieveClientTelemetryEndpoint(svc journal.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
+2 -1
View File
@@ -57,7 +57,8 @@ func newjournalServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic
logger := smqlog.NewMock()
authn := new(authnmocks.Authentication)
mux := httpapi.MakeHandler(svc, authn, logger, "journal-log", "test")
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := httpapi.MakeHandler(svc, am, logger, "journal-log", "test")
return httptest.NewServer(mux), svc, authn
}
+3 -3
View File
@@ -34,7 +34,7 @@ const (
)
// MakeHandler returns a HTTP API handler with health check and metrics.
func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slog.Logger, svcName, instanceID string) http.Handler {
func MakeHandler(svc journal.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, svcName, instanceID string) http.Handler {
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
@@ -43,7 +43,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo
idp := uuid.New()
mux.Use(api.RequestIDMiddleware(idp))
mux.With(api.AuthenticateMiddleware(authn, false)).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer(
mux.With(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware()).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer(
retrieveJournalsEndpoint(svc),
decodeRetrieveUserJournalReq,
api.EncodeResponse,
@@ -51,7 +51,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo
), "list_user_journals").ServeHTTP)
mux.Route("/{domainID}/journal", func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, true))
r.Use(authn.Middleware())
r.Get("/{entityType}/{entityID}", otelhttp.NewHandler(kithttp.NewServer(
retrieveJournalsEndpoint(svc),
+1
View File
@@ -44,6 +44,7 @@ type Session struct {
DomainID string
DomainUserID string
SuperAdmin bool
Verified bool
Role Role
}
+1 -1
View File
@@ -54,5 +54,5 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole()), Verified: res.GetVerified()}, nil
}
+177
View File
@@ -0,0 +1,177 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package authn
import (
"context"
"encoding/json"
"net/http"
"os"
"strconv"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
"github.com/go-chi/chi/v5"
)
type sessionKeyType string
const (
allowUnverifiedUserEnv = "SMQ_ALLOW_UNVERIFIED_USER"
jsonContentType = "application/json"
SessionKey = sessionKeyType("session")
)
// middlewareOptions contains configuration for authentication middleware.
type middlewareOptions struct {
domainCheck bool
allowUnverifiedUser bool
}
// defaultMiddlewareOptions returns the default middleware configuration.
func defaultMiddlewareOptions() *middlewareOptions {
return &middlewareOptions{
domainCheck: true,
allowUnverifiedUser: false,
}
}
// MiddlewareOption is a function that modifies middleware options.
type MiddlewareOption func(*middlewareOptions)
// WithDomainCheck sets whether domain checking is enabled.
func WithDomainCheck(enabled bool) MiddlewareOption {
return func(opts *middlewareOptions) {
opts.domainCheck = enabled
}
}
// WithAllowUnverifiedUser sets whether unverified users are allowed.
func WithAllowUnverifiedUser(allowed bool) MiddlewareOption {
return func(opts *middlewareOptions) {
opts.allowUnverifiedUser = allowed
}
}
// WithDefaultMiddlewareOptions resets options to default values.
func WithDefaultMiddlewareOptions() MiddlewareOption {
return func(opts *middlewareOptions) {
defaults := defaultMiddlewareOptions()
opts.domainCheck = defaults.domainCheck
opts.allowUnverifiedUser = defaults.allowUnverifiedUser
}
}
// AuthNMiddleware defines the interface for authenticated services with middleware.
type AuthNMiddleware interface {
Authentication
WithOptions(options ...MiddlewareOption) AuthNMiddleware
Middleware() func(http.Handler) http.Handler
}
// authnMiddleware wraps Authentication with middleware functionality.
type authnMiddleware struct {
Authentication
options []MiddlewareOption
}
// NewAuthNMiddleware creates a new authenticated service with middleware support.
// The order of precedence for options is as follows, with later options overriding earlier ones:
// 1. Default options (lowest precedence).
// 2. Options from environment variables (e.g., SMQ_ALLOW_UNVERIFIED_USER).
// 3. Options passed as arguments to this function (highest precedence).
//
// For example, consider the 'allowUnverifiedUser' option:
// - By default, it is 'false'.
// - If the SMQ_ALLOW_UNVERIFIED_USER environment variable is set to "true",
// it becomes 'true'.
// - If NewAuthNMiddleware is called with WithAllowUnverifiedUser(false), it will be 'false',
// regardless of the environment variable, as function arguments have the highest precedence.
func NewAuthNMiddleware(authnSvc Authentication, options ...MiddlewareOption) AuthNMiddleware {
allOptions := []MiddlewareOption{WithDefaultMiddlewareOptions()}
if val, ok := os.LookupEnv(allowUnverifiedUserEnv); ok {
allowUnverifiedUser, err := strconv.ParseBool(val)
if err == nil && allowUnverifiedUser {
allOptions = append(allOptions, WithAllowUnverifiedUser(true))
}
}
allOptions = append(allOptions, options...)
return &authnMiddleware{
Authentication: authnSvc,
options: allOptions,
}
}
// WithOptions returns a new service with additional options.
func (a *authnMiddleware) WithOptions(options ...MiddlewareOption) AuthNMiddleware {
return &authnMiddleware{
Authentication: a.Authentication,
options: append(a.options, options...),
}
}
// getMiddlewareOptions returns the configured middleware options.
func (a *authnMiddleware) getMiddlewareOptions() *middlewareOptions {
opts := defaultMiddlewareOptions()
for _, option := range a.options {
option(opts)
}
return opts
}
// Middleware returns an HTTP middleware function that handles authentication.
func (a *authnMiddleware) Middleware() func(http.Handler) http.Handler {
opts := a.getMiddlewareOptions()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := apiutil.ExtractBearerToken(r)
if token == "" {
encodeError(w, apiutil.ErrBearerToken, http.StatusUnauthorized)
return
}
resp, err := a.Authenticate(r.Context(), token)
if err != nil {
encodeError(w, err, http.StatusUnauthorized)
return
}
if resp.Type == AccessToken && !opts.allowUnverifiedUser && resp.Role != AdminRole && !resp.Verified {
encodeError(w, apiutil.ErrEmailNotVerified, http.StatusUnauthorized)
return
}
if opts.domainCheck {
domain := chi.URLParam(r, "domainID")
if domain == "" {
encodeError(w, apiutil.ErrMissingDomainID, http.StatusBadRequest)
return
}
resp.DomainID = domain
switch resp.Role {
case AdminRole:
resp.DomainUserID = resp.UserID
case UserRole:
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
}
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func encodeError(w http.ResponseWriter, err error, statusCode int) {
if errorVal, ok := err.(errors.Error); ok {
w.Header().Set("Content-Type", jsonContentType)
w.WriteHeader(statusCode)
if err := json.NewEncoder(w).Encode(errorVal); err != nil {
w.WriteHeader(http.StatusInternalServerError)
}
return
}
http.Error(w, err.Error(), statusCode)
}
+217
View File
@@ -0,0 +1,217 @@
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package mocks
import (
"context"
"net/http"
"github.com/absmach/supermq/pkg/authn"
mock "github.com/stretchr/testify/mock"
)
// NewAuthNMiddleware creates a new instance of AuthNMiddleware. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewAuthNMiddleware(t interface {
mock.TestingT
Cleanup(func())
}) *AuthNMiddleware {
mock := &AuthNMiddleware{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// AuthNMiddleware is an autogenerated mock type for the AuthNMiddleware type
type AuthNMiddleware struct {
mock.Mock
}
type AuthNMiddleware_Expecter struct {
mock *mock.Mock
}
func (_m *AuthNMiddleware) EXPECT() *AuthNMiddleware_Expecter {
return &AuthNMiddleware_Expecter{mock: &_m.Mock}
}
// Authenticate provides a mock function for the type AuthNMiddleware
func (_mock *AuthNMiddleware) Authenticate(ctx context.Context, token string) (authn.Session, error) {
ret := _mock.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for Authenticate")
}
var r0 authn.Session
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (authn.Session, error)); ok {
return returnFunc(ctx, token)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) authn.Session); ok {
r0 = returnFunc(ctx, token)
} else {
r0 = ret.Get(0).(authn.Session)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, token)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// AuthNMiddleware_Authenticate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Authenticate'
type AuthNMiddleware_Authenticate_Call struct {
*mock.Call
}
// Authenticate is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *AuthNMiddleware_Expecter) Authenticate(ctx interface{}, token interface{}) *AuthNMiddleware_Authenticate_Call {
return &AuthNMiddleware_Authenticate_Call{Call: _e.mock.On("Authenticate", ctx, token)}
}
func (_c *AuthNMiddleware_Authenticate_Call) Run(run func(ctx context.Context, token string)) *AuthNMiddleware_Authenticate_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *AuthNMiddleware_Authenticate_Call) Return(session authn.Session, err error) *AuthNMiddleware_Authenticate_Call {
_c.Call.Return(session, err)
return _c
}
func (_c *AuthNMiddleware_Authenticate_Call) RunAndReturn(run func(ctx context.Context, token string) (authn.Session, error)) *AuthNMiddleware_Authenticate_Call {
_c.Call.Return(run)
return _c
}
// Middleware provides a mock function for the type AuthNMiddleware
func (_mock *AuthNMiddleware) Middleware() func(http.Handler) http.Handler {
ret := _mock.Called()
if len(ret) == 0 {
panic("no return value specified for Middleware")
}
var r0 func(http.Handler) http.Handler
if returnFunc, ok := ret.Get(0).(func() func(http.Handler) http.Handler); ok {
r0 = returnFunc()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(func(http.Handler) http.Handler)
}
}
return r0
}
// AuthNMiddleware_Middleware_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Middleware'
type AuthNMiddleware_Middleware_Call struct {
*mock.Call
}
// Middleware is a helper method to define mock.On call
func (_e *AuthNMiddleware_Expecter) Middleware() *AuthNMiddleware_Middleware_Call {
return &AuthNMiddleware_Middleware_Call{Call: _e.mock.On("Middleware")}
}
func (_c *AuthNMiddleware_Middleware_Call) Run(run func()) *AuthNMiddleware_Middleware_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *AuthNMiddleware_Middleware_Call) Return(fn func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call {
_c.Call.Return(fn)
return _c
}
func (_c *AuthNMiddleware_Middleware_Call) RunAndReturn(run func() func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call {
_c.Call.Return(run)
return _c
}
// WithOptions provides a mock function for the type AuthNMiddleware
func (_mock *AuthNMiddleware) WithOptions(options ...authn.MiddlewareOption) authn.AuthNMiddleware {
var tmpRet mock.Arguments
if len(options) > 0 {
tmpRet = _mock.Called(options)
} else {
tmpRet = _mock.Called()
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for WithOptions")
}
var r0 authn.AuthNMiddleware
if returnFunc, ok := ret.Get(0).(func(...authn.MiddlewareOption) authn.AuthNMiddleware); ok {
r0 = returnFunc(options...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(authn.AuthNMiddleware)
}
}
return r0
}
// AuthNMiddleware_WithOptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithOptions'
type AuthNMiddleware_WithOptions_Call struct {
*mock.Call
}
// WithOptions is a helper method to define mock.On call
// - options ...authn.MiddlewareOption
func (_e *AuthNMiddleware_Expecter) WithOptions(options ...interface{}) *AuthNMiddleware_WithOptions_Call {
return &AuthNMiddleware_WithOptions_Call{Call: _e.mock.On("WithOptions",
append([]interface{}{}, options...)...)}
}
func (_c *AuthNMiddleware_WithOptions_Call) Run(run func(options ...authn.MiddlewareOption)) *AuthNMiddleware_WithOptions_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 []authn.MiddlewareOption
var variadicArgs []authn.MiddlewareOption
if len(args) > 0 {
variadicArgs = args[0].([]authn.MiddlewareOption)
}
arg0 = variadicArgs
run(
arg0...,
)
})
return _c
}
func (_c *AuthNMiddleware_WithOptions_Call) Return(authNMiddleware authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call {
_c.Call.Return(authNMiddleware)
return _c
}
func (_c *AuthNMiddleware_WithOptions_Call) RunAndReturn(run func(options ...authn.MiddlewareOption) authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call {
_c.Call.Return(run)
return _c
}
+9
View File
@@ -93,4 +93,13 @@ var (
// ErrSuperAdminAction indicates that the user is not a super admin.
ErrSuperAdminAction = errors.New("not authorized to perform admin action")
// ErrUserAlreadyVerified indicates user is already verified.
ErrUserAlreadyVerified = errors.New("user already verified")
// ErrInvalidUserVerification indicates user verification is invalid.
ErrInvalidUserVerification = errors.New("invalid verification")
// ErrUserVerificationExpired indicates user verification is expired.
ErrUserVerificationExpired = errors.New("verification expired, please generate new verification")
)
+16 -17
View File
@@ -6,7 +6,6 @@ package http
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
@@ -22,7 +21,7 @@ func CreateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -42,7 +41,7 @@ func ListRolesEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -62,7 +61,7 @@ func ListEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -94,7 +93,7 @@ func RemoveEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -113,7 +112,7 @@ func ViewRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -133,7 +132,7 @@ func UpdateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -153,7 +152,7 @@ func DeleteRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -172,7 +171,7 @@ func ListAvailableActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -192,7 +191,7 @@ func AddRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -212,7 +211,7 @@ func ListRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -232,7 +231,7 @@ func DeleteRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -251,7 +250,7 @@ func DeleteAllRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -270,7 +269,7 @@ func AddRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -290,7 +289,7 @@ func ListRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -310,7 +309,7 @@ func DeleteRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -329,7 +328,7 @@ func DeleteAllRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
+2 -1
View File
@@ -67,7 +67,8 @@ func setupCerts() (*httptest.Server, *mocks.Service, *authnmocks.Authentication)
logger := smqlog.NewMock()
authn := new(authnmocks.Authentication)
idp := uuid.NewMock()
mux := httpapi.MakeHandler(svc, authn, logger, instanceID, idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := httpapi.MakeHandler(svc, am, logger, instanceID, idp)
return httptest.NewServer(mux), svc, authn
}
+3 -2
View File
@@ -43,7 +43,8 @@ func setupChannels() (*httptest.Server, *chmocks.Service, *authnmocks.Authentica
authn := new(authnmocks.Authentication)
mux := chi.NewRouter()
idp := uuid.NewMock()
chapi.MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
chapi.MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
@@ -1094,7 +1095,7 @@ func TestUpdateChannelTags(t *testing.T) {
svcRes: channels.Channel{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Channel{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "update channel tags with empty token",
+9 -8
View File
@@ -37,7 +37,8 @@ func setupClients() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
mux := chi.NewRouter()
idp := uuid.NewMock()
authn := new(authnmocks.Authentication)
api.MakeHandler(tsvc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
api.MakeHandler(tsvc, am, mux, logger, "", idp)
return httptest.NewServer(mux), tsvc, authn
}
@@ -647,7 +648,7 @@ func TestViewClient(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "view client with empty token",
@@ -786,7 +787,7 @@ func TestUpdateClient(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "update client with empty token",
@@ -940,7 +941,7 @@ func TestUpdateClientTags(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "update client tags with empty token",
@@ -1089,7 +1090,7 @@ func TestUpdateClientSecret(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "update client secret with empty token",
@@ -1217,7 +1218,7 @@ func TestEnableClient(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "enable client with an invalid client id",
@@ -1320,7 +1321,7 @@ func TestDisableClient(t *testing.T) {
svcRes: clients.Client{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Client{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "disable client with an invalid client id",
@@ -1415,7 +1416,7 @@ func TestDeleteClient(t *testing.T) {
token: invalidToken,
clientID: client.ID,
authenticateErr: svcerr.ErrAuthorization,
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "delete client with empty token",
+2 -1
View File
@@ -62,8 +62,9 @@ func setupDomains() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
mux := chi.NewRouter()
idp := uuid.NewMock()
authn := new(authnmocks.Authentication)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
handler := domainapi.MakeHandler(svc, authn, mux, logger, "", idp)
handler := domainapi.MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(handler), svc, authn
}
+3 -2
View File
@@ -48,7 +48,8 @@ func setupGroups() (*httptest.Server, *mocks.Service, *authnmocks.Authentication
provider := new(oauth2mocks.Provider)
provider.On("Name").Return(roleName)
authn := new(authnmocks.Authentication)
httpapi.MakeHandler(svc, authn, mux, logger, "", idp)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
httpapi.MakeHandler(svc, am, mux, logger, "", idp)
return httptest.NewServer(mux), svc, authn
}
@@ -922,7 +923,7 @@ func TestUpdateGroupTags(t *testing.T) {
svcRes: groups.Group{},
authenticateErr: svcerr.ErrAuthorization,
response: sdk.Group{},
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden),
err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized),
},
{
desc: "update group tags with empty token",
+2 -1
View File
@@ -29,7 +29,8 @@ func setupJournal() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio
svc := new(mocks.Service)
authn := new(authnmocks.Authentication)
logger := smqlog.NewMock()
mux := api.MakeHandler(svc, authn, logger, "journal-log", "test")
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
mux := api.MakeHandler(svc, am, logger, "journal-log", "test")
return httptest.NewServer(mux), svc, authn
}
+118
View File
@@ -8040,6 +8040,65 @@ func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domai
return _c
}
// SendVerification provides a mock function for the type SDK
func (_mock *SDK) SendVerification(ctx context.Context, token string) errors.SDKError {
ret := _mock.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for SendVerification")
}
var r0 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
r0 = returnFunc(ctx, token)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(errors.SDKError)
}
}
return r0
}
// SDK_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
type SDK_SendVerification_Call struct {
*mock.Call
}
// SendVerification is a helper method to define mock.On call
// - ctx context.Context
// - token string
func (_e *SDK_Expecter) SendVerification(ctx interface{}, token interface{}) *SDK_SendVerification_Call {
return &SDK_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, token)}
}
func (_c *SDK_SendVerification_Call) Run(run func(ctx context.Context, token string)) *SDK_SendVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *SDK_SendVerification_Call) Return(sDKError errors.SDKError) *SDK_SendVerification_Call {
_c.Call.Return(sDKError)
return _c
}
func (_c *SDK_SendVerification_Call) RunAndReturn(run func(ctx context.Context, token string) errors.SDKError) *SDK_SendVerification_Call {
_c.Call.Return(run)
return _c
}
// SetChannelParent provides a mock function for the type SDK
func (_mock *SDK) SetChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError {
ret := _mock.Called(ctx, id, domainID, groupID, token)
@@ -9974,6 +10033,65 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk.Page
return _c
}
// VerifyEmail provides a mock function for the type SDK
func (_mock *SDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError {
ret := _mock.Called(ctx, verificationToken)
if len(ret) == 0 {
panic("no return value specified for VerifyEmail")
}
var r0 errors.SDKError
if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok {
r0 = returnFunc(ctx, verificationToken)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(errors.SDKError)
}
}
return r0
}
// SDK_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail'
type SDK_VerifyEmail_Call struct {
*mock.Call
}
// VerifyEmail is a helper method to define mock.On call
// - ctx context.Context
// - verificationToken string
func (_e *SDK_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *SDK_VerifyEmail_Call {
return &SDK_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)}
}
func (_c *SDK_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *SDK_VerifyEmail_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *SDK_VerifyEmail_Call) Return(sDKError errors.SDKError) *SDK_VerifyEmail_Call {
_c.Call.Return(sDKError)
return _c
}
func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) errors.SDKError) *SDK_VerifyEmail_Call {
_c.Call.Return(run)
return _c
}
// ViewCert provides a mock function for the type SDK
func (_mock *SDK) ViewCert(ctx context.Context, certID string, domainID string, token string) (sdk.Cert, errors.SDKError) {
ret := _mock.Called(ctx, certID, domainID, token)
+14
View File
@@ -178,6 +178,20 @@ type SDK interface {
// fmt.Println(user)
CreateUser(ctx context.Context, user User, token string) (User, errors.SDKError)
// SendVerification sends a verification email to the user.
//
// example:
// err := sdk.SendVerification("token")
// fmt.Println(err)
SendVerification(ctx context.Context, token string) errors.SDKError
// VerifyEmail verifies the user's email address using the provided token.
//
// example:
// err := sdk.VerifyEmail("verificationToken")
// fmt.Println(user)
VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError
// User returns user object by id.
//
// example:
+28 -7
View File
@@ -15,13 +15,16 @@ import (
)
const (
usersEndpoint = "users"
enableEndpoint = "enable"
disableEndpoint = "disable"
issueTokenEndpoint = "tokens/issue"
refreshTokenEndpoint = "tokens/refresh"
membersEndpoint = "members"
PasswordResetEndpoint = "password"
usersEndpoint = "users"
enableEndpoint = "enable"
disableEndpoint = "disable"
issueTokenEndpoint = "tokens/issue"
refreshTokenEndpoint = "tokens/refresh"
membersEndpoint = "members"
PasswordResetEndpoint = "password"
sendVerificationEndpoint = "send-verification"
verifyEmailEndpoint = "verify-email"
tokenQueryParamKey = "token"
)
// User represents supermq user its credentials.
@@ -61,6 +64,24 @@ func (sdk mgSDK) CreateUser(ctx context.Context, user User, token string) (User,
return user, nil
}
func (sdk mgSDK) SendVerification(ctx context.Context, token string) errors.SDKError {
url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, usersEndpoint, sendVerificationEndpoint)
_, _, sdkErr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK)
return sdkErr
}
func (sdk mgSDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError {
url := fmt.Sprintf("%s/%s?%s=%s", sdk.usersURL, verifyEmailEndpoint, tokenQueryParamKey, verificationToken)
_, _, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK)
if sdkErr != nil {
return sdkErr
}
return nil
}
func (sdk mgSDK) Users(ctx context.Context, pm PageMetadata, token string) (UsersPage, errors.SDKError) {
url, err := sdk.withQueryParams(sdk.usersURL, usersEndpoint, pm)
if err != nil {
+6 -1
View File
@@ -45,8 +45,9 @@ func setupUsers() (*httptest.Server, *umocks.Service, *authnmocks.Authentication
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
authn := new(authnmocks.Authentication)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithDomainCheck(false), smqauthn.WithAllowUnverifiedUser(true))
token := new(authmocks.TokenServiceClient)
httpapi.MakeHandler(usvc, authn, token, true, mux, logger, "", passRegex, idp, provider)
httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, provider)
return httptest.NewServer(mux), usvc, authn
}
@@ -567,6 +568,10 @@ func TestListUsers(t *testing.T) {
resp, err := mgsdk.Users(context.Background(), tc.pageMeta, tc.token)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.token != "" {
ok := authCall.Parent.AssertCalled(t, "Authenticate", mock.Anything, tc.token)
assert.True(t, ok)
}
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "ListUsers", mock.Anything, tc.session, tc.svcReq)
assert.True(t, ok)
+1 -1
View File
@@ -38,7 +38,7 @@ done
###
# Users
###
SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_EMAIL_TEMPLATE=../docker/templates/users.tmpl $BUILD_DIR/supermq-users &
SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=../docker/templates/reset-password-email.tmpl SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email SMQ_VERIFICATION_EMAIL_TEMPLATE=../docker/templates/verification-email.tmpl $BUILD_DIR/supermq-users &
###
# Clients
+1
View File
@@ -114,6 +114,7 @@ packages:
github.com/absmach/supermq/pkg/authn:
interfaces:
Authentication:
AuthNMiddleware:
github.com/absmach/supermq/pkg/authz:
interfaces:
Authorization:
+49 -43
View File
@@ -12,48 +12,51 @@ For in-depth explanation of the aforementioned scenarios, as well as thorough un
The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values.
| Variable | Description | Default |
| ------------------------------ | ----------------------------------------------------------------------- | --------------------------------- |
| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info |
| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | <admin@example.com> |
| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 |
| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ |
| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost |
| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 |
| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" |
| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" |
| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" |
| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 |
| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s |
| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" |
| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" |
| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
| SMQ_USERS_DB_HOST | Database host address | localhost |
| SMQ_USERS_DB_PORT | Database host port | 5432 |
| SMQ_USERS_DB_USER | Database user | supermq |
| SMQ_USERS_DB_PASS | Database password | supermq |
| SMQ_USERS_DB_NAME | Name of the database used by the service | users |
| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" |
| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
| SMQ_EMAIL_HOST | Mail server host | localhost |
| SMQ_EMAIL_PORT | Mail server port | 25 |
| SMQ_EMAIL_USERNAME | Mail server username | "" |
| SMQ_EMAIL_PASSWORD | Mail server password | "" |
| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" |
| SMQ_EMAIL_FROM_NAME | Email "from" name | "" |
| SMQ_EMAIL_TEMPLATE | Email template for sending emails with password reset link | email.tmpl |
| SMQ_USERS_ES_URL | Event store URL | <nats://localhost:4222> |
| SMQ_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | <http://localhost:9095/domains> |
| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | <http://localhost:9095/error> |
| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h |
| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h |
| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true |
| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" |
| Variable | Description | Default |
| --------------------------------- | ----------------------------------------------------------------------- | --------------------------------- |
| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info |
| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | <admin@example.com> |
| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 |
| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ |
| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost |
| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 |
| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" |
| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" |
| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" |
| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 |
| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s |
| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" |
| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" |
| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" |
| SMQ_USERS_DB_HOST | Database host address | localhost |
| SMQ_USERS_DB_PORT | Database host port | 5432 |
| SMQ_USERS_DB_USER | Database user | supermq |
| SMQ_USERS_DB_PASS | Database password | supermq |
| SMQ_USERS_DB_NAME | Name of the database used by the service | users |
| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable |
| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" |
| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" |
| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" |
| SMQ_EMAIL_HOST | Mail server host | localhost |
| SMQ_EMAIL_PORT | Mail server port | 25 |
| SMQ_EMAIL_USERNAME | Mail server username | "" |
| SMQ_EMAIL_PASSWORD | Mail server password | "" |
| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" |
| SMQ_EMAIL_FROM_NAME | Email "from" name | "" |
| SMQ_PASSWORD_RESET_URL_PREFIX | Password reset URL prefix | <http://localhost/password/reset> |
| SMQ_PASSWORD_RESET_EMAIL_TEMPLATE | Password reset email template | reset-password-email.tmpl |
| SMQ_VERIFICATION_URL_PREFIX | Verification URL prefix | <http://localhost/verify-email> |
| SMQ_VERIFICATION_EMAIL_TEMPLATE | Verification email template | verification-email.tmpl |
| SMQ_USERS_ES_URL | Event store URL | <nats://localhost:4222> |
| SMQ_JAEGER_URL | Jaeger server URL | <http://localhost:4318/v1/traces> |
| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | <http://localhost:9095/domains> |
| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | <http://localhost:9095/error> |
| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h |
| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h |
| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 |
| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true |
| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" |
## Deployment
@@ -104,7 +107,10 @@ SMQ_EMAIL_USERNAME="18bf7f7070513" \
SMQ_EMAIL_PASSWORD="2b0d302e775b1e" \
SMQ_EMAIL_FROM_ADDRESS=from@example.com \
SMQ_EMAIL_FROM_NAME=Example \
SMQ_EMAIL_TEMPLATE="docker/templates/users.tmpl" \
SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset \
SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=docker/templates/reset-password-email.tmpl \
SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email \
SMQ_VERIFICATION_EMAIL_TEMPLATE=docker/templates/verification-email.tmpl \
SMQ_USERS_ES_URL=nats://localhost:4222 \
SMQ_JAEGER_URL=http://localhost:14268/api/traces \
SMQ_JAEGER_TRACE_RATIO=1.0 \
+147 -1
View File
@@ -94,8 +94,9 @@ func newUsersServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticat
provider := new(oauth2mocks.Provider)
provider.On("Name").Return("test")
authn := new(authnmocks.Authentication)
am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true))
token := new(authmocks.TokenServiceClient)
usersapi.MakeHandler(svc, authn, token, true, mux, logger, "", passRegex, idp, provider)
usersapi.MakeHandler(svc, am, token, true, mux, logger, "", passRegex, idp, provider)
return httptest.NewServer(mux), svc, authn
}
@@ -1750,6 +1751,151 @@ func TestPasswordResetRequest(t *testing.T) {
}
}
func TestSendVerification(t *testing.T) {
us, svc, authn := newUsersServer()
defer us.Close()
cases := []struct {
desc string
token string
status int
authnRes smqauthn.Session
authnErr error
svcErr error
err error
}{
{
desc: "send verification with valid token",
token: validToken,
status: http.StatusOK,
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID},
err: nil,
},
{
desc: "send verification with invalid token",
token: inValidToken,
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
authnRes: smqauthn.Session{},
err: svcerr.ErrAuthentication,
},
{
desc: "send verification with empty token",
token: "",
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
authnRes: smqauthn.Session{},
err: apiutil.ErrBearerToken,
},
{
desc: "send verification with service error",
token: validToken,
status: http.StatusUnprocessableEntity,
authnRes: smqauthn.Session{UserID: validID, DomainID: domainID},
svcErr: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/users/send-verification", us.URL),
token: tc.token,
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("SendVerification", mock.Anything, tc.authnRes).Return(tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
body, err := io.ReadAll(res.Body)
if err != nil {
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err))
}
defer res.Body.Close()
var errRes respBody
if len(body) > 0 {
if err := json.Unmarshal(body, &errRes); err != nil {
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err))
}
}
if errRes.Err != "" || errRes.Message != "" {
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
func TestVerifyEmail(t *testing.T) {
us, svc, _ := newUsersServer()
defer us.Close()
cases := []struct {
desc string
token string
status int
svcErr error
err error
}{
{
desc: "verify email with valid token",
token: validToken,
status: http.StatusOK,
err: nil,
},
{
desc: "verify email with empty token",
token: "",
status: http.StatusBadRequest,
err: apiutil.ErrValidation,
},
{
desc: "verify email with service error",
token: validToken,
status: http.StatusBadRequest,
svcErr: svcerr.ErrMalformedEntity,
err: svcerr.ErrMalformedEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodGet,
url: fmt.Sprintf("%s/verify-email?token=%s", us.URL, tc.token),
}
svcCall := svc.On("VerifyEmail", mock.Anything, mock.Anything).Return(users.User{}, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
body, err := io.ReadAll(res.Body)
if err != nil {
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err))
}
defer res.Body.Close()
var errRes respBody
if len(body) > 0 {
if err := json.Unmarshal(body, &errRes); err != nil {
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err))
}
}
if errRes.Err != "" || errRes.Message != "" {
err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
})
}
}
func TestPasswordReset(t *testing.T) {
us, svc, authn := newUsersServer()
defer us.Close()
+48 -17
View File
@@ -6,7 +6,6 @@ package api
import (
"context"
api "github.com/absmach/supermq/api/http"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/errors"
@@ -25,7 +24,7 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin
var ok bool
if !selfRegister {
session, ok = ctx.Value(api.SessionKey).(authn.Session)
session, ok = ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -43,6 +42,38 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin
}
}
func sendVerificationEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
_ = request.(sendVerificationReq)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
if err := svc.SendVerification(ctx, session); err != nil {
return sendVerificationRes{}, err
}
return sendVerificationRes{}, nil
}
}
func verifyEmailEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(verifyEmailReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
if _, err := svc.VerifyEmail(ctx, req.token); err != nil {
return verifyEmailRes{}, err
}
return verifyEmailRes{}, nil
}
}
func viewEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(viewUserReq)
@@ -50,7 +81,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -65,7 +96,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint {
func viewProfileEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -85,7 +116,7 @@ func listUsersEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -172,7 +203,7 @@ func updateEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -199,7 +230,7 @@ func updateTagsEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -224,7 +255,7 @@ func updateEmailEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -273,7 +304,7 @@ func passwordResetEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -292,7 +323,7 @@ func updateSecretEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -312,7 +343,7 @@ func updateUsernameEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -337,7 +368,7 @@ func updateProfilePictureEndpoint(svc users.Service) endpoint.Endpoint {
ProfilePicture: req.ProfilePicture,
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthorization
}
@@ -363,7 +394,7 @@ func updateRoleEndpoint(svc users.Service) endpoint.Endpoint {
Role: req.role,
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -404,7 +435,7 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -429,7 +460,7 @@ func enableEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -450,7 +481,7 @@ func disableEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
@@ -471,7 +502,7 @@ func deleteEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(api.SessionKey).(authn.Session)
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
+20 -6
View File
@@ -4,7 +4,6 @@
package api
import (
"net/mail"
"net/url"
api "github.com/absmach/supermq/api/http"
@@ -39,15 +38,16 @@ func (req createUserReq) validate() error {
return err
}
// Username must not be a valid email format due to username/email login.
if _, err := mail.ParseAddress(req.User.Credentials.Username); err == nil {
if err := api.ValidateEmail(req.User.Credentials.Username); err == nil {
return apiutil.ErrInvalidUsername
}
if req.User.Email == "" {
return apiutil.ErrMissingEmail
}
// Email must be in a valid format.
if _, err := mail.ParseAddress(req.User.Email); err != nil {
return apiutil.ErrInvalidEmail
if err := api.ValidateEmail(req.User.Email); err != nil {
return err
}
if req.User.Credentials.Secret == "" {
return apiutil.ErrMissingPass
@@ -67,6 +67,20 @@ func (req createUserReq) validate() error {
return req.User.Validate()
}
type sendVerificationReq struct{}
type verifyEmailReq struct {
token string
}
func (req verifyEmailReq) validate() error {
if req.token == "" {
return apiutil.ErrInvalidVerification
}
return nil
}
type viewUserReq struct {
id string
}
@@ -183,8 +197,8 @@ func (req updateEmailReq) validate() error {
if req.id == "" {
return apiutil.ErrMissingID
}
if _, err := mail.ParseAddress(req.Email); err != nil {
return apiutil.ErrInvalidEmail
if err := api.ValidateEmail(req.Email); err != nil {
return err
}
return nil
+30
View File
@@ -16,6 +16,8 @@ const MailSent = "Email with reset link is sent"
var (
_ supermq.Response = (*tokenRes)(nil)
_ supermq.Response = (*sendVerificationRes)(nil)
_ supermq.Response = (*verifyEmailRes)(nil)
_ supermq.Response = (*viewUserRes)(nil)
_ supermq.Response = (*createUserRes)(nil)
_ supermq.Response = (*changeUserStatusRes)(nil)
@@ -79,6 +81,34 @@ func (res tokenRes) Empty() bool {
return res.AccessToken == "" || res.RefreshToken == ""
}
type sendVerificationRes struct{}
func (res sendVerificationRes) Code() int {
return http.StatusOK
}
func (res sendVerificationRes) Headers() map[string]string {
return map[string]string{}
}
func (res sendVerificationRes) Empty() bool {
return true
}
type verifyEmailRes struct{}
func (res verifyEmailRes) Code() int {
return http.StatusOK
}
func (res verifyEmailRes) Headers() map[string]string {
return map[string]string{}
}
func (res verifyEmailRes) Empty() bool {
return true
}
type updateUserRes struct {
users.User `json:",inline"`
}
+1 -1
View File
@@ -18,7 +18,7 @@ import (
)
// MakeHandler returns a HTTP handler for Users and Groups API endpoints.
func MakeHandler(cls users.Service, authn smqauthn.Authentication, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler {
func MakeHandler(cls users.Service, authn smqauthn.AuthNMiddleware, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler {
mux = usersHandler(cls, authn, tokensvc, selfRegister, mux, logger, pr, idp, providers...)
mux.Get("/health", supermq.Health("users", instanceID))
+52 -19
View File
@@ -28,13 +28,15 @@ import (
var passRegex = regexp.MustCompile("^.{8,}$")
// usersHandler returns a HTTP handler for API endpoints.
func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux {
func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux {
passRegex = pr
opts := []kithttp.ServerOption{
kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)),
}
// All endpoints in users service don't required Domain check
authn = authn.WithOptions(smqauthn.WithDomainCheck(false))
r.Route("/users", func(r chi.Router) {
r.Use(api.RequestIDMiddleware(idp))
@@ -47,16 +49,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
opts...,
), "register_user").ServeHTTP)
default:
r.With(api.AuthenticateMiddleware(authn, false)).Post("/", otelhttp.NewHandler(kithttp.NewServer(
r.With(authn.Middleware()).Post("/", otelhttp.NewHandler(kithttp.NewServer(
registrationEndpoint(svc, selfRegister),
decodeCreateUserReq,
api.EncodeResponse,
opts...,
), "register_user").ServeHTTP)
}
// Endpoints which are allowed for unverified user
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Use(authn.WithOptions(smqauthn.WithAllowUnverifiedUser(true)).Middleware())
r.Post("/send-verification", otelhttp.NewHandler(kithttp.NewServer(
sendVerificationEndpoint(svc),
decodeSendVerification,
api.EncodeResponse,
opts...,
), "send_verification").ServeHTTP)
r.Get("/profile", otelhttp.NewHandler(kithttp.NewServer(
viewProfileEndpoint(svc),
@@ -64,6 +72,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
api.EncodeResponse,
opts...,
), "view_profile").ServeHTTP)
r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer(
refreshTokenEndpoint(svc),
decodeRefreshToken,
api.EncodeResponse,
opts...,
), "refresh_token").ServeHTTP)
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
updateEmailEndpoint(svc),
decodeUpdateUserEmail,
api.EncodeResponse,
opts...,
), "update_user_email").ServeHTTP)
})
r.Group(func(r chi.Router) {
r.Use(authn.Middleware())
r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer(
viewEndpoint(svc),
@@ -121,13 +145,6 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
opts...,
), "update_user_tags").ServeHTTP)
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
updateEmailEndpoint(svc),
decodeUpdateUserEmail,
api.EncodeResponse,
opts...,
), "update_user_email").ServeHTTP)
r.Patch("/{id}/role", otelhttp.NewHandler(kithttp.NewServer(
updateRoleEndpoint(svc),
decodeUpdateUserRole,
@@ -155,18 +172,11 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
api.EncodeResponse,
opts...,
), "delete_user").ServeHTTP)
r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer(
refreshTokenEndpoint(svc),
decodeRefreshToken,
api.EncodeResponse,
opts...,
), "refresh_token").ServeHTTP)
})
})
r.Group(func(r chi.Router) {
r.Use(api.AuthenticateMiddleware(authn, false))
r.Use(authn.Middleware())
r.Put("/password/reset", otelhttp.NewHandler(kithttp.NewServer(
passwordResetEndpoint(svc),
decodePasswordReset,
@@ -189,6 +199,13 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
opts...,
), "password_reset_req").ServeHTTP)
r.Get("/verify-email", otelhttp.NewHandler(kithttp.NewServer(
verifyEmailEndpoint(svc),
decodeVerifyEmail,
api.EncodeResponse,
opts...,
), "verify_email").ServeHTTP)
for _, provider := range providers {
r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, svc, tokenClient))
}
@@ -196,6 +213,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient
return r
}
func decodeSendVerification(_ context.Context, r *http.Request) (any, error) {
req := sendVerificationReq{}
return req, nil
}
func decodeVerifyEmail(_ context.Context, r *http.Request) (any, error) {
token, err := apiutil.ReadStringQuery(r, api.TokenKey, "")
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
return verifyEmailReq{
token: token,
}, nil
}
func decodeViewUser(_ context.Context, r *http.Request) (any, error) {
req := viewUserReq{
id: chi.URLParam(r, "id"),
+3
View File
@@ -7,4 +7,7 @@ package users
type Emailer interface {
// SendPasswordReset sends an email to the user with a link to reset the password.
SendPasswordReset(To []string, user, token string) error
// SendVerification sends an email to the user with a verification token.
SendVerification(To []string, user, verificationToken string) error
}
+26 -8
View File
@@ -13,20 +13,38 @@ import (
var _ users.Emailer = (*emailer)(nil)
type emailer struct {
resetURL string
agent *email.Agent
resetURL string
verificationURL string
resetAgent *email.Agent
verifyAgent *email.Agent
}
// New creates new emailer utility.
func New(url string, c *email.Config) (users.Emailer, error) {
e, err := email.New(c)
func New(resetURL, verificationURL string, resetConfig, verifyConfig *email.Config) (users.Emailer, error) {
resetAgent, err := email.New(resetConfig)
if err != nil {
return nil, err
}
verifyAgent, err := email.New(verifyConfig)
if err != nil {
return nil, err
}
return &emailer{
resetURL: url,
agent: e,
}, err
resetURL: resetURL,
verificationURL: verificationURL,
resetAgent: resetAgent,
verifyAgent: verifyAgent,
}, nil
}
func (e *emailer) SendPasswordReset(to []string, user, token string) error {
url := fmt.Sprintf("%s?token=%s", e.resetURL, token)
return e.agent.Send(to, "", "Password Reset Request", "", user, url, "")
return e.resetAgent.Send(to, "", "Password Reset Request", "", user, url, "")
}
func (e *emailer) SendVerification(to []string, user, verificationToken string) error {
url := fmt.Sprintf("%s?token=%s", e.verificationURL, verificationToken)
return e.verifyAgent.Send(to, "", "Email Verification", "", user, url, "")
}
+35
View File
@@ -14,6 +14,8 @@ import (
const (
userPrefix = "user."
userCreate = userPrefix + "create"
userSendVerification = userPrefix + "send_verification"
userVerifyEmail = userPrefix + "verify_email"
userUpdate = userPrefix + "update"
userUpdateRole = userPrefix + "update_role"
userUpdateTags = userPrefix + "update_tags"
@@ -41,6 +43,8 @@ const (
var (
_ events.Event = (*createUserEvent)(nil)
_ events.Event = (*sendVerificationEvent)(nil)
_ events.Event = (*verifyEmailEvent)(nil)
_ events.Event = (*updateUserEvent)(nil)
_ events.Event = (*updateProfilePictureEvent)(nil)
_ events.Event = (*updateUsernameEvent)(nil)
@@ -99,6 +103,37 @@ func (uce createUserEvent) Encode() (map[string]any, error) {
return val, nil
}
type sendVerificationEvent struct {
authn.Session
requestID string
}
func (sve sendVerificationEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": userSendVerification,
"user_id": sve.UserID,
"token_type": sve.Type.String(),
"request_id": sve.requestID,
}, nil
}
type verifyEmailEvent struct {
requestID string
email string
userID string
verifiedAt time.Time
}
func (vee verifyEmailEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": userVerifyEmail,
"request_id": vee.requestID,
"email": vee.email,
"user_id": vee.userID,
"verified_at": vee.verifiedAt,
}, nil
}
type updateUserEvent struct {
users.User
operation string
+34
View File
@@ -17,6 +17,8 @@ import (
const (
supermqPrefix = "supermq."
createStream = supermqPrefix + userCreate
sendVerificationStream = supermqPrefix + userSendVerification
verifyEmailStream = supermqPrefix + userVerifyEmail
updateStream = supermqPrefix + userUpdate
updateRoleStream = supermqPrefix + userUpdateRole
updateTagsStream = supermqPrefix + userUpdateTags
@@ -82,6 +84,38 @@ func (es *eventStore) Register(ctx context.Context, session authn.Session, user
return user, nil
}
func (es *eventStore) SendVerification(ctx context.Context, session authn.Session) error {
err := es.svc.SendVerification(ctx, session)
if err != nil {
return err
}
event := sendVerificationEvent{
session,
middleware.GetReqID(ctx),
}
return es.Publish(ctx, sendVerificationStream, event)
}
func (es *eventStore) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
user, err := es.svc.VerifyEmail(ctx, verificationToken)
if err != nil {
return user, err
}
event := verifyEmailEvent{
email: user.Email,
userID: user.ID,
verifiedAt: user.VerifiedAt,
requestID: middleware.GetReqID(ctx),
}
if err := es.Publish(ctx, verifyEmailStream, event); err != nil {
return user, err
}
return user, nil
}
func (es *eventStore) Update(ctx context.Context, session authn.Session, id string, usr users.UserReq) (users.User, error) {
user, err := es.svc.Update(ctx, session, id, usr)
if err != nil {
+9 -5
View File
@@ -26,11 +26,15 @@ type authorizationMiddleware struct {
// AuthorizationMiddleware adds authorization to the clients service.
func AuthorizationMiddleware(svc users.Service, authz smqauthz.Authorization, selfRegister bool) users.Service {
return &authorizationMiddleware{
svc: svc,
authz: authz,
selfRegister: selfRegister,
}
return &authorizationMiddleware{svc: svc, authz: authz, selfRegister: selfRegister}
}
func (am *authorizationMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
return am.svc.SendVerification(ctx, session)
}
func (am *authorizationMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
return am.svc.VerifyEmail(ctx, verificationToken)
}
func (am *authorizationMiddleware) Register(ctx context.Context, session authn.Session, user users.User, selfRegister bool) (users.User, error) {
+37
View File
@@ -50,6 +50,43 @@ func (lm *loggingMiddleware) Register(ctx context.Context, session authn.Session
return lm.svc.Register(ctx, session, user, selfRegister)
}
// SendVerification logs the send_verification request. It logs the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) SendVerification(ctx context.Context, session authn.Session) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("request_id", middleware.GetReqID(ctx)),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Send verification failed", args...)
return
}
lm.logger.Info("Send verification completed successfully", args...)
}(time.Now())
return lm.svc.SendVerification(ctx, session)
}
// VerifyEmail logs the verify_email request. It logs the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (user users.User, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", user.ID),
slog.String("request_id", middleware.GetReqID(ctx)),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Verify email failed", args...)
return
}
lm.logger.Info("Verify email completed successfully", args...)
}(time.Now())
return lm.svc.VerifyEmail(ctx, verificationToken)
}
// IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) {
+18
View File
@@ -39,6 +39,24 @@ func (ms *metricsMiddleware) Register(ctx context.Context, session authn.Session
return ms.svc.Register(ctx, session, user, selfRegister)
}
// SendVerification instruments SendVerification method with metrics.
func (ms *metricsMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
defer func(begin time.Time) {
ms.counter.With("method", "send_verification").Add(1)
ms.latency.With("method", "send_verification").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.SendVerification(ctx, session)
}
// VerifyEmail instruments VerifyEmail method with metrics.
func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
defer func(begin time.Time) {
ms.counter.With("method", "verify_email").Add(1)
ms.latency.With("method", "verify_email").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.VerifyEmail(ctx, verificationToken)
}
// IssueToken instruments IssueToken method with metrics.
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
defer func(begin time.Time) {
+63
View File
@@ -100,3 +100,66 @@ func (_c *Emailer_SendPasswordReset_Call) RunAndReturn(run func(To []string, use
_c.Call.Return(run)
return _c
}
// SendVerification provides a mock function for the type Emailer
func (_mock *Emailer) SendVerification(To []string, user string, verificationToken string) error {
ret := _mock.Called(To, user, verificationToken)
if len(ret) == 0 {
panic("no return value specified for SendVerification")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func([]string, string, string) error); ok {
r0 = returnFunc(To, user, verificationToken)
} else {
r0 = ret.Error(0)
}
return r0
}
// Emailer_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
type Emailer_SendVerification_Call struct {
*mock.Call
}
// SendVerification is a helper method to define mock.On call
// - To []string
// - user string
// - verificationToken string
func (_e *Emailer_Expecter) SendVerification(To interface{}, user interface{}, verificationToken interface{}) *Emailer_SendVerification_Call {
return &Emailer_SendVerification_Call{Call: _e.mock.On("SendVerification", To, user, verificationToken)}
}
func (_c *Emailer_SendVerification_Call) Run(run func(To []string, user string, verificationToken string)) *Emailer_SendVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 []string
if args[0] != nil {
arg0 = args[0].([]string)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Emailer_SendVerification_Call) Return(err error) *Emailer_SendVerification_Call {
_c.Call.Return(err)
return _c
}
func (_c *Emailer_SendVerification_Call) RunAndReturn(run func(To []string, user string, verificationToken string) error) *Emailer_SendVerification_Call {
_c.Call.Return(run)
return _c
}
+384
View File
@@ -41,6 +41,63 @@ func (_m *Repository) EXPECT() *Repository_Expecter {
return &Repository_Expecter{mock: &_m.Mock}
}
// AddUserVerification provides a mock function for the type Repository
func (_mock *Repository) AddUserVerification(ctx context.Context, uv users.UserVerification) error {
ret := _mock.Called(ctx, uv)
if len(ret) == 0 {
panic("no return value specified for AddUserVerification")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok {
r0 = returnFunc(ctx, uv)
} else {
r0 = ret.Error(0)
}
return r0
}
// Repository_AddUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddUserVerification'
type Repository_AddUserVerification_Call struct {
*mock.Call
}
// AddUserVerification is a helper method to define mock.On call
// - ctx context.Context
// - uv users.UserVerification
func (_e *Repository_Expecter) AddUserVerification(ctx interface{}, uv interface{}) *Repository_AddUserVerification_Call {
return &Repository_AddUserVerification_Call{Call: _e.mock.On("AddUserVerification", ctx, uv)}
}
func (_c *Repository_AddUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_AddUserVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 users.UserVerification
if args[1] != nil {
arg1 = args[1].(users.UserVerification)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_AddUserVerification_Call) Return(err error) *Repository_AddUserVerification_Call {
_c.Call.Return(err)
return _c
}
func (_c *Repository_AddUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_AddUserVerification_Call {
_c.Call.Return(run)
return _c
}
// ChangeStatus provides a mock function for the type Repository
func (_mock *Repository) ChangeStatus(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
@@ -551,6 +608,78 @@ func (_c *Repository_RetrieveByUsername_Call) RunAndReturn(run func(ctx context.
return _c
}
// RetrieveUserVerification provides a mock function for the type Repository
func (_mock *Repository) RetrieveUserVerification(ctx context.Context, userID string, email string) (users.UserVerification, error) {
ret := _mock.Called(ctx, userID, email)
if len(ret) == 0 {
panic("no return value specified for RetrieveUserVerification")
}
var r0 users.UserVerification
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (users.UserVerification, error)); ok {
return returnFunc(ctx, userID, email)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) users.UserVerification); ok {
r0 = returnFunc(ctx, userID, email)
} else {
r0 = ret.Get(0).(users.UserVerification)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = returnFunc(ctx, userID, email)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_RetrieveUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveUserVerification'
type Repository_RetrieveUserVerification_Call struct {
*mock.Call
}
// RetrieveUserVerification is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - email string
func (_e *Repository_Expecter) RetrieveUserVerification(ctx interface{}, userID interface{}, email interface{}) *Repository_RetrieveUserVerification_Call {
return &Repository_RetrieveUserVerification_Call{Call: _e.mock.On("RetrieveUserVerification", ctx, userID, email)}
}
func (_c *Repository_RetrieveUserVerification_Call) Run(run func(ctx context.Context, userID string, email string)) *Repository_RetrieveUserVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Repository_RetrieveUserVerification_Call) Return(userVerification users.UserVerification, err error) *Repository_RetrieveUserVerification_Call {
_c.Call.Return(userVerification, err)
return _c
}
func (_c *Repository_RetrieveUserVerification_Call) RunAndReturn(run func(ctx context.Context, userID string, email string) (users.UserVerification, error)) *Repository_RetrieveUserVerification_Call {
_c.Call.Return(run)
return _c
}
// Save provides a mock function for the type Repository
func (_mock *Repository) Save(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
@@ -755,6 +884,138 @@ func (_c *Repository_Update_Call) RunAndReturn(run func(ctx context.Context, id
return _c
}
// UpdateEmail provides a mock function for the type Repository
func (_mock *Repository) UpdateEmail(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
if len(ret) == 0 {
panic("no return value specified for UpdateEmail")
}
var r0 users.User
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
return returnFunc(ctx, user)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
r0 = returnFunc(ctx, user)
} else {
r0 = ret.Get(0).(users.User)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
r1 = returnFunc(ctx, user)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_UpdateEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateEmail'
type Repository_UpdateEmail_Call struct {
*mock.Call
}
// UpdateEmail is a helper method to define mock.On call
// - ctx context.Context
// - user users.User
func (_e *Repository_Expecter) UpdateEmail(ctx interface{}, user interface{}) *Repository_UpdateEmail_Call {
return &Repository_UpdateEmail_Call{Call: _e.mock.On("UpdateEmail", ctx, user)}
}
func (_c *Repository_UpdateEmail_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateEmail_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 users.User
if args[1] != nil {
arg1 = args[1].(users.User)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_UpdateEmail_Call) Return(user1 users.User, err error) *Repository_UpdateEmail_Call {
_c.Call.Return(user1, err)
return _c
}
func (_c *Repository_UpdateEmail_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateEmail_Call {
_c.Call.Return(run)
return _c
}
// UpdateRole provides a mock function for the type Repository
func (_mock *Repository) UpdateRole(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
if len(ret) == 0 {
panic("no return value specified for UpdateRole")
}
var r0 users.User
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
return returnFunc(ctx, user)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
r0 = returnFunc(ctx, user)
} else {
r0 = ret.Get(0).(users.User)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
r1 = returnFunc(ctx, user)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole'
type Repository_UpdateRole_Call struct {
*mock.Call
}
// UpdateRole is a helper method to define mock.On call
// - ctx context.Context
// - user users.User
func (_e *Repository_Expecter) UpdateRole(ctx interface{}, user interface{}) *Repository_UpdateRole_Call {
return &Repository_UpdateRole_Call{Call: _e.mock.On("UpdateRole", ctx, user)}
}
func (_c *Repository_UpdateRole_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateRole_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 users.User
if args[1] != nil {
arg1 = args[1].(users.User)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_UpdateRole_Call) Return(user1 users.User, err error) *Repository_UpdateRole_Call {
_c.Call.Return(user1, err)
return _c
}
func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateRole_Call {
_c.Call.Return(run)
return _c
}
// UpdateSecret provides a mock function for the type Repository
func (_mock *Repository) UpdateSecret(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
@@ -821,6 +1082,63 @@ func (_c *Repository_UpdateSecret_Call) RunAndReturn(run func(ctx context.Contex
return _c
}
// UpdateUserVerification provides a mock function for the type Repository
func (_mock *Repository) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error {
ret := _mock.Called(ctx, uv)
if len(ret) == 0 {
panic("no return value specified for UpdateUserVerification")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok {
r0 = returnFunc(ctx, uv)
} else {
r0 = ret.Error(0)
}
return r0
}
// Repository_UpdateUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUserVerification'
type Repository_UpdateUserVerification_Call struct {
*mock.Call
}
// UpdateUserVerification is a helper method to define mock.On call
// - ctx context.Context
// - uv users.UserVerification
func (_e *Repository_Expecter) UpdateUserVerification(ctx interface{}, uv interface{}) *Repository_UpdateUserVerification_Call {
return &Repository_UpdateUserVerification_Call{Call: _e.mock.On("UpdateUserVerification", ctx, uv)}
}
func (_c *Repository_UpdateUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_UpdateUserVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 users.UserVerification
if args[1] != nil {
arg1 = args[1].(users.UserVerification)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_UpdateUserVerification_Call) Return(err error) *Repository_UpdateUserVerification_Call {
_c.Call.Return(err)
return _c
}
func (_c *Repository_UpdateUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_UpdateUserVerification_Call {
_c.Call.Return(run)
return _c
}
// UpdateUsername provides a mock function for the type Repository
func (_mock *Repository) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
@@ -886,3 +1204,69 @@ func (_c *Repository_UpdateUsername_Call) RunAndReturn(run func(ctx context.Cont
_c.Call.Return(run)
return _c
}
// UpdateVerifiedAt provides a mock function for the type Repository
func (_mock *Repository) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) {
ret := _mock.Called(ctx, user)
if len(ret) == 0 {
panic("no return value specified for UpdateVerifiedAt")
}
var r0 users.User
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok {
return returnFunc(ctx, user)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok {
r0 = returnFunc(ctx, user)
} else {
r0 = ret.Get(0).(users.User)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok {
r1 = returnFunc(ctx, user)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Repository_UpdateVerifiedAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateVerifiedAt'
type Repository_UpdateVerifiedAt_Call struct {
*mock.Call
}
// UpdateVerifiedAt is a helper method to define mock.On call
// - ctx context.Context
// - user users.User
func (_e *Repository_Expecter) UpdateVerifiedAt(ctx interface{}, user interface{}) *Repository_UpdateVerifiedAt_Call {
return &Repository_UpdateVerifiedAt_Call{Call: _e.mock.On("UpdateVerifiedAt", ctx, user)}
}
func (_c *Repository_UpdateVerifiedAt_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateVerifiedAt_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 users.User
if args[1] != nil {
arg1 = args[1].(users.User)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Repository_UpdateVerifiedAt_Call) Return(user1 users.User, err error) *Repository_UpdateVerifiedAt_Call {
_c.Call.Return(user1, err)
return _c
}
func (_c *Repository_UpdateVerifiedAt_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateVerifiedAt_Call {
_c.Call.Return(run)
return _c
}
+123
View File
@@ -923,6 +923,63 @@ func (_c *Service_SendPasswordReset_Call) RunAndReturn(run func(ctx context.Cont
return _c
}
// SendVerification provides a mock function for the type Service
func (_mock *Service) SendVerification(ctx context.Context, session authn.Session) error {
ret := _mock.Called(ctx, session)
if len(ret) == 0 {
panic("no return value specified for SendVerification")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) error); ok {
r0 = returnFunc(ctx, session)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification'
type Service_SendVerification_Call struct {
*mock.Call
}
// SendVerification is a helper method to define mock.On call
// - ctx context.Context
// - session authn.Session
func (_e *Service_Expecter) SendVerification(ctx interface{}, session interface{}) *Service_SendVerification_Call {
return &Service_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, session)}
}
func (_c *Service_SendVerification_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_SendVerification_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 authn.Session
if args[1] != nil {
arg1 = args[1].(authn.Session)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Service_SendVerification_Call) Return(err error) *Service_SendVerification_Call {
_c.Call.Return(err)
return _c
}
func (_c *Service_SendVerification_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) error) *Service_SendVerification_Call {
_c.Call.Return(run)
return _c
}
// Update provides a mock function for the type Service
func (_mock *Service) Update(ctx context.Context, session authn.Session, id string, user users.UserReq) (users.User, error) {
ret := _mock.Called(ctx, session, id, user)
@@ -1463,6 +1520,72 @@ func (_c *Service_UpdateUsername_Call) RunAndReturn(run func(ctx context.Context
return _c
}
// VerifyEmail provides a mock function for the type Service
func (_mock *Service) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
ret := _mock.Called(ctx, verificationToken)
if len(ret) == 0 {
panic("no return value specified for VerifyEmail")
}
var r0 users.User
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (users.User, error)); ok {
return returnFunc(ctx, verificationToken)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) users.User); ok {
r0 = returnFunc(ctx, verificationToken)
} else {
r0 = ret.Get(0).(users.User)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, verificationToken)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail'
type Service_VerifyEmail_Call struct {
*mock.Call
}
// VerifyEmail is a helper method to define mock.On call
// - ctx context.Context
// - verificationToken string
func (_e *Service_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *Service_VerifyEmail_Call {
return &Service_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)}
}
func (_c *Service_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *Service_VerifyEmail_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Service_VerifyEmail_Call) Return(user users.User, err error) *Service_VerifyEmail_Call {
_c.Call.Return(user, err)
return _c
}
func (_c *Service_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) (users.User, error)) *Service_VerifyEmail_Call {
_c.Call.Return(run)
return _c
}
// View provides a mock function for the type Service
func (_mock *Service) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
ret := _mock.Called(ctx, session, id)
+24 -3
View File
@@ -14,7 +14,7 @@ func Migration() *migrate.MemoryMigrationSource {
Migrations: []*migrate.Migration{
{
Id: "clients_01",
// VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters
// VARCHAR(36) for column with IDs as UUIDS have a maximum of 36 characters
// STATUS 0 to imply enabled and 1 to imply disabled
// Role 0 to imply user role and 1 to imply admin role
Up: []string{
@@ -50,8 +50,8 @@ func Migration() *migrate.MemoryMigrationSource {
Up: []string{
`ALTER TABLE clients
ADD COLUMN username VARCHAR(254) UNIQUE,
ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '',
ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '',
ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '',
ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '',
ADD COLUMN profile_picture TEXT`,
`ALTER TABLE clients RENAME COLUMN identity TO email`,
`ALTER TABLE clients DROP COLUMN name`,
@@ -97,6 +97,27 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE users ALTER COLUMN updated_at TYPE TIMESTAMP;`,
},
},
{
Id: "clients_07",
Up: []string{
`ALTER TABLE users ADD COLUMN verified_at TIMESTAMPTZ DEFAULT NULL;`,
`CREATE TABLE users_verifications (
user_id VARCHAR(36) NOT NULL,
email VARCHAR(254) NOT NULL,
otp VARCHAR(255),
created_at TIMESTAMPTZ,
expires_at TIMESTAMPTZ,
used_at TIMESTAMPTZ,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
);
CREATE INDEX idx_users_verifications_lookup ON users_verifications (user_id, email, created_at DESC);
`,
},
Down: []string{
`ALTER TABLE users DROP COLUMN verified_at;`,
`DROP TABLE users_verifications;`,
},
},
},
}
}
+64 -63
View File
@@ -33,7 +33,7 @@ func NewRepository(db postgres.Database) users.Repository {
func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error) {
q := `INSERT INTO users (id, tags, email, secret, metadata, created_at, status, role, first_name, last_name, username, profile_picture)
VALUES (:id, :tags, :email, :secret, :metadata, :created_at, :status, :role, :first_name, :last_name, :username, :profile_picture)
RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture`
RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture, verified_at`
dbu, err := toDBUser(c)
if err != nil {
@@ -81,7 +81,7 @@ func (repo *userRepo) CheckSuperAdmin(ctx context.Context, adminID string) error
}
func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User, error) {
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture, verified_at
FROM users WHERE id = :id`
dbu := DBUser{
@@ -95,20 +95,20 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User,
defer rows.Close()
dbu = DBUser{}
if rows.Next() {
if err = rows.StructScan(&dbu); err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
user, err := ToUser(dbu)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
return user, nil
if !rows.Next() {
return users.User{}, repoerr.ErrNotFound
}
return users.User{}, repoerr.ErrNotFound
if err = rows.StructScan(&dbu); err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
user, err := ToUser(dbu)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err)
}
return user, nil
}
func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.UsersPage, error) {
@@ -127,7 +127,7 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
}
q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username,
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by
u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at
FROM users u %s %s LIMIT :limit OFFSET :offset;`, query, orderClause)
dbPage, err := ToDBUsersPage(pm)
@@ -180,35 +180,9 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use
func (repo *userRepo) UpdateUsername(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET username = :username, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email`
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email, role, verified_at`
dbu, err := toDBUser(user)
if err != nil {
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu)
if err != nil {
return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer row.Close()
dbu = DBUser{
ID: user.ID,
Username: stringToNullString(user.Credentials.Username),
UpdatedAt: sql.NullTime{Time: time.Now().UTC(), Valid: true},
}
if ok := row.Next(); !ok {
return users.User{}, errors.Wrap(repoerr.ErrNotFound, row.Err())
}
if err := row.StructScan(&dbu); err != nil {
return users.User{}, err
}
return ToUser(dbu)
return repo.update(ctx, user, q)
}
func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (users.User, error) {
@@ -231,18 +205,10 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
query = append(query, "tags = :tags")
u.Tags = *ur.Tags
}
if ur.Role != nil && *ur.Role != users.AllRole {
query = append(query, "role = :role")
u.Role = *ur.Role
}
if ur.ProfilePicture != nil {
query = append(query, "profile_picture = :profile_picture")
u.ProfilePicture = *ur.ProfilePicture
}
if ur.Email != nil && *ur.Email != "" {
query = append(query, "email = :email")
u.Email = *ur.Email
}
u.UpdatedAt = time.Now().UTC()
if ur.UpdatedAt != nil {
query = append(query, "updated_at = :updated_at")
@@ -259,7 +225,7 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (
q := fmt.Sprintf(`UPDATE users SET %s
WHERE id = :id AND status = :status
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role`, upq)
RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role, verified_at`, upq)
u.Status = users.EnabledStatus
return repo.update(ctx, u, q)
@@ -278,21 +244,37 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string)
defer row.Close()
dbu = DBUser{}
if row.Next() {
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return ToUser(dbu)
if !row.Next() {
return users.User{}, repoerr.ErrNotFound
}
return users.User{}, repoerr.ErrNotFound
if err := row.StructScan(&dbu); err != nil {
return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return ToUser(dbu)
}
func (repo *userRepo) UpdateEmail(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET email = :email, verified_at = NULL, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
user.Status = users.EnabledStatus
return repo.update(ctx, user, q)
}
func (repo *userRepo) UpdateRole(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET role = :role, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
user.Status = users.EnabledStatus
return repo.update(ctx, user, q)
}
func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET secret = :secret, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id AND status = :status
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username`
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
user.Status = users.EnabledStatus
return repo.update(ctx, user, q)
}
@@ -300,7 +282,15 @@ func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users.
func (repo *userRepo) ChangeStatus(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET status = :status, updated_at = :updated_at, updated_by = :updated_by
WHERE id = :id
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username`
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
return repo.update(ctx, user, q)
}
func (repo *userRepo) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) {
q := `UPDATE users SET verified_at = :verified_at
WHERE id = :id and email = :email
RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at`
return repo.update(ctx, user, q)
}
@@ -434,7 +424,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user
}
func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.User, error) {
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at
FROM users WHERE email = :email AND status = :status`
dbu := DBUser{
@@ -461,7 +451,7 @@ func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.
}
func (repo *userRepo) RetrieveByUsername(ctx context.Context, username string) (users.User, error) {
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username
q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at
FROM users WHERE username = :username AND status = :status`
dbu := DBUser{
@@ -504,6 +494,7 @@ type DBUser struct {
LastName sql.NullString `db:"last_name, omitempty"`
ProfilePicture sql.NullString `db:"profile_picture, omitempty"`
Email string `db:"email,omitempty"`
VerifiedAt sql.NullTime `db:"verified_at,omitempty"`
}
func toDBUser(u users.User) (DBUser, error) {
@@ -527,6 +518,10 @@ func toDBUser(u users.User) (DBUser, error) {
if u.UpdatedAt != (time.Time{}) {
updatedAt = sql.NullTime{Time: u.UpdatedAt, Valid: true}
}
var verifiedAt sql.NullTime
if u.VerifiedAt != (time.Time{}) {
verifiedAt = sql.NullTime{Time: u.VerifiedAt, Valid: true}
}
return DBUser{
ID: u.ID,
@@ -543,6 +538,7 @@ func toDBUser(u users.User) (DBUser, error) {
Username: stringToNullString(u.Credentials.Username),
ProfilePicture: stringToNullString(u.ProfilePicture),
Email: u.Email,
VerifiedAt: verifiedAt,
}, nil
}
@@ -565,6 +561,10 @@ func ToUser(dbu DBUser) (users.User, error) {
if dbu.UpdatedAt.Valid {
updatedAt = dbu.UpdatedAt.Time.UTC()
}
var verifiedAt time.Time
if dbu.VerifiedAt.Valid {
verifiedAt = dbu.VerifiedAt.Time.UTC()
}
user := users.User{
ID: dbu.ID,
@@ -582,6 +582,7 @@ func ToUser(dbu DBUser) (users.User, error) {
Status: dbu.Status,
Tags: tags,
ProfilePicture: nullStringString(dbu.ProfilePicture),
VerifiedAt: verifiedAt,
}
if dbu.Role != nil {
user.Role = *dbu.Role
+141 -75
View File
@@ -1048,6 +1048,147 @@ func TestSearch(t *testing.T) {
}
}
func TestUpdateRole(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
})
repo := cpostgres.NewRepository(database)
user1 := generateUser(t, users.EnabledStatus, repo)
user2 := generateUser(t, users.DisabledStatus, repo)
adminRole := users.AdminRole
userRole := users.UserRole
cases := []struct {
desc string
update string
userID string
userReq users.User
err error
}{
{
desc: "update role of user to admin",
userReq: users.User{
ID: user1.ID,
Role: adminRole,
},
err: nil,
},
{
desc: "update role of admin to user",
userReq: users.User{
ID: user1.ID,
Role: userRole,
},
err: nil,
},
{
desc: "update role for disabled user",
userReq: users.User{
ID: user2.ID,
Role: adminRole,
},
err: repoerr.ErrNotFound,
},
{
desc: "update role for invalid user",
userReq: users.User{
ID: testsutil.GenerateUUID(t),
Role: adminRole,
},
err: repoerr.ErrNotFound,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
updatedAt := time.Now().UTC().Truncate(time.Millisecond)
updatedBy := testsutil.GenerateUUID(t)
c.userReq.UpdatedAt = updatedAt
c.userReq.UpdatedBy = updatedBy
expected, err := repo.UpdateRole(context.Background(), c.userReq)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
if err == nil {
assert.Equal(t, c.userReq.Role, expected.Role)
assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt)
assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy)
}
})
}
}
func TestUpdateEmail(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
})
repo := cpostgres.NewRepository(database)
user1 := generateUser(t, users.EnabledStatus, repo)
user2 := generateUser(t, users.DisabledStatus, repo)
user3 := generateUser(t, users.EnabledStatus, repo)
updatedEmail := namesgen.Generate() + emailSuffix
emptyName := ""
cases := []struct {
desc string
update string
userReq users.User
err error
}{
{
desc: "update email for enabled user",
userReq: users.User{
ID: user1.ID,
Email: updatedEmail,
},
err: nil,
},
{
desc: "update empty email for enabled user",
userReq: users.User{
ID: user3.ID,
Email: emptyName,
},
err: nil,
},
{
desc: "update email for disabled user",
userReq: users.User{
ID: user2.ID,
Email: updatedEmail,
},
err: repoerr.ErrNotFound,
},
{
desc: "update email for invalid user",
userReq: users.User{
ID: testsutil.GenerateUUID(t),
Email: updatedEmail,
},
err: repoerr.ErrNotFound,
},
}
for _, c := range cases {
t.Run(c.desc, func(t *testing.T) {
updatedAt := time.Now().UTC().Truncate(time.Millisecond)
updatedBy := testsutil.GenerateUUID(t)
c.userReq.UpdatedAt = updatedAt
c.userReq.UpdatedBy = updatedBy
expected, err := repo.UpdateEmail(context.Background(), c.userReq)
assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err))
if err == nil {
assert.Equal(t, c.userReq.Email, expected.Email)
assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt)
assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy)
}
})
}
}
func TestUpdate(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
@@ -1065,11 +1206,8 @@ func TestUpdate(t *testing.T) {
updatedFirstName := namesgen.Generate()
updateTags := namesgen.GenerateMultiple(5)
updatedProfilePicture := namesgen.Generate()
adminRole := users.AdminRole
updatedEmail := namesgen.Generate() + emailSuffix
emptyName := ""
emptyTags := []string{}
allRole := users.AllRole
cases := []struct {
desc string
@@ -1286,78 +1424,6 @@ func TestUpdate(t *testing.T) {
},
err: repoerr.ErrNotFound,
},
{
desc: "update role for enabled user",
userID: user1.ID,
userReq: users.UserReq{
Role: &adminRole,
},
userRes: users.User{
Role: adminRole,
},
err: nil,
},
{
desc: "update all role for enabled user",
userID: user3.ID,
userReq: users.UserReq{
Role: &allRole,
},
userRes: user3,
err: nil,
},
{
desc: "update role for disabled user",
userID: user2.ID,
userReq: users.UserReq{
Role: &adminRole,
},
err: repoerr.ErrNotFound,
},
{
desc: "update role for invalid user",
userID: testsutil.GenerateUUID(t),
userReq: users.UserReq{
Role: &adminRole,
},
err: repoerr.ErrNotFound,
},
{
desc: "update email for enabled user",
userID: user1.ID,
userReq: users.UserReq{
Email: &updatedEmail,
},
userRes: users.User{
Email: updatedEmail,
},
err: nil,
},
{
desc: "update empty email for enabled user",
userID: user3.ID,
userReq: users.UserReq{
Email: &emptyName,
},
userRes: user3,
err: nil,
},
{
desc: "update email for disabled user",
userID: user2.ID,
userReq: users.UserReq{
Email: &updatedEmail,
},
err: repoerr.ErrNotFound,
},
{
desc: "update email for invalid user",
userID: testsutil.GenerateUUID(t),
userReq: users.UserReq{
Email: &updatedEmail,
},
err: repoerr.ErrNotFound,
},
}
for _, c := range cases {
+131
View File
@@ -0,0 +1,131 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"time"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/users"
)
// AddUserVerification adds new verification for given user id and email.
func (repo *userRepo) AddUserVerification(ctx context.Context, uv users.UserVerification) error {
q := `INSERT INTO users_verifications (user_id, email, otp, created_at, expires_at )
VALUES (:user_id, :email, :otp, :created_at, :expires_at );`
dbuv := toDBUserVerification(uv)
if _, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
return nil
}
// RetrieveUserVerification retrieves verification details of given user id and email.
func (repo *userRepo) RetrieveUserVerification(ctx context.Context, userID, email string) (users.UserVerification, error) {
dbuv := dbUserVerification{
UserID: userID,
Email: email,
}
q := `SELECT user_id, email, otp, created_at, expires_at , used_at FROM users_verifications WHERE user_id = :user_id AND email = :email ORDER BY created_at DESC LIMIT 1 `
row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbuv)
if err != nil {
return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
if !row.Next() {
return users.UserVerification{}, repoerr.ErrNotFound
}
defer row.Close()
if err := row.StructScan(&dbuv); err != nil {
return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return toUserVerification(dbuv), nil
}
// UpdateUserVerification update user verification details for the given user id and email.
func (repo *userRepo) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error {
q := `UPDATE users_verifications SET otp = :otp, used_at = :used_at WHERE user_id = :user_id AND email = :email`
dbuv := toDBUserVerification(uv)
res, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
rows, err := res.RowsAffected()
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if rows == 0 {
return repoerr.ErrNotFound
}
return nil
}
type dbUserVerification struct {
UserID string `db:"user_id"`
Email string `db:"email"`
OTP sql.NullString `db:"otp"`
CreatedAt sql.NullTime `db:"created_at"`
ExpiresAt sql.NullTime `db:"expires_at"`
UsedAt sql.NullTime `db:"used_at"`
}
func toDBUserVerification(uv users.UserVerification) dbUserVerification {
var otp sql.NullString
if uv.OTP != "" {
otp = sql.NullString{String: uv.OTP, Valid: true}
}
var createdAt sql.NullTime
if !uv.CreatedAt.IsZero() {
createdAt = sql.NullTime{Time: uv.CreatedAt, Valid: true}
}
var expiresAt sql.NullTime
if !uv.ExpiresAt.IsZero() {
expiresAt = sql.NullTime{Time: uv.ExpiresAt, Valid: true}
}
var usedAt sql.NullTime
if !uv.UsedAt.IsZero() {
usedAt = sql.NullTime{Time: uv.UsedAt, Valid: true}
}
return dbUserVerification{
UserID: uv.UserID,
Email: uv.Email,
OTP: otp,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
UsedAt: usedAt,
}
}
func toUserVerification(dbuv dbUserVerification) users.UserVerification {
var createdAt time.Time
if dbuv.CreatedAt.Valid {
createdAt = dbuv.CreatedAt.Time.UTC()
}
var expiresAt time.Time
if dbuv.ExpiresAt.Valid {
expiresAt = dbuv.ExpiresAt.Time.UTC()
}
var usedAt time.Time
if dbuv.UsedAt.Valid {
usedAt = dbuv.UsedAt.Time.UTC()
}
return users.UserVerification{
UserID: dbuv.UserID,
Email: dbuv.Email,
OTP: dbuv.OTP.String,
CreatedAt: createdAt,
ExpiresAt: expiresAt,
UsedAt: usedAt,
}
}
+217
View File
@@ -0,0 +1,217 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/users"
"github.com/absmach/supermq/users/postgres"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddUserVerification(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
_, err = db.Exec("DELETE FROM users_verifications")
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
first_name := namesgen.Generate()
last_name := namesgen.Generate()
username := namesgen.Generate()
user := users.User{
ID: "test-user-id",
Email: "test@example.com",
FirstName: first_name,
LastName: last_name,
Credentials: users.Credentials{
Username: username,
},
}
_, err := repo.Save(context.Background(), user)
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
cases := []struct {
desc string
uv users.UserVerification
err error
}{
{
desc: "add new user verification",
uv: users.UserVerification{
UserID: user.ID,
Email: user.Email,
CreatedAt: time.Now().UTC(),
OTP: "123456",
ExpiresAt: time.Now().UTC().Add(time.Hour),
},
err: nil,
},
{
desc: "add user verification for non-existing user",
uv: users.UserVerification{
UserID: "non-existing-user",
Email: "non-existing@example.com",
OTP: "654321",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(time.Hour),
},
err: repoerr.ErrCreateEntity,
},
}
for _, tc := range cases {
err := repo.AddUserVerification(context.Background(), tc.uv)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
}
}
func TestRetrieveUserVerification(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
_, err = db.Exec("DELETE FROM users_verifications")
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
first_name := namesgen.Generate()
last_name := namesgen.Generate()
username := namesgen.Generate()
user := users.User{
ID: "test-user-id",
Email: "test@example.com",
FirstName: first_name,
LastName: last_name,
Credentials: users.Credentials{
Username: username,
},
}
_, err := repo.Save(context.Background(), user)
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
uv := users.UserVerification{
UserID: user.ID,
Email: user.Email,
OTP: "123456",
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour),
}
err = repo.AddUserVerification(context.Background(), uv)
require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err))
cases := []struct {
desc string
userID string
email string
err error
}{
{
desc: "retrieve existing user verification",
userID: user.ID,
email: user.Email,
err: nil,
},
{
desc: "retrieve non-existing user verification",
userID: "non-existing-user",
email: "non-existing@example.com",
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.userID, tc.email)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
if err == nil {
assert.Equal(t, uv.UserID, retrieved.UserID, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.UserID, retrieved.UserID))
assert.Equal(t, uv.Email, retrieved.Email, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.Email, retrieved.Email))
assert.Equal(t, uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.OTP, retrieved.OTP))
}
}
}
func TestUpdateUserVerification(t *testing.T) {
t.Cleanup(func() {
_, err := db.Exec("DELETE FROM users")
require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err))
_, err = db.Exec("DELETE FROM users_verifications")
require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err))
})
repo := postgres.NewRepository(database)
first_name := namesgen.Generate()
last_name := namesgen.Generate()
username := namesgen.Generate()
user := users.User{
ID: "test-user-id",
Email: "test@example.com",
FirstName: first_name,
LastName: last_name,
Credentials: users.Credentials{
Username: username,
},
}
_, err := repo.Save(context.Background(), user)
require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err))
uv := users.UserVerification{
UserID: user.ID,
Email: user.Email,
OTP: "123456",
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(time.Hour),
}
err = repo.AddUserVerification(context.Background(), uv)
require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err))
usedTime := time.Now()
cases := []struct {
desc string
uv users.UserVerification
err error
}{
{
desc: "update existing user verification",
uv: users.UserVerification{
UserID: user.ID,
Email: user.Email,
OTP: "654321",
UsedAt: usedTime,
},
err: nil,
},
{
desc: "update non-existing user verification",
uv: users.UserVerification{
UserID: "non-existing-user",
Email: "non-existing@example.com",
},
err: repoerr.ErrNotFound,
},
}
for _, tc := range cases {
err := repo.UpdateUserVerification(context.Background(), tc.uv)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err))
if err == nil {
retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.uv.UserID, tc.uv.Email)
require.Nil(t, err, fmt.Sprintf("retrieving updated verification unexpected error: %s", err))
assert.Equal(t, tc.uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.OTP, retrieved.OTP))
assert.WithinDuration(t, tc.uv.UsedAt, retrieved.UsedAt, 10*time.Second, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.UsedAt, retrieved.UsedAt))
}
}
}
+104 -15
View File
@@ -5,6 +5,7 @@ package users
import (
"context"
"fmt"
"net/mail"
"time"
@@ -92,6 +93,85 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User,
return user, nil
}
func (svc service) SendVerification(ctx context.Context, session authn.Session) error {
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
if err != nil {
return err
}
if !dbUser.VerifiedAt.IsZero() {
return svcerr.ErrUserAlreadyVerified
}
uv, err := svc.users.RetrieveUserVerification(ctx, dbUser.ID, dbUser.Email)
if err != nil && err != repoerr.ErrNotFound {
return err
}
if err = uv.Valid(); err != nil {
uv, err = NewUserVerification(dbUser.ID, dbUser.Email)
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
if err := svc.users.AddUserVerification(ctx, uv); err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
}
uvs, err := uv.Encode()
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
if err := svc.email.SendVerification([]string{dbUser.Email}, dbUser.Credentials.Username, uvs); err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
return nil
}
func (svc service) VerifyEmail(ctx context.Context, token string) (User, error) {
var received UserVerification
if err := received.Decode(token); err != nil {
return User{}, errors.Wrap(svcerr.ErrInvalidUserVerification, err)
}
stored, err := svc.users.RetrieveUserVerification(ctx, received.UserID, received.Email)
if err != nil {
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if err := stored.Match(received); err != nil {
return User{}, err
}
if err := stored.Valid(); err != nil {
if err == svcerr.ErrUserVerificationExpired {
return User{}, err
}
return User{}, errors.Wrap(svcerr.ErrMalformedEntity, err)
}
stored.UsedAt = time.Now().UTC()
if err = svc.users.UpdateUserVerification(ctx, stored); err != nil {
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
user := User{
ID: stored.UserID,
Email: stored.Email,
VerifiedAt: time.Now().UTC(),
}
user, err = svc.users.UpdateVerifiedAt(ctx, user)
if err == repoerr.ErrNotFound {
return User{}, svcerr.ErrInvalidUserVerification
}
if err != nil {
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return user, nil
}
func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) {
var dbUser User
var err error
@@ -110,7 +190,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
}
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey)})
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()})
if err != nil {
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
}
@@ -127,7 +207,7 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
}
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken})
return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()})
}
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
@@ -255,14 +335,23 @@ func (svc service) UpdateEmail(ctx context.Context, session authn.Session, userI
return User{}, err
}
}
updatedAt := time.Now().UTC()
usr := UserReq{
Email: &email,
UpdatedAt: &updatedAt,
UpdatedBy: &session.UserID,
oldUsr, err := svc.users.RetrieveByID(ctx, userID)
if err != nil {
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
user, err := svc.users.Update(ctx, userID, usr)
if oldUsr.Email == email {
return User{}, fmt.Errorf("current email is same as update requested email")
}
usr := User{
ID: userID,
Email: email,
UpdatedAt: time.Now().UTC(),
UpdatedBy: session.UserID,
VerifiedAt: time.Time{},
}
user, err := svc.users.UpdateEmail(ctx, usr)
if err != nil {
return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
@@ -362,18 +451,18 @@ func (svc service) UpdateRole(ctx context.Context, session authn.Session, usr Us
if err := svc.checkSuperAdmin(ctx, session); err != nil {
return User{}, err
}
updateAt := time.Now().UTC()
uReq := UserReq{
Role: &usr.Role,
UpdatedAt: &updateAt,
UpdatedBy: &session.UserID,
usr = User{
ID: usr.ID,
Role: usr.Role,
UpdatedAt: time.Now().UTC(),
UpdatedBy: session.UserID,
}
if err := svc.updateUserPolicy(ctx, usr.ID, usr.Role); err != nil {
return User{}, err
}
u, err := svc.users.Update(ctx, usr.ID, uReq)
u, err := svc.users.UpdateRole(ctx, usr)
if err != nil {
// If failed to update role in DB, then revert back to platform admin policies in spicedb
if errRollback := svc.updateUserPolicy(ctx, usr.ID, UserRole); errRollback != nil {
+182 -11
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"testing"
"time"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
smqauth "github.com/absmach/supermq/auth"
@@ -776,13 +777,13 @@ func TestUpdateRole(t *testing.T) {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr)
policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr)
repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr)
updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything)
ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
}
repoCall.Unset()
@@ -906,7 +907,7 @@ func TestUpdateEmail(t *testing.T) {
svc, cRepo := newServiceMinimal()
user2 := user
user2.Email = "updated@example.com"
user2.Email = "user2@example.com"
cases := []struct {
desc string
@@ -921,16 +922,26 @@ func TestUpdateEmail(t *testing.T) {
}{
{
desc: "update user as normal user successfully",
email: "updated@example.com",
email: "user2-update-1@example.com",
token: validToken,
reqUserID: user.ID,
id: user.ID,
updateEmailResponse: user2,
err: nil,
},
{
desc: "update to same email as normal user successfully",
email: "user2-update-1@example.com",
token: validToken,
reqUserID: user.ID,
id: user.ID,
updateEmailResponse: user2,
err: nil,
},
{
desc: "update user email as normal user with repo error on update",
email: "updated@example.com",
email: "user2-update-2@example.com",
token: validToken,
reqUserID: user.ID,
id: user.ID,
@@ -940,14 +951,14 @@ func TestUpdateEmail(t *testing.T) {
},
{
desc: "update user email as admin successfully",
email: "updated@example.com",
email: "user2-update-3@example.com",
token: validToken,
id: user.ID,
err: nil,
},
{
desc: "update user email as admin with repo error on update",
email: "updated@exmaple.com",
email: "user2-update-4@exmaple.com",
token: validToken,
reqUserID: user.ID,
id: user.ID,
@@ -957,7 +968,7 @@ func TestUpdateEmail(t *testing.T) {
},
{
desc: "update user as admin user with failed check on super admin",
email: "updated@exmaple.com",
email: "user2-update-5@exmaple.com",
token: validToken,
reqUserID: user.ID,
id: "",
@@ -970,15 +981,18 @@ func TestUpdateEmail(t *testing.T) {
for _, tc := range cases {
repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr)
repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr)
updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser))
if tc.err == nil {
ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything)
if tc.err == nil && user2.Email != tc.email {
ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything)
assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc))
user2.Email = tc.email
}
repoCall.Unset()
repocall2.Unset()
repoCall1.Unset()
}
}
@@ -1791,3 +1805,160 @@ func TestOAuthCallback(t *testing.T) {
})
}
}
func TestSendVerification(t *testing.T) {
svc, _, cRepo, _, e := newService()
verifiedAt := time.Now().UTC()
cases := []struct {
desc string
session authn.Session
retrieveByIDResponse users.User
retrieveByIDError error
retrieveUserVerResponse users.UserVerification
retrieveUserVerError error
addUserVerError error
sendVerificationEmailError error
err error
}{
{
desc: "send verification email successfully",
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: user,
retrieveUserVerError: repoerr.ErrNotFound,
sendVerificationEmailError: nil,
err: nil,
},
{
desc: "send verification email for already verified user",
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: users.User{VerifiedAt: verifiedAt},
err: svcerr.ErrUserAlreadyVerified,
},
{
desc: "send verification email for non-existing user",
session: authn.Session{UserID: wrongID},
retrieveByIDError: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "send verification email with failed to retrieve user verification",
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: user,
retrieveUserVerError: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
},
{
desc: "send verification email with failed to add user verification",
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: user,
retrieveUserVerError: repoerr.ErrNotFound,
addUserVerError: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
{
desc: "send verification email with failed to send email",
session: authn.Session{UserID: user.ID},
retrieveByIDResponse: user,
retrieveUserVerError: repoerr.ErrNotFound,
sendVerificationEmailError: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.retrieveByIDResponse, tc.retrieveByIDError)
repoCall1 := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError)
repoCall2 := cRepo.On("AddUserVerification", context.Background(), mock.Anything).Return(tc.addUserVerError)
emailCall := e.On("SendVerification", []string{user.Email}, user.Credentials.Username, mock.Anything).Return(tc.sendVerificationEmailError)
err := svc.SendVerification(context.Background(), tc.session)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
emailCall.Unset()
})
}
}
func TestVerifyEmail(t *testing.T) {
//nolint:dogsled
svc, _, cRepo, _, _ := newService()
uv, err := users.NewUserVerification(user.ID, user.Email)
assert.Nil(t, err, fmt.Sprintf("failed to create user verification: %v", err))
uvs, err := uv.Encode()
assert.Nil(t, err, fmt.Sprintf("failed to encode user verification: %v", err))
createdAt := time.Now().Add(-5 * users.VerificationExpiryDuration).UTC()
expiresdAt := time.Now().Add(-users.VerificationExpiryDuration).UTC()
cases := []struct {
desc string
uvs string
retrieveUserVerResponse users.UserVerification
retrieveUserVerError error
updateUserVerError error
updateVerifiedAtError error
err error
}{
{
desc: "verify email successfully",
uvs: uvs,
retrieveUserVerResponse: uv,
err: nil,
},
{
desc: "verify email with malformed token",
uvs: "invalid",
err: svcerr.ErrInvalidUserVerification,
},
{
desc: "verify email with non-existing user verification",
uvs: uvs,
retrieveUserVerError: repoerr.ErrNotFound,
err: svcerr.ErrViewEntity,
},
{
desc: "verify email with expired token",
uvs: uvs,
retrieveUserVerResponse: users.UserVerification{
UserID: uv.UserID,
Email: uv.Email,
OTP: uv.OTP,
ExpiresAt: expiresdAt,
CreatedAt: createdAt,
UsedAt: uv.UsedAt,
},
err: svcerr.ErrUserVerificationExpired,
},
{
desc: "verify email with failed to update user verification",
uvs: uvs,
retrieveUserVerResponse: uv,
updateUserVerError: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
{
desc: "verify email with failed to update verified at",
uvs: uvs,
retrieveUserVerResponse: uv,
updateVerifiedAtError: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError)
repoCall1 := cRepo.On("UpdateUserVerification", context.Background(), mock.Anything).Return(tc.updateUserVerError)
repoCall2 := cRepo.On("UpdateVerifiedAt", context.Background(), mock.Anything).Return(users.User{}, tc.updateVerifiedAtError)
_, err := svc.VerifyEmail(context.Background(), tc.uvs)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
})
}
}
+15
View File
@@ -34,6 +34,21 @@ func (tm *tracingMiddleware) Register(ctx context.Context, session authn.Session
return tm.svc.Register(ctx, session, user, selfRegister)
}
// SendVerification traces the "SendVerification" operation of the wrapped users.Service.
func (tm *tracingMiddleware) SendVerification(ctx context.Context, session authn.Session) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_send_verification")
defer span.End()
return tm.svc.SendVerification(ctx, session)
}
// VerifyEmail traces the "VerifyEmail" operation of the wrapped users.Service.
func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_verify_email")
defer span.End()
return tm.svc.VerifyEmail(ctx, verificationToken)
}
// IssueToken traces the "IssueToken" operation of the wrapped users.Service.
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username)))
+26 -2
View File
@@ -29,6 +29,7 @@ type User struct {
CreatedAt time.Time `json:"created_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
UpdatedBy string `json:"updated_by,omitempty"`
VerifiedAt time.Time `json:"verified_at,omitempty"`
}
type Credentials struct {
@@ -49,9 +50,7 @@ type UserReq struct {
LastName *string `json:"last_name,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
Tags *[]string `json:"tags,omitempty"`
Role *Role `json:"role,omitempty"`
ProfilePicture *string `json:"profile_picture,omitempty"`
Email *string `json:"email,omitempty"`
UpdatedBy *string `json:"updated_by,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
@@ -90,6 +89,15 @@ type Repository interface {
// UpdateSecret updates secret for user with given email.
UpdateSecret(ctx context.Context, user User) (User, error)
// UpdateEmail updates email for user with given id.
UpdateEmail(ctx context.Context, user User) (User, error)
// UpdateRole updates role for user with given id.
UpdateRole(ctx context.Context, user User) (User, error)
// UpdateVerifiedAt updates the verified time for user with given id.
UpdateVerifiedAt(ctx context.Context, user User) (User, error)
// ChangeStatus changes user status to enabled or disabled
ChangeStatus(ctx context.Context, user User) (User, error)
@@ -107,6 +115,15 @@ type Repository interface {
// Save persists the user account. A non-nil error is returned to indicate
// operation failure.
Save(ctx context.Context, user User) (User, error)
// AddUserVerification adds new verification for given user id and email
AddUserVerification(ctx context.Context, uv UserVerification) error
// RetrieveVerificationToken retrieves verification token of given user id and email.
RetrieveUserVerification(ctx context.Context, userID, email string) (UserVerification, error)
// UpdateUserVerificationDetails update verification details for the given user id and email.
UpdateUserVerification(ctx context.Context, uv UserVerification) error
}
// Validate returns an error if user representation is invalid.
@@ -143,6 +160,7 @@ type Page struct {
FirstName string `json:"first_name,omitempty"`
LastName string `json:"last_name,omitempty"`
Email string `json:"email,omitempty"`
Verified bool `json:"verified,omitempty"`
}
// Service specifies an API that must be fullfiled by the domain service
@@ -152,6 +170,12 @@ type Service interface {
// non-nil error value is returned.
Register(ctx context.Context, session authn.Session, user User, selfRegister bool) (User, error)
// SendVerification sends a verification email to the user.
SendVerification(ctx context.Context, session authn.Session) error
// VerifyEmail verifies user's email using the verification token.
VerifyEmail(ctx context.Context, verificationToken string) (User, error)
// View retrieves user info for a given user ID and an authorized token.
View(ctx context.Context, session authn.Session, id string) (User, error)
+121
View File
@@ -0,0 +1,121 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package users
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"time"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
)
const VerificationExpiryDuration = 24 * time.Hour
var (
errFailedToCreateUserVerification = errors.New("failed to create new user verification")
errFailedToEncodeUserVerification = errors.New("failed to encode user verification")
errFailedToDecodeUserVerification = errors.New("failed to decode user verification")
)
// UserVerification OTP is sent to the user's email as base64 encoded with UserID, Email and OTP. It should not be exposed via API.
type UserVerification struct {
UserID string `json:"user_id"`
Email string `json:"email"`
OTP string `json:"otp"`
CreatedAt time.Time `json:"-"`
ExpiresAt time.Time `json:"-"`
UsedAt time.Time `json:"-"`
}
func NewUserVerification(userID, email string) (UserVerification, error) {
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return UserVerification{}, errors.Wrap(errFailedToCreateUserVerification, err)
}
return UserVerification{
UserID: userID,
Email: email,
OTP: base64.URLEncoding.EncodeToString(randomBytes),
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().Add(VerificationExpiryDuration).UTC(),
}, nil
}
func (u UserVerification) Encode() (string, error) {
jsonBytes, err := json.Marshal(u)
if err != nil {
return "", errors.Wrap(errFailedToEncodeUserVerification, err)
}
return base64.URLEncoding.EncodeToString(jsonBytes), nil
}
func (u *UserVerification) Decode(data string) error {
decodedPayload, err := base64.URLEncoding.DecodeString(data)
if err != nil {
return errors.Wrap(errFailedToDecodeUserVerification, err)
}
if err := json.Unmarshal(decodedPayload, u); err != nil {
return errors.Wrap(errFailedToDecodeUserVerification, err)
}
if u.UserID == "" || u.Email == "" || u.OTP == "" {
return svcerr.ErrInvalidUserVerification
}
return nil
}
func (u UserVerification) Valid() error {
if u.UserID == "" || u.Email == "" || u.OTP == "" {
return svcerr.ErrInvalidUserVerification
}
// Verification should have created time.
if u.CreatedAt.IsZero() {
return svcerr.ErrInvalidUserVerification
}
// Verification should have expiry time.
if u.ExpiresAt.IsZero() {
return svcerr.ErrInvalidUserVerification
}
// Expiry time should not be before Created time
if u.ExpiresAt.Before(u.CreatedAt) {
return svcerr.ErrInvalidUserVerification
}
// Verification should be not be Expired.
if time.Now().After(u.ExpiresAt) {
return svcerr.ErrUserVerificationExpired
}
// Verification should not be used.
if !u.UsedAt.IsZero() {
return svcerr.ErrUserVerificationExpired
}
return nil
}
func (u UserVerification) Match(ruv UserVerification) error {
if u.UserID != ruv.UserID {
return svcerr.ErrInvalidUserVerification
}
if u.Email != ruv.Email {
return svcerr.ErrInvalidUserVerification
}
if u.OTP != ruv.OTP {
return svcerr.ErrInvalidUserVerification
}
return nil
}