diff --git a/api/grpc/auth/v1/auth.pb.go b/api/grpc/auth/v1/auth.pb.go index 46b6ad2ba..74c251fbb 100644 --- a/api/grpc/auth/v1/auth.pb.go +++ b/api/grpc/auth/v1/auth.pb.go @@ -73,6 +73,7 @@ type AuthNRes struct { Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role + Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` // verified user unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -128,6 +129,13 @@ func (x *AuthNRes) GetUserRole() uint32 { return 0 } +func (x *AuthNRes) GetVerified() bool { + if x != nil { + return x.Verified + } + return false +} + type AuthZReq struct { state protoimpl.MessageState `protogen:"open.v1"` Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` // Domain @@ -378,11 +386,12 @@ const file_auth_v1_auth_proto_rawDesc = "" + "\n" + "\x12auth/v1/auth.proto\x12\aauth.v1\" \n" + "\bAuthNReq\x12\x14\n" + - "\x05token\x18\x01 \x01(\tR\x05token\"P\n" + + "\x05token\x18\x01 \x01(\tR\x05token\"l\n" + "\bAuthNRes\x12\x0e\n" + "\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" + "\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" + - "\tuser_role\x18\x03 \x01(\rR\buserRole\"\xa2\x02\n" + + "\tuser_role\x18\x03 \x01(\rR\buserRole\x12\x1a\n" + + "\bverified\x18\x04 \x01(\bR\bverified\"\xa2\x02\n" + "\bAuthZReq\x12\x16\n" + "\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" + "\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" + diff --git a/api/grpc/token/v1/token.pb.go b/api/grpc/token/v1/token.pb.go index d1f1fadb6..d7eb70ca9 100644 --- a/api/grpc/token/v1/token.pb.go +++ b/api/grpc/token/v1/token.pb.go @@ -29,6 +29,7 @@ type IssueReq struct { UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"` + Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -84,9 +85,17 @@ func (x *IssueReq) GetType() uint32 { return 0 } +func (x *IssueReq) GetVerified() bool { + if x != nil { + return x.Verified + } + return false +} + type RefreshReq struct { state protoimpl.MessageState `protogen:"open.v1"` RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"` + Verified bool `protobuf:"varint,2,opt,name=verified,proto3" json:"verified,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -128,6 +137,13 @@ func (x *RefreshReq) GetRefreshToken() string { return "" } +func (x *RefreshReq) GetVerified() bool { + if x != nil { + return x.Verified + } + return false +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -195,14 +211,16 @@ var File_token_v1_token_proto protoreflect.FileDescriptor const file_token_v1_token_proto_rawDesc = "" + "\n" + - "\x14token/v1/token.proto\x12\btoken.v1\"T\n" + + "\x14token/v1/token.proto\x12\btoken.v1\"p\n" + "\bIssueReq\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" + "\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" + - "\x04type\x18\x03 \x01(\rR\x04type\"1\n" + + "\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" + + "\bverified\x18\x04 \x01(\bR\bverified\"M\n" + "\n" + "RefreshReq\x12#\n" + - "\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\"\x87\x01\n" + + "\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" + + "\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" + "\x05Token\x12!\n" + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" + "\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" + diff --git a/api/http/authn.go b/api/http/authn.go deleted file mode 100644 index 8346d4328..000000000 --- a/api/http/authn.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package http - -import ( - "context" - "net/http" - - apiutil "github.com/absmach/supermq/api/http/util" - "github.com/absmach/supermq/auth" - smqauthn "github.com/absmach/supermq/pkg/authn" - "github.com/go-chi/chi/v5" -) - -type sessionKeyType string - -const SessionKey = sessionKeyType("session") - -func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - token := apiutil.ExtractBearerToken(r) - if token == "" { - EncodeError(r.Context(), apiutil.ErrBearerToken, w) - return - } - resp, err := authn.Authenticate(r.Context(), token) - if err != nil { - EncodeError(r.Context(), err, w) - return - } - - if domainCheck { - domain := chi.URLParam(r, "domainID") - if domain == "" { - EncodeError(r.Context(), apiutil.ErrMissingDomainID, w) - return - } - resp.DomainID = domain - switch resp.Role { - case smqauthn.AdminRole: - resp.DomainUserID = resp.UserID - case smqauthn.UserRole: - resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID) - } - } - - ctx := context.WithValue(r.Context(), SessionKey, resp) - - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} diff --git a/api/http/common.go b/api/http/common.go index 9433bf9a6..6240794a6 100644 --- a/api/http/common.go +++ b/api/http/common.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" "net/http" + "net/mail" "regexp" "strings" @@ -110,6 +111,13 @@ func ValidateUUID(extID string) (err error) { return nil } +func ValidateEmail(email string) (err error) { + if _, err := mail.ParseAddress(email); err != nil { + return apiutil.ErrInvalidEmail + } + return nil +} + // ValidateName validates name format. func ValidateName(id string) error { if !nameRegExp.MatchString(id) { @@ -201,6 +209,7 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, apiutil.ErrMissingSecret), errors.Contains(err, errors.ErrMalformedEntity), errors.Contains(err, apiutil.ErrMissingID), + errors.Contains(err, apiutil.ErrInvalidVerification), errors.Contains(err, apiutil.ErrMissingName), errors.Contains(err, apiutil.ErrMissingEmail), errors.Contains(err, apiutil.ErrInvalidEmail), diff --git a/api/http/util/errors.go b/api/http/util/errors.go index c7554fccf..8af1f0099 100644 --- a/api/http/util/errors.go +++ b/api/http/util/errors.go @@ -268,4 +268,10 @@ var ( // ErrMissingUsernameEmail indicates missing user name / email. ErrMissingUsernameEmail = errors.New("missing username / email") + + // ErrInvalidVerification indicates invalid email verification. + ErrInvalidVerification = errors.New("invalid verification") + + // ErrEmailNotVerified indicates invalid email not verified. + ErrEmailNotVerified = errors.New("email not verified") ) diff --git a/apidocs/openapi/users.yaml b/apidocs/openapi/users.yaml index 96f842d6f..ddbd457e4 100644 --- a/apidocs/openapi/users.yaml +++ b/apidocs/openapi/users.yaml @@ -607,6 +607,54 @@ paths: "500": $ref: "#/components/responses/ServiceError" + /users/send-verification: + post: + operationId: sendVerification + tags: + - Users + summary: Sends a verification email + description: | + Sends a verification email to the user. + security: + - bearerAuth: [] + responses: + "200": + description: Sent verification email if registered. + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /verify-email: + get: + operationId: verifyEmail + tags: + - Users + summary: Verify user's email + description: | + Verify user's email using the token from the verification link. + parameters: + - $ref: "#/components/parameters/VerificationToken" + responses: + "200": + description: Email verified successfully. + "400": + description: Failed due to malformed query parameters. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + /health: get: operationId: health @@ -1240,6 +1288,14 @@ components: required: false example: "0" + VerificationToken: + name: token + description: Verification token. + in: query + schema: + type: string + required: true + requestBodies: UserCreateReq: description: JSON-formatted document describing the new user to be registered diff --git a/auth/api/grpc/auth/client.go b/auth/api/grpc/auth/client.go index 8b4797638..672d72185 100644 --- a/auth/api/grpc/auth/client.go +++ b/auth/api/grpc/auth/client.go @@ -75,7 +75,7 @@ func (client authGrpcClient) Authenticate(ctx context.Context, token *grpcAuthV1 return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err) } ir := res.(authenticateRes) - return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole)}, nil + return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole), Verified: ir.verified}, nil } func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) { @@ -85,7 +85,7 @@ func encodeIdentifyRequest(_ context.Context, grpcReq any) (any, error) { func decodeIdentifyResponse(_ context.Context, grpcRes any) (any, error) { res := grpcRes.(*grpcAuthV1.AuthNRes) - return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole)}, nil + return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole), verified: res.GetVerified()}, nil } func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) { diff --git a/auth/api/grpc/auth/endpoint.go b/auth/api/grpc/auth/endpoint.go index c72c7e638..98b9c41f6 100644 --- a/auth/api/grpc/auth/endpoint.go +++ b/auth/api/grpc/auth/endpoint.go @@ -23,7 +23,7 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint { return authenticateRes{}, err } - return authenticateRes{userID: key.Subject, userRole: key.Role}, nil + return authenticateRes{userID: key.Subject, userRole: key.Role, verified: key.Verified}, nil } } diff --git a/auth/api/grpc/auth/responses.go b/auth/api/grpc/auth/responses.go index 78b6d7b2b..861bc5b89 100644 --- a/auth/api/grpc/auth/responses.go +++ b/auth/api/grpc/auth/responses.go @@ -9,6 +9,7 @@ type authenticateRes struct { id string userID string userRole smqauth.Role + verified bool } type authorizeRes struct { diff --git a/auth/api/grpc/auth/server.go b/auth/api/grpc/auth/server.go index 2bc765d49..cbb9f1acb 100644 --- a/auth/api/grpc/auth/server.go +++ b/auth/api/grpc/auth/server.go @@ -82,7 +82,7 @@ func decodeAuthenticateRequest(_ context.Context, grpcReq any) (any, error) { func encodeAuthenticateResponse(_ context.Context, grpcRes any) (any, error) { res := grpcRes.(authenticateRes) - return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole)}, nil + return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole), Verified: res.verified}, nil } func encodeAuthenticatePATResponse(_ context.Context, grpcRes any) (any, error) { diff --git a/auth/api/grpc/token/client.go b/auth/api/grpc/token/client.go index b747f3dea..25c3bf62f 100644 --- a/auth/api/grpc/token/client.go +++ b/auth/api/grpc/token/client.go @@ -56,6 +56,7 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR userID: req.GetUserId(), userRole: auth.Role(req.GetUserRole()), keyType: auth.KeyType(req.GetType()), + verified: req.GetVerified(), }) if err != nil { return &grpcTokenV1.Token{}, grpcapi.DecodeError(err) @@ -69,6 +70,7 @@ func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) { UserId: req.userID, UserRole: uint32(req.userRole), Type: uint32(req.keyType), + Verified: req.verified, }, nil } @@ -80,7 +82,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr ctx, cancel := context.WithTimeout(ctx, client.timeout) defer cancel() - res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken()}) + res, err := client.refresh(ctx, refreshReq{refreshToken: req.GetRefreshToken(), verified: req.GetVerified()}) if err != nil { return &grpcTokenV1.Token{}, grpcapi.DecodeError(err) } @@ -89,7 +91,7 @@ func (client tokenGrpcClient) Refresh(ctx context.Context, req *grpcTokenV1.Refr func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(refreshReq) - return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken}, nil + return &grpcTokenV1.RefreshReq{RefreshToken: req.refreshToken, Verified: req.verified}, nil } func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) { diff --git a/auth/api/grpc/token/endpoint.go b/auth/api/grpc/token/endpoint.go index 6e4faa1d2..b03e42ae5 100644 --- a/auth/api/grpc/token/endpoint.go +++ b/auth/api/grpc/token/endpoint.go @@ -18,9 +18,10 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint { } key := auth.Key{ - Type: req.keyType, - Subject: req.userID, - Role: req.userRole, + Type: req.keyType, + Subject: req.userID, + Role: req.userRole, + Verified: req.verified, } tkn, err := svc.Issue(ctx, "", key) if err != nil { @@ -42,7 +43,7 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint { return issueRes{}, err } - key := auth.Key{Type: auth.RefreshKey} + key := auth.Key{Type: auth.RefreshKey, Verified: req.verified} tkn, err := svc.Issue(ctx, req.refreshToken, key) if err != nil { return issueRes{}, err diff --git a/auth/api/grpc/token/requests.go b/auth/api/grpc/token/requests.go index eb10bd660..a5ab3c094 100644 --- a/auth/api/grpc/token/requests.go +++ b/auth/api/grpc/token/requests.go @@ -12,6 +12,7 @@ type issueReq struct { userID string userRole auth.Role keyType auth.KeyType + verified bool } func (req issueReq) validate() error { @@ -27,6 +28,7 @@ func (req issueReq) validate() error { type refreshReq struct { refreshToken string + verified bool } func (req refreshReq) validate() error { diff --git a/auth/api/grpc/token/server.go b/auth/api/grpc/token/server.go index c8a79fa01..319e46e6e 100644 --- a/auth/api/grpc/token/server.go +++ b/auth/api/grpc/token/server.go @@ -58,12 +58,13 @@ func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) { userID: req.GetUserId(), userRole: auth.Role(req.GetUserRole()), keyType: auth.KeyType(req.GetType()), + verified: req.Verified, }, nil } func decodeRefreshRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(*grpcTokenV1.RefreshReq) - return refreshReq{refreshToken: req.GetRefreshToken()}, nil + return refreshReq{refreshToken: req.GetRefreshToken(), verified: req.Verified}, nil } func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) { diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index 8e5cdb0a7..9f0e24838 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -21,6 +21,8 @@ var ( errInvalidType = errors.New("invalid token type") // errInvalidRole is returned when the role is invalid. errInvalidRole = errors.New("invalid role") + // errInvalidVerified is returned when the verified is invalid. + errInvalidVerified = errors.New("invalid verified") // errJWTExpiryKey is used to check if the token is expired. errJWTExpiryKey = errors.New(`"exp" not satisfied`) // ErrSignJWT indicates an error in signing jwt token. @@ -36,6 +38,7 @@ const ( tokenType = "type" userField = "user" RoleField = "role" + VerifiedField = "verified" oauthProviderField = "oauth_provider" oauthAccessTokenField = "access_token" oauthRefreshTokenField = "refresh_token" @@ -62,6 +65,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { Claim(tokenType, key.Type). Expiration(key.ExpiresAt) builder.Claim(RoleField, key.Role) + builder.Claim(VerifiedField, key.Verified) if key.Subject != "" { builder.Subject(key.Subject) } @@ -150,6 +154,16 @@ func toKey(tkn jwt.Token) (auth.Key, error) { if !ok { return auth.Key{}, errInvalidRole } + + tVerified, ok := tkn.Get(VerifiedField) + if !ok { + return auth.Key{}, errInvalidVerified + } + kVerified, ok := tVerified.(bool) + if !ok { + return auth.Key{}, errInvalidVerified + } + kr := auth.Role(kRole) if !kr.Validate() { return auth.Key{}, errInvalidRole @@ -162,6 +176,7 @@ func toKey(tkn jwt.Token) (auth.Key, error) { key.Subject = tkn.Subject() key.IssuedAt = tkn.IssuedAt() key.ExpiresAt = tkn.Expiration() + key.Verified = kVerified return key, nil } diff --git a/auth/keys.go b/auth/keys.go index f149f1624..efde3a00c 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -88,6 +88,7 @@ type Key struct { Role Role `json:"role,omitempty"` IssuedAt time.Time `json:"issued_at,omitempty"` ExpiresAt time.Time `json:"expires_at,omitempty"` + Verified bool `json:"verified,omitempty"` } func (key Key) String() string { diff --git a/certs/api/endpoint_test.go b/certs/api/endpoint_test.go index 9c1e6dee4..71a19a651 100644 --- a/certs/api/endpoint_test.go +++ b/certs/api/endpoint_test.go @@ -72,7 +72,8 @@ func newCertServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticati logger := smqlog.NewMock() idp := uuid.NewMock() authn := new(authnmocks.Authentication) - mux := api.MakeHandler(svc, authn, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := api.MakeHandler(svc, am, logger, "", idp) return httptest.NewServer(mux), svc, authn } diff --git a/certs/api/transport.go b/certs/api/transport.go index 93dd559d0..c9718c37b 100644 --- a/certs/api/transport.go +++ b/certs/api/transport.go @@ -31,7 +31,7 @@ const ( ) // MakeHandler returns a HTTP handler for API endpoints. -func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { +func MakeHandler(svc certs.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -39,7 +39,7 @@ func MakeHandler(svc certs.Service, authn smqauthn.Authentication, logger *slog. r := chi.NewRouter() r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Use(api.RequestIDMiddleware(idp)) r.Route("/{domainID}", func(r chi.Router) { diff --git a/channels/api/http/endpoint_test.go b/channels/api/http/endpoint_test.go index ecc3efa51..d4e59052a 100644 --- a/channels/api/http/endpoint_test.go +++ b/channels/api/http/endpoint_test.go @@ -58,7 +58,8 @@ func newChannelsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenti mux := chi.NewRouter() idp := uuid.NewMock() logger := smqlog.NewMock() - mux = MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux = MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } diff --git a/channels/api/http/endpoints.go b/channels/api/http/endpoints.go index cdfcd3f0a..2afee3150 100644 --- a/channels/api/http/endpoints.go +++ b/channels/api/http/endpoints.go @@ -6,7 +6,6 @@ package http import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/channels" "github.com/absmach/supermq/pkg/authn" @@ -22,7 +21,7 @@ func createChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -46,7 +45,7 @@ func createChannelsEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -77,7 +76,7 @@ func viewChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -98,7 +97,7 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -138,7 +137,7 @@ func updateChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -164,7 +163,7 @@ func updateChannelTagsEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -189,7 +188,7 @@ func setChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -209,7 +208,7 @@ func removeChannelParentGroupEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -229,7 +228,7 @@ func enableChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -250,7 +249,7 @@ func disableChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -271,7 +270,7 @@ func connectChannelClientEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -291,7 +290,7 @@ func disconnectChannelClientsEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -311,7 +310,7 @@ func connectEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -331,7 +330,7 @@ func disconnectEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -351,7 +350,7 @@ func deleteChannelEndpoint(svc channels.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } diff --git a/channels/api/http/transport.go b/channels/api/http/transport.go index e356793e9..2767008d4 100644 --- a/channels/api/http/transport.go +++ b/channels/api/http/transport.go @@ -19,7 +19,7 @@ import ( ) // MakeHandler returns a HTTP handler for Channels API endpoints. -func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux { +func MakeHandler(svc channels.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -27,7 +27,7 @@ func MakeHandler(svc channels.Service, authn smqauthn.Authentication, mux *chi.M d := roleManagerHttp.NewDecoder("channelID") mux.Route("/{domainID}/channels", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Use(api.RequestIDMiddleware(idp)) r.Post("/", otelhttp.NewHandler(kithttp.NewServer( diff --git a/cli/channels.go b/cli/channels.go index 98b0817e5..20e38d4c5 100644 --- a/cli/channels.go +++ b/cli/channels.go @@ -12,14 +12,16 @@ import ( ) const ( - all = "all" - create = "create" - get = "get" - update = "update" - delete = "delete" - enable = "enable" - disable = "disable" - users = "users" + all = "all" + create = "create" + get = "get" + update = "update" + delete = "delete" + enable = "enable" + disable = "disable" + users = "users" + sendVerification = "send-verification" + verifyEmail = "verify-email" usageCreate = "cli channels create " usageGet = "cli channels get " diff --git a/cli/users.go b/cli/users.go index ca5f2f32a..591479522 100644 --- a/cli/users.go +++ b/cli/users.go @@ -72,11 +72,16 @@ Examples: logUsageCmd(*cmd, cmd.Use) return } - switch args[0] { case create: handleUserCreate(cmd, args[1:]) return + case sendVerification: + handleSendVerification(cmd, args[1]) + return + case verifyEmail: + handleVerify(cmd, args[1]) + return case token: handleUserToken(cmd, args[1], args[2:]) return @@ -152,6 +157,34 @@ func handleUserCreate(cmd *cobra.Command, args []string) { logJSONCmd(*cmd, user) } +func handleSendVerification(cmd *cobra.Command, token string) { + if token == "" { + logUsageCmd(*cmd, usageUserToken) + return + } + + if err := sdk.SendVerification(cmd.Context(), token); err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, "sent verification successfully") +} + +func handleVerify(cmd *cobra.Command, token string) { + if token == "" { + logUsageCmd(*cmd, usageUserToken) + return + } + + if err := sdk.VerifyEmail(cmd.Context(), token); err != nil { + logErrorCmd(*cmd, err) + return + } + + logJSONCmd(*cmd, "verified successfully") +} + func handleUserGet(cmd *cobra.Command, userParams string, args []string) { if len(args) != 1 { logUsageCmd(*cmd, usageUserGet) diff --git a/clients/api/http/clients.go b/clients/api/http/clients.go index bcc2c6bf3..035b50525 100644 --- a/clients/api/http/clients.go +++ b/clients/api/http/clients.go @@ -17,14 +17,14 @@ import ( "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) -func clientsHandler(svc clients.Service, authn smqauthn.Authentication, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux { +func clientsHandler(svc clients.Service, authn smqauthn.AuthNMiddleware, r *chi.Mux, logger *slog.Logger, idp supermq.IDProvider) *chi.Mux { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } d := roleManagerHttp.NewDecoder("clientID") r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Use(api.RequestIDMiddleware(idp)) r.Route("/{domainID}/clients", func(r chi.Router) { diff --git a/clients/api/http/endpoints.go b/clients/api/http/endpoints.go index 64b75ee3a..2bd6c7dc7 100644 --- a/clients/api/http/endpoints.go +++ b/clients/api/http/endpoints.go @@ -6,7 +6,6 @@ package http import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/clients" "github.com/absmach/supermq/pkg/authn" @@ -22,7 +21,7 @@ func createClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -46,7 +45,7 @@ func createClientsEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -77,7 +76,7 @@ func viewClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -98,7 +97,7 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -138,7 +137,7 @@ func updateClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -164,7 +163,7 @@ func updateClientTagsEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -189,7 +188,7 @@ func updateClientSecretEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -210,7 +209,7 @@ func enableClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -231,7 +230,7 @@ func disableClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -252,7 +251,7 @@ func setClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -271,7 +270,7 @@ func removeClientParentGroupEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -290,7 +289,7 @@ func deleteClientEndpoint(svc clients.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } diff --git a/clients/api/http/endpoints_test.go b/clients/api/http/endpoints_test.go index 41976c520..655bad028 100644 --- a/clients/api/http/endpoints_test.go +++ b/clients/api/http/endpoints_test.go @@ -96,7 +96,8 @@ func newClientsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic logger := smqlog.NewMock() mux := chi.NewRouter() idp := uuid.NewMock() - clientsapi.MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + clientsapi.MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } diff --git a/clients/api/http/transport.go b/clients/api/http/transport.go index b78218636..8a3b31ec1 100644 --- a/clients/api/http/transport.go +++ b/clients/api/http/transport.go @@ -15,7 +15,7 @@ import ( ) // MakeHandler returns a HTTP handler for clients and Groups API endpoints. -func MakeHandler(tsvc clients.Service, authn smqauthn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { +func MakeHandler(tsvc clients.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { mux = clientsHandler(tsvc, authn, mux, logger, idp) mux.Get("/health", supermq.Health("clients", instanceID)) diff --git a/cmd/certs/main.go b/cmd/certs/main.go index 2742b1c64..2bc87b77d 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -20,6 +20,7 @@ import ( "github.com/absmach/supermq/certs/postgres" "github.com/absmach/supermq/certs/tracing" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" "github.com/absmach/supermq/pkg/grpcclient" jaegerclient "github.com/absmach/supermq/pkg/jaeger" @@ -138,6 +139,7 @@ func main() { } defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) tp, err := jaegerclient.NewProvider(ctx, svcName, cfg.JaegerURL, cfg.InstanceID, cfg.TraceRatio) if err != nil { @@ -163,7 +165,7 @@ func main() { idp := uuid.New() - hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, cfg.InstanceID, idp), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, cfg.InstanceID, idp), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/cmd/channels/main.go b/cmd/channels/main.go index 637b9ccfb..34b7e9dd5 100644 --- a/cmd/channels/main.go +++ b/cmd/channels/main.go @@ -31,6 +31,7 @@ import ( gpostgres "github.com/absmach/supermq/groups/postgres" redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -179,6 +180,7 @@ func main() { } defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) domsGrpcCfg := grpcclient.Config{} if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { @@ -304,7 +306,7 @@ func main() { } mux := chi.NewRouter() idp := uuid.New() - httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger) + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/cmd/clients/main.go b/cmd/clients/main.go index bcf4c4574..5892fb1df 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -31,6 +31,7 @@ import ( gpostgres "github.com/absmach/supermq/groups/postgres" redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -190,6 +191,7 @@ func main() { } defer authnClient.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnClient.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) domsGrpcCfg := grpcclient.Config{} if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { @@ -292,7 +294,7 @@ func main() { } mux := chi.NewRouter() idp := uuid.New() - httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger) + httpSvc := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger) grpcServerConfig := server.Config{Port: defSvcAuthGRPCPort} if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { diff --git a/cmd/domains/main.go b/cmd/domains/main.go index e659ffe69..19e00c958 100644 --- a/cmd/domains/main.go +++ b/cmd/domains/main.go @@ -27,6 +27,7 @@ import ( dtracing "github.com/absmach/supermq/domains/tracing" redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -159,6 +160,7 @@ func main() { } defer authnHandler.Close() logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) database := postgres.NewDatabase(db, dbConfig, tracer) domainsRepo := dpostgres.NewRepository(database) @@ -239,7 +241,7 @@ func main() { } mux := chi.NewMux() idp := uuid.New() - hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger) + hs := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger) g.Go(func() error { return hs.Start() diff --git a/cmd/groups/main.go b/cmd/groups/main.go index aa5912e26..586cb95e5 100644 --- a/cmd/groups/main.go +++ b/cmd/groups/main.go @@ -28,6 +28,7 @@ import ( pgroups "github.com/absmach/supermq/groups/private" "github.com/absmach/supermq/groups/tracing" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -162,6 +163,7 @@ func main() { } defer authnHandler.Close() logger.Info("Authn successfully connected to auth gRPC server " + authnHandler.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) domsGrpcCfg := grpcclient.Config{} if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { @@ -263,7 +265,7 @@ func main() { mux := chi.NewRouter() idp := uuid.New() - httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, mux, logger, cfg.InstanceID, idp), logger) + httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, mux, logger, cfg.InstanceID, idp), logger) grpcServerConfig := server.Config{} if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixgRPC}); err != nil { diff --git a/cmd/journal/main.go b/cmd/journal/main.go index 4401644c1..4c033a817 100644 --- a/cmd/journal/main.go +++ b/cmd/journal/main.go @@ -20,6 +20,7 @@ import ( "github.com/absmach/supermq/journal/middleware" journalpg "github.com/absmach/supermq/journal/postgres" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -112,6 +113,7 @@ func main() { } defer authnHandler.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) domsGrpcCfg := grpcclient.Config{} if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { @@ -173,7 +175,7 @@ func main() { return } - hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authn, logger, svcName, cfg.InstanceID), logger) + hs := http.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(svc, authnMiddleware, logger, svcName, cfg.InstanceID), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) diff --git a/cmd/users/main.go b/cmd/users/main.go index e999afe44..ae8b2d007 100644 --- a/cmd/users/main.go +++ b/cmd/users/main.go @@ -20,6 +20,7 @@ import ( grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" "github.com/absmach/supermq/internal/email" smqlog "github.com/absmach/supermq/logger" + smqauthn "github.com/absmach/supermq/pkg/authn" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" @@ -67,28 +68,31 @@ const ( ) type config struct { - LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"` - AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"` - AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"` - AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"` - AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"` - AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"` - PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"` - JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` - SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` - InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""` - ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` - TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` - SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"` - OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"` - OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"` - DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"` - DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"` - SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` - SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` - SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` - PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost:8080"` - PassRegex *regexp.Regexp + LogLevel string `env:"SMQ_USERS_LOG_LEVEL" envDefault:"info"` + AdminEmail string `env:"SMQ_USERS_ADMIN_EMAIL" envDefault:"admin@example.com"` + AdminPassword string `env:"SMQ_USERS_ADMIN_PASSWORD" envDefault:"12345678"` + AdminUsername string `env:"SMQ_USERS_ADMIN_USERNAME" envDefault:"admin"` + AdminFirstName string `env:"SMQ_USERS_ADMIN_FIRST_NAME" envDefault:"super"` + AdminLastName string `env:"SMQ_USERS_ADMIN_LAST_NAME" envDefault:"admin"` + PassRegexText string `env:"SMQ_USERS_PASS_REGEX" envDefault:"^.{8,}$"` + JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` + SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` + InstanceID string `env:"SMQ_USERS_INSTANCE_ID" envDefault:""` + ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` + TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` + SelfRegister bool `env:"SMQ_USERS_ALLOW_SELF_REGISTER" envDefault:"false"` + OAuthUIRedirectURL string `env:"SMQ_OAUTH_UI_REDIRECT_URL" envDefault:"http://localhost:9095/domains"` + OAuthUIErrorURL string `env:"SMQ_OAUTH_UI_ERROR_URL" envDefault:"http://localhost:9095/error"` + DeleteInterval time.Duration `env:"SMQ_USERS_DELETE_INTERVAL" envDefault:"24h"` + DeleteAfter time.Duration `env:"SMQ_USERS_DELETE_AFTER" envDefault:"720h"` + SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` + SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` + SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + PasswordResetURLPrefix string `env:"SMQ_PASSWORD_RESET_URL_PREFIX" envDefault:"http://localhost/password/reset"` + PasswordResetEmailTemplate string `env:"SMQ_PASSWORD_RESET_EMAIL_TEMPLATE" envDefault:"reset-password-email.tmpl"` + VerificationURLPrefix string `env:"SMQ_VERIFICATION_URL_PREFIX" envDefault:"http://localhost/verify-email"` + VerificationEmailTemplate string `env:"SMQ_VERIFICATION_EMAIL_TEMPLATE" envDefault:"verification-email.tmpl"` + PassRegex *regexp.Regexp } func main() { @@ -121,12 +125,21 @@ func main() { } } - ec := email.Config{} - if err := env.Parse(&ec); err != nil { - logger.Error(fmt.Sprintf("failed to load email configuration : %s", err.Error())) + resetPasswordEmailConfig := email.Config{} + if err := env.Parse(&resetPasswordEmailConfig); err != nil { + logger.Error(fmt.Sprintf("failed to load reset password email configuration : %s", err.Error())) exitCode = 1 return } + resetPasswordEmailConfig.Template = cfg.PasswordResetEmailTemplate + + verificationEmailConfig := email.Config{} + if err := env.Parse(&verificationEmailConfig); err != nil { + logger.Error(fmt.Sprintf("failed to load verification password email configuration : %s", err.Error())) + exitCode = 1 + return + } + verificationEmailConfig.Template = cfg.VerificationEmailTemplate dbConfig := pgclient.Config{Name: defDB} if err := env.ParseWithOptions(&dbConfig, env.Options{Prefix: envPrefixDB}); err != nil { @@ -182,6 +195,8 @@ func main() { defer authnHandler.Close() logger.Info("AuthN successfully connected to auth gRPC server " + authnHandler.Secure()) + authnMiddleware := smqauthn.NewAuthNMiddleware(authn) + domsGrpcCfg := grpcclient.Config{} if err := env.ParseWithOptions(&domsGrpcCfg, env.Options{Prefix: envPrefixDomains}); err != nil { logger.Error(fmt.Sprintf("failed to load domains gRPC client configuration : %s", err)) @@ -213,7 +228,7 @@ func main() { } logger.Info("Policy client successfully connected to spicedb gRPC server") - csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, ec, logger) + csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, resetPasswordEmailConfig, verificationEmailConfig, logger) if err != nil { logger.Error(fmt.Sprintf("failed to setup service: %s", err)) exitCode = 1 @@ -237,7 +252,7 @@ func main() { mux := chi.NewRouter() idp := uuid.New() - httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authn, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger) + httpSrv := httpserver.NewServer(ctx, cancel, svcName, httpServerConfig, httpapi.MakeHandler(csvc, authnMiddleware, tokenClient, cfg.SelfRegister, mux, logger, cfg.InstanceID, cfg.PassRegex, idp, oauthProvider), logger) if cfg.SendTelemetry { chc := chclient.New(svcName, supermq.Version, logger, cancel) @@ -257,16 +272,23 @@ func main() { } } -func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, ec email.Config, logger *slog.Logger) (users.Service, error) { +func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, resetPasswordEmailConfig, verificationEmailConfig email.Config, logger *slog.Logger) (users.Service, error) { database := pg.NewDatabase(db, dbConfig, tracer) idp := uuid.New() hsr := hasher.New() // Creating users service repo := postgres.NewRepository(database) - emailerClient, err := emailer.New(fmt.Sprintf("%s/reset-request", c.PasswordResetURLPrefix), &ec) + + emailerClient, err := emailer.New( + c.PasswordResetURLPrefix, + c.VerificationURLPrefix, + &resetPasswordEmailConfig, + &verificationEmailConfig, + ) if err != nil { logger.Error(fmt.Sprintf("failed to configure e-mailing util: %s", err.Error())) + return nil, err } svc := users.NewService(token, repo, policyService, emailerClient, hsr, idp) diff --git a/docker/.env b/docker/.env index a755423ac..c4b52e1ab 100644 --- a/docker/.env +++ b/docker/.env @@ -248,7 +248,6 @@ SMQ_USERS_DB_SSL_MODE=disable SMQ_USERS_DB_SSL_CERT= SMQ_USERS_DB_SSL_KEY= SMQ_USERS_DB_SSL_ROOT_CERT= -SMQ_USERS_RESET_PWD_TEMPLATE=users.tmpl SMQ_USERS_INSTANCE_ID= SMQ_USERS_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH SMQ_USERS_ADMIN_EMAIL=admin@example.com @@ -261,19 +260,21 @@ SMQ_OAUTH_UI_REDIRECT_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/tokens/secu SMQ_OAUTH_UI_ERROR_URL=http://localhost:9095${SMQ_UI_PATH_PREFIX}/error SMQ_USERS_DELETE_INTERVAL=24h SMQ_USERS_DELETE_AFTER=720h -SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9001 +SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost/password-reset +SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=reset-password-email.tmpl +SMQ_VERIFICATION_URL_PREFIX=http://localhost/verify-email +SMQ_VERIFICATION_EMAIL_TEMPLATE=verification-email.tmpl #### Users Client Config SMQ_USERS_URL=users:9002 ### Email utility -SMQ_EMAIL_HOST=smtp.mailtrap.io +SMQ_EMAIL_HOST=host.docker.internal SMQ_EMAIL_PORT=2525 -SMQ_EMAIL_USERNAME=18bf7f70705139 -SMQ_EMAIL_PASSWORD=2b0d302e775b1e +SMQ_EMAIL_USERNAME=from@example.com +SMQ_EMAIL_PASSWORD=password SMQ_EMAIL_FROM_ADDRESS=from@example.com SMQ_EMAIL_FROM_NAME=Example -SMQ_EMAIL_TEMPLATE=email.tmpl ### Google OAuth2 SMQ_GOOGLE_CLIENT_ID= @@ -514,5 +515,8 @@ SMQ_GRAFANA_PORT=3000 SMQ_GRAFANA_ADMIN_USER=supermq SMQ_GRAFANA_ADMIN_PASSWORD=supermq +## Allow unverified user to access +SMQ_ALLOW_UNVERIFIED_USER=true + # Docker image tag SMQ_RELEASE_TAG=latest diff --git a/docker/addons/certs/docker-compose.yaml b/docker/addons/certs/docker-compose.yaml index 71f3e746b..26f5dffdc 100644 --- a/docker/addons/certs/docker-compose.yaml +++ b/docker/addons/certs/docker-compose.yaml @@ -74,6 +74,7 @@ services: SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY} SMQ_CERTS_INSTANCE_ID: ${SMQ_CERTS_INSTANCE_ID} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} volumes: - ../../ssl/certs/ca.key:/etc/ssl/certs/ca.key - ../../ssl/certs/ca.crt:/etc/ssl/certs/ca.crt diff --git a/docker/addons/journal/docker-compose.yaml b/docker/addons/journal/docker-compose.yaml index 5b1c77dc4..de77ac212 100644 --- a/docker/addons/journal/docker-compose.yaml +++ b/docker/addons/journal/docker-compose.yaml @@ -69,6 +69,7 @@ services: SMQ_DOMAINS_GRPC_CLIENT_CERT: ${SMQ_DOMAINS_GRPC_CLIENT_CERT:+/domains-grpc-client.crt} SMQ_DOMAINS_GRPC_CLIENT_KEY: ${SMQ_DOMAINS_GRPC_CLIENT_KEY:+/domains-grpc-client.key} SMQ_DOMAINS_GRPC_SERVER_CA_CERTS: ${SMQ_DOMAINS_GRPC_SERVER_CA_CERTS:+/domains-grpc-server-ca.crt} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} ports: - ${SMQ_JOURNAL_HTTP_PORT}:${SMQ_JOURNAL_HTTP_PORT} networks: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index f00932467..9eae71156 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -285,6 +285,7 @@ services: SMQ_DOMAINS_CALLOUT_CERT: ${SMQ_DOMAINS_CALLOUT_CERT} SMQ_DOMAINS_CALLOUT_KEY: ${SMQ_DOMAINS_CALLOUT_KEY} SMQ_DOMAINS_CALLOUT_OPERATIONS: ${SMQ_DOMAINS_CALLOUT_OPERATIONS} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} ports: - ${SMQ_DOMAINS_HTTP_PORT}:${SMQ_DOMAINS_HTTP_PORT} - ${SMQ_DOMAINS_GRPC_PORT}:${SMQ_DOMAINS_GRPC_PORT} @@ -516,6 +517,7 @@ services: SMQ_CLIENTS_CALLOUT_CERT: ${SMQ_CLIENTS_CALLOUT_CERT} SMQ_CLIENTS_CALLOUT_KEY: ${SMQ_CLIENTS_CALLOUT_KEY} SMQ_CLIENTS_CALLOUT_OPERATIONS: ${SMQ_CLIENTS_CALLOUT_OPERATIONS} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} ports: - ${SMQ_CLIENTS_HTTP_PORT}:${SMQ_CLIENTS_HTTP_PORT} - ${SMQ_CLIENTS_GRPC_PORT}:${SMQ_CLIENTS_GRPC_PORT} @@ -706,6 +708,7 @@ services: SMQ_CHANNELS_CALLOUT_CERT: ${SMQ_CHANNELS_CALLOUT_CERT} SMQ_CHANNELS_CALLOUT_KEY: ${SMQ_CHANNELS_CALLOUT_KEY} SMQ_CHANNELS_CALLOUT_OPERATIONS: ${SMQ_CHANNELS_CALLOUT_OPERATIONS} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} ports: - ${SMQ_CHANNELS_HTTP_PORT}:${SMQ_CHANNELS_HTTP_PORT} - ${SMQ_CHANNELS_GRPC_PORT}:${SMQ_CHANNELS_GRPC_PORT} @@ -855,7 +858,6 @@ services: SMQ_EMAIL_PASSWORD: ${SMQ_EMAIL_PASSWORD} SMQ_EMAIL_FROM_ADDRESS: ${SMQ_EMAIL_FROM_ADDRESS} SMQ_EMAIL_FROM_NAME: ${SMQ_EMAIL_FROM_NAME} - SMQ_EMAIL_TEMPLATE: ${SMQ_EMAIL_TEMPLATE} SMQ_ES_URL: ${SMQ_ES_URL} SMQ_JAEGER_URL: ${SMQ_JAEGER_URL} SMQ_JAEGER_TRACE_RATIO: ${SMQ_JAEGER_TRACE_RATIO} @@ -882,12 +884,16 @@ services: SMQ_SPICEDB_HOST: ${SMQ_SPICEDB_HOST} SMQ_SPICEDB_PORT: ${SMQ_SPICEDB_PORT} SMQ_PASSWORD_RESET_URL_PREFIX: ${SMQ_PASSWORD_RESET_URL_PREFIX} + SMQ_PASSWORD_RESET_EMAIL_TEMPLATE: ${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE} + SMQ_VERIFICATION_URL_PREFIX: ${SMQ_VERIFICATION_URL_PREFIX} + SMQ_VERIFICATION_EMAIL_TEMPLATE: ${SMQ_VERIFICATION_EMAIL_TEMPLATE} ports: - ${SMQ_USERS_HTTP_PORT}:${SMQ_USERS_HTTP_PORT} networks: - supermq-base-net volumes: - - ./templates/${SMQ_USERS_RESET_PWD_TEMPLATE}:/email.tmpl + - ./templates/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE}:/${SMQ_PASSWORD_RESET_EMAIL_TEMPLATE} + - ./templates/${SMQ_VERIFICATION_EMAIL_TEMPLATE}:/${SMQ_VERIFICATION_EMAIL_TEMPLATE} # Auth gRPC client certificates - type: bind source: ${SMQ_AUTH_GRPC_CLIENT_CERT:-ssl/certs/dummy/client_cert} @@ -1007,6 +1013,7 @@ services: SMQ_GROUPS_CALLOUT_CERT: ${SMQ_GROUPS_CALLOUT_CERT} SMQ_GROUPS_CALLOUT_KEY: ${SMQ_GROUPS_CALLOUT_KEY} SMQ_GROUPS_CALLOUT_OPERATIONS: ${SMQ_GROUPS_CALLOUT_OPERATIONS} + SMQ_ALLOW_UNVERIFIED_USER: ${SMQ_ALLOW_UNVERIFIED_USER} ports: - ${SMQ_GROUPS_HTTP_PORT}:${SMQ_GROUPS_HTTP_PORT} - ${SMQ_GROUPS_GRPC_PORT}:${SMQ_GROUPS_GRPC_PORT} diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index d4ea0dbdc..52c2aaf34 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -72,7 +72,7 @@ http { } # Proxy pass to users service - location ~ ^/(users|password|authorize|oauth/callback/[^/]+) { + location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; proxy_pass http://users:${SMQ_USERS_HTTP_PORT}; diff --git a/docker/nginx/nginx-x509.conf b/docker/nginx/nginx-x509.conf index dadcb547a..11d8bf3d6 100644 --- a/docker/nginx/nginx-x509.conf +++ b/docker/nginx/nginx-x509.conf @@ -81,7 +81,7 @@ http { } # Proxy pass to users service - location ~ ^/(users|groups|password|authorize|oauth/callback/[^/]+) { + location ~ ^/(users|password|verify-email|authorize|oauth/callback/[^/]+) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; proxy_pass http://users:${SMQ_USERS_HTTP_PORT}; diff --git a/docker/templates/users.tmpl b/docker/templates/reset-password-email.tmpl similarity index 100% rename from docker/templates/users.tmpl rename to docker/templates/reset-password-email.tmpl diff --git a/docker/templates/verification-email.tmpl b/docker/templates/verification-email.tmpl new file mode 100644 index 000000000..2c4ce30fd --- /dev/null +++ b/docker/templates/verification-email.tmpl @@ -0,0 +1,15 @@ +Dear {{.User}}, + +Welcome to {{.Host}}! To complete your account registration, please verify your email address by clicking on the link below: + +{{.Content}} + +This verification link will expire in 24 hours. If you did not create an account on {{.Host}}, please disregard this message. + +Once verified, you'll have full access to all features on {{.Host}}. + +Thank you for joining {{.Host}}! + +Best regards, + +{{.Footer}} diff --git a/domains/api/http/endpoint.go b/domains/api/http/endpoint.go index a3d4fab23..6e7cd8949 100644 --- a/domains/api/http/endpoint.go +++ b/domains/api/http/endpoint.go @@ -6,7 +6,6 @@ package http import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/authn" @@ -25,7 +24,7 @@ func createDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -53,7 +52,7 @@ func retrieveDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -73,7 +72,7 @@ func updateDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -103,7 +102,7 @@ func listDomainsEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -123,7 +122,7 @@ func enableDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -142,7 +141,7 @@ func disableDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -161,7 +160,7 @@ func freezeDomainEndpoint(svc domains.Service) endpoint.Endpoint { return nil, err } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -179,7 +178,7 @@ func sendInvitationEndpoint(svc domains.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -208,7 +207,7 @@ func listDomainInvitationsEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -231,7 +230,7 @@ func listUserInvitationsEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -254,7 +253,7 @@ func acceptInvitationEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -274,7 +273,7 @@ func rejectInvitationEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -294,7 +293,7 @@ func deleteInvitationEndpoint(svc domains.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } diff --git a/domains/api/http/endpoint_test.go b/domains/api/http/endpoint_test.go index 676ab32bc..24d920526 100644 --- a/domains/api/http/endpoint_test.go +++ b/domains/api/http/endpoint_test.go @@ -21,6 +21,7 @@ import ( "github.com/absmach/supermq/internal/testsutil" smqlog "github.com/absmach/supermq/logger" "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" authnmock "github.com/absmach/supermq/pkg/authn/mocks" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" @@ -98,7 +99,8 @@ func newDomainsServer() (*httptest.Server, *mocks.Service, *authnmock.Authentica authn := new(authnmock.Authentication) mux := chi.NewMux() idp := uuid.NewMock() - domainsapi.MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + domainsapi.MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } diff --git a/domains/api/http/transport.go b/domains/api/http/transport.go index eefecf699..ca0b6c08a 100644 --- a/domains/api/http/transport.go +++ b/domains/api/http/transport.go @@ -11,7 +11,7 @@ import ( api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/domains" - "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api" "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" @@ -20,7 +20,7 @@ import ( ) // MakeHandler returns a HTTP handler for Domains and Invitations API endpoints. -func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { +func MakeHandler(svc domains.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -30,7 +30,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, r.Use(api.RequestIDMiddleware(idp)) r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, false)) + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware()) r.Post("/", otelhttp.NewHandler(kithttp.NewServer( createDomainEndpoint(svc), decodeCreateDomainRequest, @@ -49,7 +49,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, }) r.Route("/{domainID}", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Get("/", otelhttp.NewHandler(kithttp.NewServer( retrieveDomainEndpoint(svc), decodeRetrieveDomainRequest, @@ -88,7 +88,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, }) r.Route("/{domainID}/invitations", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Post("/", otelhttp.NewHandler(kithttp.NewServer( sendInvitationEndpoint(svc), decodeSendInvitationReq, @@ -111,7 +111,7 @@ func MakeHandler(svc domains.Service, authn authn.Authentication, mux *chi.Mux, }) mux.Route("/invitations", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, false)) + r.Use(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware()) r.Get("/", otelhttp.NewHandler(kithttp.NewServer( listUserInvitationsEndpoint(svc), decodeListInvitationsReq, diff --git a/groups/api/http/endpoint_test.go b/groups/api/http/endpoint_test.go index 13a6b8b06..cd9b458c0 100644 --- a/groups/api/http/endpoint_test.go +++ b/groups/api/http/endpoint_test.go @@ -60,7 +60,8 @@ func newGroupsServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentica mux := chi.NewRouter() idp := uuid.NewMock() logger := smqlog.NewMock() - mux = MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux = MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } diff --git a/groups/api/http/endpoints.go b/groups/api/http/endpoints.go index 6bb946b19..8d8623328 100644 --- a/groups/api/http/endpoints.go +++ b/groups/api/http/endpoints.go @@ -6,7 +6,6 @@ package api import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/groups" "github.com/absmach/supermq/pkg/authn" @@ -22,7 +21,7 @@ func CreateGroupEndpoint(svc groups.Service) endpoint.Endpoint { return createGroupRes{created: false}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return createGroupRes{created: false}, svcerr.ErrAuthentication } @@ -43,7 +42,7 @@ func ViewGroupEndpoint(svc groups.Service) endpoint.Endpoint { return viewGroupRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return viewGroupRes{}, svcerr.ErrAuthentication } @@ -64,7 +63,7 @@ func UpdateGroupEndpoint(svc groups.Service) endpoint.Endpoint { return updateGroupRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return updateGroupRes{}, svcerr.ErrAuthentication } @@ -92,7 +91,7 @@ func updateGroupTagsEndpoint(svc groups.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -117,7 +116,7 @@ func EnableGroupEndpoint(svc groups.Service) endpoint.Endpoint { return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -137,7 +136,7 @@ func DisableGroupEndpoint(svc groups.Service) endpoint.Endpoint { return changeStatusRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -158,7 +157,7 @@ func ListGroupsEndpoint(svc groups.Service) endpoint.Endpoint { return groupPageRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return groupPageRes{}, svcerr.ErrAuthentication } @@ -198,7 +197,7 @@ func DeleteGroupEndpoint(svc groups.Service) endpoint.Endpoint { return deleteGroupRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return deleteGroupRes{}, svcerr.ErrAuthentication } @@ -216,7 +215,7 @@ func retrieveGroupHierarchyEndpoint(svc groups.Service) endpoint.Endpoint { return retrieveGroupHierarchyRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -244,7 +243,7 @@ func addParentGroupEndpoint(svc groups.Service) endpoint.Endpoint { return addParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -263,7 +262,7 @@ func removeParentGroupEndpoint(svc groups.Service) endpoint.Endpoint { return removeParentGroupRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -282,7 +281,7 @@ func addChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint { return addChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -301,7 +300,7 @@ func removeChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint { return removeChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -320,7 +319,7 @@ func removeAllChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint { return removeAllChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } @@ -339,7 +338,7 @@ func listChildrenGroupsEndpoint(svc groups.Service) endpoint.Endpoint { return listChildrenGroupsRes{}, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return changeStatusRes{}, svcerr.ErrAuthentication } diff --git a/groups/api/http/transport.go b/groups/api/http/transport.go index 689989fea..27c086f1f 100644 --- a/groups/api/http/transport.go +++ b/groups/api/http/transport.go @@ -10,7 +10,7 @@ import ( api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/groups" - "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" roleManagerHttp "github.com/absmach/supermq/pkg/roles/rolemanager/api" "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" @@ -19,14 +19,14 @@ import ( ) // MakeHandler returns a HTTP handler for Groups API endpoints. -func MakeHandler(svc groups.Service, authn authn.Authentication, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux { +func MakeHandler(svc groups.Service, authn smqauthn.AuthNMiddleware, mux *chi.Mux, logger *slog.Logger, instanceID string, idp supermq.IDProvider) *chi.Mux { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } d := roleManagerHttp.NewDecoder("groupID") mux.Route("/{domainID}/groups", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Use(api.RequestIDMiddleware(idp)) r.Post("/", otelhttp.NewHandler(kithttp.NewServer( diff --git a/internal/proto/auth/v1/auth.proto b/internal/proto/auth/v1/auth.proto index 269ca313e..16b7c14de 100644 --- a/internal/proto/auth/v1/auth.proto +++ b/internal/proto/auth/v1/auth.proto @@ -6,7 +6,7 @@ syntax = "proto3"; package auth.v1; option go_package = "github.com/absmach/supermq/api/grpc/auth/v1"; -// AuthService is a service that provides authentication +// AuthService is a service that provides authentication // and authorization functionalities for SuperMQ services. service AuthService { rpc Authorize(AuthZReq) returns (AuthZRes) {} @@ -23,7 +23,8 @@ message AuthNReq { message AuthNRes { string id = 1; // token id string user_id = 2; // user id - uint32 user_role = 3; // user role + uint32 user_role = 3; // user role + bool verified = 4; // verified user } message AuthZReq { diff --git a/internal/proto/token/v1/token.proto b/internal/proto/token/v1/token.proto index dd7b378c1..10c066511 100644 --- a/internal/proto/token/v1/token.proto +++ b/internal/proto/token/v1/token.proto @@ -15,10 +15,12 @@ message IssueReq { string user_id = 1; uint32 user_role = 2; uint32 type = 3; + bool verified = 4; } message RefreshReq { string refresh_token = 1; + bool verified = 2; } // If a token is not carrying any information itself, the type diff --git a/journal/api/endpoint.go b/journal/api/endpoint.go index 67cecdee0..58a8b71a8 100644 --- a/journal/api/endpoint.go +++ b/journal/api/endpoint.go @@ -6,7 +6,6 @@ package api import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/journal" "github.com/absmach/supermq/pkg/authn" @@ -22,7 +21,7 @@ func retrieveJournalsEndpoint(svc journal.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -45,7 +44,7 @@ func retrieveClientTelemetryEndpoint(svc journal.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } diff --git a/journal/api/endpoint_test.go b/journal/api/endpoint_test.go index 17cf2f3da..05adda55a 100644 --- a/journal/api/endpoint_test.go +++ b/journal/api/endpoint_test.go @@ -57,7 +57,8 @@ func newjournalServer() (*httptest.Server, *mocks.Service, *authnmocks.Authentic logger := smqlog.NewMock() authn := new(authnmocks.Authentication) - mux := httpapi.MakeHandler(svc, authn, logger, "journal-log", "test") + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := httpapi.MakeHandler(svc, am, logger, "journal-log", "test") return httptest.NewServer(mux), svc, authn } diff --git a/journal/api/transport.go b/journal/api/transport.go index a9f34f094..498c0a9ab 100644 --- a/journal/api/transport.go +++ b/journal/api/transport.go @@ -34,7 +34,7 @@ const ( ) // MakeHandler returns a HTTP API handler with health check and metrics. -func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slog.Logger, svcName, instanceID string) http.Handler { +func MakeHandler(svc journal.Service, authn smqauthn.AuthNMiddleware, logger *slog.Logger, svcName, instanceID string) http.Handler { opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } @@ -43,7 +43,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo idp := uuid.New() mux.Use(api.RequestIDMiddleware(idp)) - mux.With(api.AuthenticateMiddleware(authn, false)).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer( + mux.With(authn.WithOptions(smqauthn.WithDomainCheck(false)).Middleware()).Get("/journal/user/{userID}", otelhttp.NewHandler(kithttp.NewServer( retrieveJournalsEndpoint(svc), decodeRetrieveUserJournalReq, api.EncodeResponse, @@ -51,7 +51,7 @@ func MakeHandler(svc journal.Service, authn smqauthn.Authentication, logger *slo ), "list_user_journals").ServeHTTP) mux.Route("/{domainID}/journal", func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, true)) + r.Use(authn.Middleware()) r.Get("/{entityType}/{entityID}", otelhttp.NewHandler(kithttp.NewServer( retrieveJournalsEndpoint(svc), diff --git a/pkg/authn/authn.go b/pkg/authn/authn.go index f6418a062..a525324e3 100644 --- a/pkg/authn/authn.go +++ b/pkg/authn/authn.go @@ -44,6 +44,7 @@ type Session struct { DomainID string DomainUserID string SuperAdmin bool + Verified bool Role Role } diff --git a/pkg/authn/authsvc/authn.go b/pkg/authn/authsvc/authn.go index 81fac2095..d1d04414e 100644 --- a/pkg/authn/authsvc/authn.go +++ b/pkg/authn/authsvc/authn.go @@ -54,5 +54,5 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err) } - return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil + return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole()), Verified: res.GetVerified()}, nil } diff --git a/pkg/authn/middleware.go b/pkg/authn/middleware.go new file mode 100644 index 000000000..6d24db625 --- /dev/null +++ b/pkg/authn/middleware.go @@ -0,0 +1,177 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package authn + +import ( + "context" + "encoding/json" + "net/http" + "os" + "strconv" + + apiutil "github.com/absmach/supermq/api/http/util" + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + "github.com/go-chi/chi/v5" +) + +type sessionKeyType string + +const ( + allowUnverifiedUserEnv = "SMQ_ALLOW_UNVERIFIED_USER" + jsonContentType = "application/json" + + SessionKey = sessionKeyType("session") +) + +// middlewareOptions contains configuration for authentication middleware. +type middlewareOptions struct { + domainCheck bool + allowUnverifiedUser bool +} + +// defaultMiddlewareOptions returns the default middleware configuration. +func defaultMiddlewareOptions() *middlewareOptions { + return &middlewareOptions{ + domainCheck: true, + allowUnverifiedUser: false, + } +} + +// MiddlewareOption is a function that modifies middleware options. +type MiddlewareOption func(*middlewareOptions) + +// WithDomainCheck sets whether domain checking is enabled. +func WithDomainCheck(enabled bool) MiddlewareOption { + return func(opts *middlewareOptions) { + opts.domainCheck = enabled + } +} + +// WithAllowUnverifiedUser sets whether unverified users are allowed. +func WithAllowUnverifiedUser(allowed bool) MiddlewareOption { + return func(opts *middlewareOptions) { + opts.allowUnverifiedUser = allowed + } +} + +// WithDefaultMiddlewareOptions resets options to default values. +func WithDefaultMiddlewareOptions() MiddlewareOption { + return func(opts *middlewareOptions) { + defaults := defaultMiddlewareOptions() + opts.domainCheck = defaults.domainCheck + opts.allowUnverifiedUser = defaults.allowUnverifiedUser + } +} + +// AuthNMiddleware defines the interface for authenticated services with middleware. +type AuthNMiddleware interface { + Authentication + WithOptions(options ...MiddlewareOption) AuthNMiddleware + Middleware() func(http.Handler) http.Handler +} + +// authnMiddleware wraps Authentication with middleware functionality. +type authnMiddleware struct { + Authentication + options []MiddlewareOption +} + +// NewAuthNMiddleware creates a new authenticated service with middleware support. +// The order of precedence for options is as follows, with later options overriding earlier ones: +// 1. Default options (lowest precedence). +// 2. Options from environment variables (e.g., SMQ_ALLOW_UNVERIFIED_USER). +// 3. Options passed as arguments to this function (highest precedence). +// +// For example, consider the 'allowUnverifiedUser' option: +// - By default, it is 'false'. +// - If the SMQ_ALLOW_UNVERIFIED_USER environment variable is set to "true", +// it becomes 'true'. +// - If NewAuthNMiddleware is called with WithAllowUnverifiedUser(false), it will be 'false', +// regardless of the environment variable, as function arguments have the highest precedence. +func NewAuthNMiddleware(authnSvc Authentication, options ...MiddlewareOption) AuthNMiddleware { + allOptions := []MiddlewareOption{WithDefaultMiddlewareOptions()} + if val, ok := os.LookupEnv(allowUnverifiedUserEnv); ok { + allowUnverifiedUser, err := strconv.ParseBool(val) + if err == nil && allowUnverifiedUser { + allOptions = append(allOptions, WithAllowUnverifiedUser(true)) + } + } + allOptions = append(allOptions, options...) + return &authnMiddleware{ + Authentication: authnSvc, + options: allOptions, + } +} + +// WithOptions returns a new service with additional options. +func (a *authnMiddleware) WithOptions(options ...MiddlewareOption) AuthNMiddleware { + return &authnMiddleware{ + Authentication: a.Authentication, + options: append(a.options, options...), + } +} + +// getMiddlewareOptions returns the configured middleware options. +func (a *authnMiddleware) getMiddlewareOptions() *middlewareOptions { + opts := defaultMiddlewareOptions() + for _, option := range a.options { + option(opts) + } + return opts +} + +// Middleware returns an HTTP middleware function that handles authentication. +func (a *authnMiddleware) Middleware() func(http.Handler) http.Handler { + opts := a.getMiddlewareOptions() + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := apiutil.ExtractBearerToken(r) + if token == "" { + encodeError(w, apiutil.ErrBearerToken, http.StatusUnauthorized) + return + } + resp, err := a.Authenticate(r.Context(), token) + if err != nil { + encodeError(w, err, http.StatusUnauthorized) + return + } + + if resp.Type == AccessToken && !opts.allowUnverifiedUser && resp.Role != AdminRole && !resp.Verified { + encodeError(w, apiutil.ErrEmailNotVerified, http.StatusUnauthorized) + return + } + + if opts.domainCheck { + domain := chi.URLParam(r, "domainID") + if domain == "" { + encodeError(w, apiutil.ErrMissingDomainID, http.StatusBadRequest) + return + } + resp.DomainID = domain + switch resp.Role { + case AdminRole: + resp.DomainUserID = resp.UserID + case UserRole: + resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID) + } + } + + ctx := context.WithValue(r.Context(), SessionKey, resp) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func encodeError(w http.ResponseWriter, err error, statusCode int) { + if errorVal, ok := err.(errors.Error); ok { + w.Header().Set("Content-Type", jsonContentType) + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(errorVal); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + return + } + http.Error(w, err.Error(), statusCode) +} diff --git a/pkg/authn/mocks/auth_n_middleware.go b/pkg/authn/mocks/auth_n_middleware.go new file mode 100644 index 000000000..4ae1acc98 --- /dev/null +++ b/pkg/authn/mocks/auth_n_middleware.go @@ -0,0 +1,217 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +package mocks + +import ( + "context" + "net/http" + + "github.com/absmach/supermq/pkg/authn" + mock "github.com/stretchr/testify/mock" +) + +// NewAuthNMiddleware creates a new instance of AuthNMiddleware. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthNMiddleware(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthNMiddleware { + mock := &AuthNMiddleware{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// AuthNMiddleware is an autogenerated mock type for the AuthNMiddleware type +type AuthNMiddleware struct { + mock.Mock +} + +type AuthNMiddleware_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthNMiddleware) EXPECT() *AuthNMiddleware_Expecter { + return &AuthNMiddleware_Expecter{mock: &_m.Mock} +} + +// Authenticate provides a mock function for the type AuthNMiddleware +func (_mock *AuthNMiddleware) Authenticate(ctx context.Context, token string) (authn.Session, error) { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for Authenticate") + } + + var r0 authn.Session + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (authn.Session, error)); ok { + return returnFunc(ctx, token) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) authn.Session); ok { + r0 = returnFunc(ctx, token) + } else { + r0 = ret.Get(0).(authn.Session) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, token) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// AuthNMiddleware_Authenticate_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Authenticate' +type AuthNMiddleware_Authenticate_Call struct { + *mock.Call +} + +// Authenticate is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *AuthNMiddleware_Expecter) Authenticate(ctx interface{}, token interface{}) *AuthNMiddleware_Authenticate_Call { + return &AuthNMiddleware_Authenticate_Call{Call: _e.mock.On("Authenticate", ctx, token)} +} + +func (_c *AuthNMiddleware_Authenticate_Call) Run(run func(ctx context.Context, token string)) *AuthNMiddleware_Authenticate_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *AuthNMiddleware_Authenticate_Call) Return(session authn.Session, err error) *AuthNMiddleware_Authenticate_Call { + _c.Call.Return(session, err) + return _c +} + +func (_c *AuthNMiddleware_Authenticate_Call) RunAndReturn(run func(ctx context.Context, token string) (authn.Session, error)) *AuthNMiddleware_Authenticate_Call { + _c.Call.Return(run) + return _c +} + +// Middleware provides a mock function for the type AuthNMiddleware +func (_mock *AuthNMiddleware) Middleware() func(http.Handler) http.Handler { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Middleware") + } + + var r0 func(http.Handler) http.Handler + if returnFunc, ok := ret.Get(0).(func() func(http.Handler) http.Handler); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(func(http.Handler) http.Handler) + } + } + return r0 +} + +// AuthNMiddleware_Middleware_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Middleware' +type AuthNMiddleware_Middleware_Call struct { + *mock.Call +} + +// Middleware is a helper method to define mock.On call +func (_e *AuthNMiddleware_Expecter) Middleware() *AuthNMiddleware_Middleware_Call { + return &AuthNMiddleware_Middleware_Call{Call: _e.mock.On("Middleware")} +} + +func (_c *AuthNMiddleware_Middleware_Call) Run(run func()) *AuthNMiddleware_Middleware_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *AuthNMiddleware_Middleware_Call) Return(fn func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call { + _c.Call.Return(fn) + return _c +} + +func (_c *AuthNMiddleware_Middleware_Call) RunAndReturn(run func() func(http.Handler) http.Handler) *AuthNMiddleware_Middleware_Call { + _c.Call.Return(run) + return _c +} + +// WithOptions provides a mock function for the type AuthNMiddleware +func (_mock *AuthNMiddleware) WithOptions(options ...authn.MiddlewareOption) authn.AuthNMiddleware { + var tmpRet mock.Arguments + if len(options) > 0 { + tmpRet = _mock.Called(options) + } else { + tmpRet = _mock.Called() + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for WithOptions") + } + + var r0 authn.AuthNMiddleware + if returnFunc, ok := ret.Get(0).(func(...authn.MiddlewareOption) authn.AuthNMiddleware); ok { + r0 = returnFunc(options...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(authn.AuthNMiddleware) + } + } + return r0 +} + +// AuthNMiddleware_WithOptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithOptions' +type AuthNMiddleware_WithOptions_Call struct { + *mock.Call +} + +// WithOptions is a helper method to define mock.On call +// - options ...authn.MiddlewareOption +func (_e *AuthNMiddleware_Expecter) WithOptions(options ...interface{}) *AuthNMiddleware_WithOptions_Call { + return &AuthNMiddleware_WithOptions_Call{Call: _e.mock.On("WithOptions", + append([]interface{}{}, options...)...)} +} + +func (_c *AuthNMiddleware_WithOptions_Call) Run(run func(options ...authn.MiddlewareOption)) *AuthNMiddleware_WithOptions_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []authn.MiddlewareOption + var variadicArgs []authn.MiddlewareOption + if len(args) > 0 { + variadicArgs = args[0].([]authn.MiddlewareOption) + } + arg0 = variadicArgs + run( + arg0..., + ) + }) + return _c +} + +func (_c *AuthNMiddleware_WithOptions_Call) Return(authNMiddleware authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call { + _c.Call.Return(authNMiddleware) + return _c +} + +func (_c *AuthNMiddleware_WithOptions_Call) RunAndReturn(run func(options ...authn.MiddlewareOption) authn.AuthNMiddleware) *AuthNMiddleware_WithOptions_Call { + _c.Call.Return(run) + return _c +} diff --git a/pkg/errors/service/types.go b/pkg/errors/service/types.go index a0302ee88..b63cbd6d5 100644 --- a/pkg/errors/service/types.go +++ b/pkg/errors/service/types.go @@ -93,4 +93,13 @@ var ( // ErrSuperAdminAction indicates that the user is not a super admin. ErrSuperAdminAction = errors.New("not authorized to perform admin action") + + // ErrUserAlreadyVerified indicates user is already verified. + ErrUserAlreadyVerified = errors.New("user already verified") + + // ErrInvalidUserVerification indicates user verification is invalid. + ErrInvalidUserVerification = errors.New("invalid verification") + + // ErrUserVerificationExpired indicates user verification is expired. + ErrUserVerificationExpired = errors.New("verification expired, please generate new verification") ) diff --git a/pkg/roles/rolemanager/api/endpoints.go b/pkg/roles/rolemanager/api/endpoints.go index 8828fd8d5..45c608e79 100644 --- a/pkg/roles/rolemanager/api/endpoints.go +++ b/pkg/roles/rolemanager/api/endpoints.go @@ -6,7 +6,6 @@ package http import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" @@ -22,7 +21,7 @@ func CreateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -42,7 +41,7 @@ func ListRolesEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -62,7 +61,7 @@ func ListEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -94,7 +93,7 @@ func RemoveEntityMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -113,7 +112,7 @@ func ViewRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -133,7 +132,7 @@ func UpdateRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -153,7 +152,7 @@ func DeleteRoleEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -172,7 +171,7 @@ func ListAvailableActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -192,7 +191,7 @@ func AddRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -212,7 +211,7 @@ func ListRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -232,7 +231,7 @@ func DeleteRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -251,7 +250,7 @@ func DeleteAllRoleActionsEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -270,7 +269,7 @@ func AddRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -290,7 +289,7 @@ func ListRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -310,7 +309,7 @@ func DeleteRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -329,7 +328,7 @@ func DeleteAllRoleMembersEndpoint(svc roles.RoleManager) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } diff --git a/pkg/sdk/certs_test.go b/pkg/sdk/certs_test.go index c50d0549c..d568e64c1 100644 --- a/pkg/sdk/certs_test.go +++ b/pkg/sdk/certs_test.go @@ -67,7 +67,8 @@ func setupCerts() (*httptest.Server, *mocks.Service, *authnmocks.Authentication) logger := smqlog.NewMock() authn := new(authnmocks.Authentication) idp := uuid.NewMock() - mux := httpapi.MakeHandler(svc, authn, logger, instanceID, idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := httpapi.MakeHandler(svc, am, logger, instanceID, idp) return httptest.NewServer(mux), svc, authn } diff --git a/pkg/sdk/channels_test.go b/pkg/sdk/channels_test.go index 4e181c913..a57910ef9 100644 --- a/pkg/sdk/channels_test.go +++ b/pkg/sdk/channels_test.go @@ -43,7 +43,8 @@ func setupChannels() (*httptest.Server, *chmocks.Service, *authnmocks.Authentica authn := new(authnmocks.Authentication) mux := chi.NewRouter() idp := uuid.NewMock() - chapi.MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + chapi.MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } @@ -1094,7 +1095,7 @@ func TestUpdateChannelTags(t *testing.T) { svcRes: channels.Channel{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Channel{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "update channel tags with empty token", diff --git a/pkg/sdk/clients_test.go b/pkg/sdk/clients_test.go index c20fca134..3e160838b 100644 --- a/pkg/sdk/clients_test.go +++ b/pkg/sdk/clients_test.go @@ -37,7 +37,8 @@ func setupClients() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio mux := chi.NewRouter() idp := uuid.NewMock() authn := new(authnmocks.Authentication) - api.MakeHandler(tsvc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + api.MakeHandler(tsvc, am, mux, logger, "", idp) return httptest.NewServer(mux), tsvc, authn } @@ -647,7 +648,7 @@ func TestViewClient(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "view client with empty token", @@ -786,7 +787,7 @@ func TestUpdateClient(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "update client with empty token", @@ -940,7 +941,7 @@ func TestUpdateClientTags(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "update client tags with empty token", @@ -1089,7 +1090,7 @@ func TestUpdateClientSecret(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "update client secret with empty token", @@ -1217,7 +1218,7 @@ func TestEnableClient(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "enable client with an invalid client id", @@ -1320,7 +1321,7 @@ func TestDisableClient(t *testing.T) { svcRes: clients.Client{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Client{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "disable client with an invalid client id", @@ -1415,7 +1416,7 @@ func TestDeleteClient(t *testing.T) { token: invalidToken, clientID: client.ID, authenticateErr: svcerr.ErrAuthorization, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "delete client with empty token", diff --git a/pkg/sdk/domains_test.go b/pkg/sdk/domains_test.go index 80e20e7e8..e94852adb 100644 --- a/pkg/sdk/domains_test.go +++ b/pkg/sdk/domains_test.go @@ -62,8 +62,9 @@ func setupDomains() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio mux := chi.NewRouter() idp := uuid.NewMock() authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) - handler := domainapi.MakeHandler(svc, authn, mux, logger, "", idp) + handler := domainapi.MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(handler), svc, authn } diff --git a/pkg/sdk/groups_test.go b/pkg/sdk/groups_test.go index c6e5ecd15..40c59fe30 100644 --- a/pkg/sdk/groups_test.go +++ b/pkg/sdk/groups_test.go @@ -48,7 +48,8 @@ func setupGroups() (*httptest.Server, *mocks.Service, *authnmocks.Authentication provider := new(oauth2mocks.Provider) provider.On("Name").Return(roleName) authn := new(authnmocks.Authentication) - httpapi.MakeHandler(svc, authn, mux, logger, "", idp) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + httpapi.MakeHandler(svc, am, mux, logger, "", idp) return httptest.NewServer(mux), svc, authn } @@ -922,7 +923,7 @@ func TestUpdateGroupTags(t *testing.T) { svcRes: groups.Group{}, authenticateErr: svcerr.ErrAuthorization, response: sdk.Group{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusForbidden), + err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthorization, http.StatusUnauthorized), }, { desc: "update group tags with empty token", diff --git a/pkg/sdk/journal_test.go b/pkg/sdk/journal_test.go index 2cfe933d0..4c4878f7d 100644 --- a/pkg/sdk/journal_test.go +++ b/pkg/sdk/journal_test.go @@ -29,7 +29,8 @@ func setupJournal() (*httptest.Server, *mocks.Service, *authnmocks.Authenticatio svc := new(mocks.Service) authn := new(authnmocks.Authentication) logger := smqlog.NewMock() - mux := api.MakeHandler(svc, authn, logger, "journal-log", "test") + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) + mux := api.MakeHandler(svc, am, logger, "journal-log", "test") return httptest.NewServer(mux), svc, authn } diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index 9398b6a3b..ba730215f 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -8040,6 +8040,65 @@ func (_c *SDK_SendMessage_Call) RunAndReturn(run func(ctx context.Context, domai return _c } +// SendVerification provides a mock function for the type SDK +func (_mock *SDK) SendVerification(ctx context.Context, token string) errors.SDKError { + ret := _mock.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for SendVerification") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok { + r0 = returnFunc(ctx, token) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification' +type SDK_SendVerification_Call struct { + *mock.Call +} + +// SendVerification is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *SDK_Expecter) SendVerification(ctx interface{}, token interface{}) *SDK_SendVerification_Call { + return &SDK_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, token)} +} + +func (_c *SDK_SendVerification_Call) Run(run func(ctx context.Context, token string)) *SDK_SendVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SDK_SendVerification_Call) Return(sDKError errors.SDKError) *SDK_SendVerification_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_SendVerification_Call) RunAndReturn(run func(ctx context.Context, token string) errors.SDKError) *SDK_SendVerification_Call { + _c.Call.Return(run) + return _c +} + // SetChannelParent provides a mock function for the type SDK func (_mock *SDK) SetChannelParent(ctx context.Context, id string, domainID string, groupID string, token string) errors.SDKError { ret := _mock.Called(ctx, id, domainID, groupID, token) @@ -9974,6 +10033,65 @@ func (_c *SDK_Users_Call) RunAndReturn(run func(ctx context.Context, pm sdk.Page return _c } +// VerifyEmail provides a mock function for the type SDK +func (_mock *SDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError { + ret := _mock.Called(ctx, verificationToken) + + if len(ret) == 0 { + panic("no return value specified for VerifyEmail") + } + + var r0 errors.SDKError + if returnFunc, ok := ret.Get(0).(func(context.Context, string) errors.SDKError); ok { + r0 = returnFunc(ctx, verificationToken) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + return r0 +} + +// SDK_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail' +type SDK_VerifyEmail_Call struct { + *mock.Call +} + +// VerifyEmail is a helper method to define mock.On call +// - ctx context.Context +// - verificationToken string +func (_e *SDK_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *SDK_VerifyEmail_Call { + return &SDK_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)} +} + +func (_c *SDK_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *SDK_VerifyEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *SDK_VerifyEmail_Call) Return(sDKError errors.SDKError) *SDK_VerifyEmail_Call { + _c.Call.Return(sDKError) + return _c +} + +func (_c *SDK_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) errors.SDKError) *SDK_VerifyEmail_Call { + _c.Call.Return(run) + return _c +} + // ViewCert provides a mock function for the type SDK func (_mock *SDK) ViewCert(ctx context.Context, certID string, domainID string, token string) (sdk.Cert, errors.SDKError) { ret := _mock.Called(ctx, certID, domainID, token) diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 4080d8c38..d3bc917c5 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -178,6 +178,20 @@ type SDK interface { // fmt.Println(user) CreateUser(ctx context.Context, user User, token string) (User, errors.SDKError) + // SendVerification sends a verification email to the user. + // + // example: + // err := sdk.SendVerification("token") + // fmt.Println(err) + SendVerification(ctx context.Context, token string) errors.SDKError + + // VerifyEmail verifies the user's email address using the provided token. + // + // example: + // err := sdk.VerifyEmail("verificationToken") + // fmt.Println(user) + VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError + // User returns user object by id. // // example: diff --git a/pkg/sdk/users.go b/pkg/sdk/users.go index aee38b522..611e30c77 100644 --- a/pkg/sdk/users.go +++ b/pkg/sdk/users.go @@ -15,13 +15,16 @@ import ( ) const ( - usersEndpoint = "users" - enableEndpoint = "enable" - disableEndpoint = "disable" - issueTokenEndpoint = "tokens/issue" - refreshTokenEndpoint = "tokens/refresh" - membersEndpoint = "members" - PasswordResetEndpoint = "password" + usersEndpoint = "users" + enableEndpoint = "enable" + disableEndpoint = "disable" + issueTokenEndpoint = "tokens/issue" + refreshTokenEndpoint = "tokens/refresh" + membersEndpoint = "members" + PasswordResetEndpoint = "password" + sendVerificationEndpoint = "send-verification" + verifyEmailEndpoint = "verify-email" + tokenQueryParamKey = "token" ) // User represents supermq user its credentials. @@ -61,6 +64,24 @@ func (sdk mgSDK) CreateUser(ctx context.Context, user User, token string) (User, return user, nil } +func (sdk mgSDK) SendVerification(ctx context.Context, token string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s", sdk.usersURL, usersEndpoint, sendVerificationEndpoint) + + _, _, sdkErr := sdk.processRequest(ctx, http.MethodPost, url, token, nil, nil, http.StatusOK) + + return sdkErr +} + +func (sdk mgSDK) VerifyEmail(ctx context.Context, verificationToken string) errors.SDKError { + url := fmt.Sprintf("%s/%s?%s=%s", sdk.usersURL, verifyEmailEndpoint, tokenQueryParamKey, verificationToken) + + _, _, sdkErr := sdk.processRequest(ctx, http.MethodGet, url, "", nil, nil, http.StatusOK) + if sdkErr != nil { + return sdkErr + } + return nil +} + func (sdk mgSDK) Users(ctx context.Context, pm PageMetadata, token string) (UsersPage, errors.SDKError) { url, err := sdk.withQueryParams(sdk.usersURL, usersEndpoint, pm) if err != nil { diff --git a/pkg/sdk/users_test.go b/pkg/sdk/users_test.go index 42d59c98d..22b31cdce 100644 --- a/pkg/sdk/users_test.go +++ b/pkg/sdk/users_test.go @@ -45,8 +45,9 @@ func setupUsers() (*httptest.Server, *umocks.Service, *authnmocks.Authentication provider := new(oauth2mocks.Provider) provider.On("Name").Return("test") authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithDomainCheck(false), smqauthn.WithAllowUnverifiedUser(true)) token := new(authmocks.TokenServiceClient) - httpapi.MakeHandler(usvc, authn, token, true, mux, logger, "", passRegex, idp, provider) + httpapi.MakeHandler(usvc, am, token, true, mux, logger, "", passRegex, idp, provider) return httptest.NewServer(mux), usvc, authn } @@ -567,6 +568,10 @@ func TestListUsers(t *testing.T) { resp, err := mgsdk.Users(context.Background(), tc.pageMeta, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) + if tc.token != "" { + ok := authCall.Parent.AssertCalled(t, "Authenticate", mock.Anything, tc.token) + assert.True(t, ok) + } if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "ListUsers", mock.Anything, tc.session, tc.svcReq) assert.True(t, ok) diff --git a/scripts/run.sh b/scripts/run.sh index 9b9c74041..eb351d9f9 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -38,7 +38,7 @@ done ### # Users ### -SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_EMAIL_TEMPLATE=../docker/templates/users.tmpl $BUILD_DIR/supermq-users & +SMQ_USERS_LOG_LEVEL=info SMQ_USERS_HTTP_PORT=9002 SMQ_USERS_GRPC_PORT=7001 SMQ_USERS_ADMIN_EMAIL=admin@supermq.com SMQ_USERS_ADMIN_PASSWORD=12345678 SMQ_USERS_ADMIN_USERNAME=admin SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=../docker/templates/reset-password-email.tmpl SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email SMQ_VERIFICATION_EMAIL_TEMPLATE=../docker/templates/verification-email.tmpl $BUILD_DIR/supermq-users & ### # Clients diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index 261873db4..8e3df8935 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -114,6 +114,7 @@ packages: github.com/absmach/supermq/pkg/authn: interfaces: Authentication: + AuthNMiddleware: github.com/absmach/supermq/pkg/authz: interfaces: Authorization: diff --git a/users/README.md b/users/README.md index c7d2c792e..28d0a26ea 100644 --- a/users/README.md +++ b/users/README.md @@ -12,48 +12,51 @@ For in-depth explanation of the aforementioned scenarios, as well as thorough un The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. -| Variable | Description | Default | -| ------------------------------ | ----------------------------------------------------------------------- | --------------------------------- | -| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info | -| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | | -| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 | -| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ | -| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost | -| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 | -| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" | -| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" | -| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" | -| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" | -| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 | -| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s | -| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" | -| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" | -| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" | -| SMQ_USERS_DB_HOST | Database host address | localhost | -| SMQ_USERS_DB_PORT | Database host port | 5432 | -| SMQ_USERS_DB_USER | Database user | supermq | -| SMQ_USERS_DB_PASS | Database password | supermq | -| SMQ_USERS_DB_NAME | Name of the database used by the service | users | -| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | -| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" | -| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | -| SMQ_EMAIL_HOST | Mail server host | localhost | -| SMQ_EMAIL_PORT | Mail server port | 25 | -| SMQ_EMAIL_USERNAME | Mail server username | "" | -| SMQ_EMAIL_PASSWORD | Mail server password | "" | -| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" | -| SMQ_EMAIL_FROM_NAME | Email "from" name | "" | -| SMQ_EMAIL_TEMPLATE | Email template for sending emails with password reset link | email.tmpl | -| SMQ_USERS_ES_URL | Event store URL | | -| SMQ_JAEGER_URL | Jaeger server URL | | -| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | | -| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | | -| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h | -| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h | -| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | -| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true | -| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" | +| Variable | Description | Default | +| --------------------------------- | ----------------------------------------------------------------------- | --------------------------------- | +| SMQ_USERS_LOG_LEVEL | Log level for users service (debug, info, warn, error) | info | +| SMQ_USERS_ADMIN_EMAIL | Default user, created on startup | | +| SMQ_USERS_ADMIN_PASSWORD | Default user password, created on startup | 12345678 | +| SMQ_USERS_PASS_REGEX | Password regex | ^.{8,}$ | +| SMQ_USERS_HTTP_HOST | Users service HTTP host | localhost | +| SMQ_USERS_HTTP_PORT | Users service HTTP port | 9002 | +| SMQ_USERS_HTTP_SERVER_CERT | Path to the PEM encoded server certificate file | "" | +| SMQ_USERS_HTTP_SERVER_KEY | Path to the PEM encoded server key file | "" | +| SMQ_USERS_HTTP_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" | +| SMQ_USERS_HTTP_CLIENT_CA_CERTS | Path to the PEM encoded client CA certificate file | "" | +| SMQ_AUTH_GRPC_URL | Auth service GRPC URL | localhost:8181 | +| SMQ_AUTH_GRPC_TIMEOUT | Auth service GRPC timeout | 1s | +| SMQ_AUTH_GRPC_CLIENT_CERT | Path to the PEM encoded client certificate file | "" | +| SMQ_AUTH_GRPC_CLIENT_KEY | Path to the PEM encoded client key file | "" | +| SMQ_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded server CA certificate file | "" | +| SMQ_USERS_DB_HOST | Database host address | localhost | +| SMQ_USERS_DB_PORT | Database host port | 5432 | +| SMQ_USERS_DB_USER | Database user | supermq | +| SMQ_USERS_DB_PASS | Database password | supermq | +| SMQ_USERS_DB_NAME | Name of the database used by the service | users | +| SMQ_USERS_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| SMQ_USERS_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| SMQ_USERS_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| SMQ_USERS_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| SMQ_EMAIL_HOST | Mail server host | localhost | +| SMQ_EMAIL_PORT | Mail server port | 25 | +| SMQ_EMAIL_USERNAME | Mail server username | "" | +| SMQ_EMAIL_PASSWORD | Mail server password | "" | +| SMQ_EMAIL_FROM_ADDRESS | Email "from" address | "" | +| SMQ_EMAIL_FROM_NAME | Email "from" name | "" | +| SMQ_PASSWORD_RESET_URL_PREFIX | Password reset URL prefix | | +| SMQ_PASSWORD_RESET_EMAIL_TEMPLATE | Password reset email template | reset-password-email.tmpl | +| SMQ_VERIFICATION_URL_PREFIX | Verification URL prefix | | +| SMQ_VERIFICATION_EMAIL_TEMPLATE | Verification email template | verification-email.tmpl | +| SMQ_USERS_ES_URL | Event store URL | | +| SMQ_JAEGER_URL | Jaeger server URL | | +| SMQ_OAUTH_UI_REDIRECT_URL | OAuth UI redirect URL | | +| SMQ_OAUTH_UI_ERROR_URL | OAuth UI error URL | | +| SMQ_USERS_DELETE_INTERVAL | Interval for deleting users | 24h | +| SMQ_USERS_DELETE_AFTER | Time after which users are deleted | 720h | +| SMQ_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | +| SMQ_SEND_TELEMETRY | Send telemetry to supermq call home server. | true | +| SMQ_USERS_INSTANCE_ID | SuperMQ instance ID | "" | ## Deployment @@ -104,7 +107,10 @@ SMQ_EMAIL_USERNAME="18bf7f7070513" \ SMQ_EMAIL_PASSWORD="2b0d302e775b1e" \ SMQ_EMAIL_FROM_ADDRESS=from@example.com \ SMQ_EMAIL_FROM_NAME=Example \ -SMQ_EMAIL_TEMPLATE="docker/templates/users.tmpl" \ +SMQ_PASSWORD_RESET_URL_PREFIX=http://localhost:9002/password/reset \ +SMQ_PASSWORD_RESET_EMAIL_TEMPLATE=docker/templates/reset-password-email.tmpl \ +SMQ_VERIFICATION_URL_PREFIX=http://localhost:9002/users/verify-email \ +SMQ_VERIFICATION_EMAIL_TEMPLATE=docker/templates/verification-email.tmpl \ SMQ_USERS_ES_URL=nats://localhost:4222 \ SMQ_JAEGER_URL=http://localhost:14268/api/traces \ SMQ_JAEGER_TRACE_RATIO=1.0 \ diff --git a/users/api/endpoint_test.go b/users/api/endpoint_test.go index 5c587b9cd..efdd9d53f 100644 --- a/users/api/endpoint_test.go +++ b/users/api/endpoint_test.go @@ -94,8 +94,9 @@ func newUsersServer() (*httptest.Server, *mocks.Service, *authnmocks.Authenticat provider := new(oauth2mocks.Provider) provider.On("Name").Return("test") authn := new(authnmocks.Authentication) + am := smqauthn.NewAuthNMiddleware(authn, smqauthn.WithAllowUnverifiedUser(true)) token := new(authmocks.TokenServiceClient) - usersapi.MakeHandler(svc, authn, token, true, mux, logger, "", passRegex, idp, provider) + usersapi.MakeHandler(svc, am, token, true, mux, logger, "", passRegex, idp, provider) return httptest.NewServer(mux), svc, authn } @@ -1750,6 +1751,151 @@ func TestPasswordResetRequest(t *testing.T) { } } +func TestSendVerification(t *testing.T) { + us, svc, authn := newUsersServer() + defer us.Close() + + cases := []struct { + desc string + token string + status int + authnRes smqauthn.Session + authnErr error + svcErr error + err error + }{ + { + desc: "send verification with valid token", + token: validToken, + status: http.StatusOK, + authnRes: smqauthn.Session{UserID: validID, DomainID: domainID}, + err: nil, + }, + { + desc: "send verification with invalid token", + token: inValidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + authnRes: smqauthn.Session{}, + err: svcerr.ErrAuthentication, + }, + { + desc: "send verification with empty token", + token: "", + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + authnRes: smqauthn.Session{}, + err: apiutil.ErrBearerToken, + }, + { + desc: "send verification with service error", + token: validToken, + status: http.StatusUnprocessableEntity, + authnRes: smqauthn.Session{UserID: validID, DomainID: domainID}, + svcErr: svcerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + user: us.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/users/send-verification", us.URL), + token: tc.token, + } + + authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("SendVerification", mock.Anything, tc.authnRes).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body, err := io.ReadAll(res.Body) + if err != nil { + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err)) + } + defer res.Body.Close() + var errRes respBody + if len(body) > 0 { + if err := json.Unmarshal(body, &errRes); err != nil { + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err)) + } + } + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authnCall.Unset() + }) + } +} + +func TestVerifyEmail(t *testing.T) { + us, svc, _ := newUsersServer() + defer us.Close() + + cases := []struct { + desc string + token string + status int + svcErr error + err error + }{ + { + desc: "verify email with valid token", + token: validToken, + status: http.StatusOK, + err: nil, + }, + { + desc: "verify email with empty token", + token: "", + status: http.StatusBadRequest, + err: apiutil.ErrValidation, + }, + { + desc: "verify email with service error", + token: validToken, + status: http.StatusBadRequest, + svcErr: svcerr.ErrMalformedEntity, + err: svcerr.ErrMalformedEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + user: us.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/verify-email?token=%s", us.URL, tc.token), + } + + svcCall := svc.On("VerifyEmail", mock.Anything, mock.Anything).Return(users.User{}, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + body, err := io.ReadAll(res.Body) + if err != nil { + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while reading response body: %s", tc.desc, err)) + } + defer res.Body.Close() + var errRes respBody + if len(body) > 0 { + if err := json.Unmarshal(body, &errRes); err != nil { + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while unmarshal response body: %s", tc.desc, err)) + } + } + if errRes.Err != "" || errRes.Message != "" { + err = errors.Wrap(errors.New(errRes.Err), errors.New(errRes.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.err, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + }) + } +} + func TestPasswordReset(t *testing.T) { us, svc, authn := newUsersServer() defer us.Close() diff --git a/users/api/endpoints.go b/users/api/endpoints.go index 32e462405..60b5910db 100644 --- a/users/api/endpoints.go +++ b/users/api/endpoints.go @@ -6,7 +6,6 @@ package api import ( "context" - api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" @@ -25,7 +24,7 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin var ok bool if !selfRegister { - session, ok = ctx.Value(api.SessionKey).(authn.Session) + session, ok = ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -43,6 +42,38 @@ func registrationEndpoint(svc users.Service, selfRegister bool) endpoint.Endpoin } } +func sendVerificationEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + _ = request.(sendVerificationReq) + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + if err := svc.SendVerification(ctx, session); err != nil { + return sendVerificationRes{}, err + } + + return sendVerificationRes{}, nil + } +} + +func verifyEmailEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(verifyEmailReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + if _, err := svc.VerifyEmail(ctx, req.token); err != nil { + return verifyEmailRes{}, err + } + + return verifyEmailRes{}, nil + } +} + func viewEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(viewUserReq) @@ -50,7 +81,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -65,7 +96,7 @@ func viewEndpoint(svc users.Service) endpoint.Endpoint { func viewProfileEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -85,7 +116,7 @@ func listUsersEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -172,7 +203,7 @@ func updateEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -199,7 +230,7 @@ func updateTagsEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -224,7 +255,7 @@ func updateEmailEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -273,7 +304,7 @@ func passwordResetEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -292,7 +323,7 @@ func updateSecretEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -312,7 +343,7 @@ func updateUsernameEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -337,7 +368,7 @@ func updateProfilePictureEndpoint(svc users.Service) endpoint.Endpoint { ProfilePicture: req.ProfilePicture, } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthorization } @@ -363,7 +394,7 @@ func updateRoleEndpoint(svc users.Service) endpoint.Endpoint { Role: req.role, } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -404,7 +435,7 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -429,7 +460,7 @@ func enableEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -450,7 +481,7 @@ func disableEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } @@ -471,7 +502,7 @@ func deleteEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - session, ok := ctx.Value(api.SessionKey).(authn.Session) + session, ok := ctx.Value(authn.SessionKey).(authn.Session) if !ok { return nil, svcerr.ErrAuthentication } diff --git a/users/api/requests.go b/users/api/requests.go index 137f87d0e..6b620e54f 100644 --- a/users/api/requests.go +++ b/users/api/requests.go @@ -4,7 +4,6 @@ package api import ( - "net/mail" "net/url" api "github.com/absmach/supermq/api/http" @@ -39,15 +38,16 @@ func (req createUserReq) validate() error { return err } // Username must not be a valid email format due to username/email login. - if _, err := mail.ParseAddress(req.User.Credentials.Username); err == nil { + if err := api.ValidateEmail(req.User.Credentials.Username); err == nil { return apiutil.ErrInvalidUsername } + if req.User.Email == "" { return apiutil.ErrMissingEmail } // Email must be in a valid format. - if _, err := mail.ParseAddress(req.User.Email); err != nil { - return apiutil.ErrInvalidEmail + if err := api.ValidateEmail(req.User.Email); err != nil { + return err } if req.User.Credentials.Secret == "" { return apiutil.ErrMissingPass @@ -67,6 +67,20 @@ func (req createUserReq) validate() error { return req.User.Validate() } +type sendVerificationReq struct{} + +type verifyEmailReq struct { + token string +} + +func (req verifyEmailReq) validate() error { + if req.token == "" { + return apiutil.ErrInvalidVerification + } + + return nil +} + type viewUserReq struct { id string } @@ -183,8 +197,8 @@ func (req updateEmailReq) validate() error { if req.id == "" { return apiutil.ErrMissingID } - if _, err := mail.ParseAddress(req.Email); err != nil { - return apiutil.ErrInvalidEmail + if err := api.ValidateEmail(req.Email); err != nil { + return err } return nil diff --git a/users/api/responses.go b/users/api/responses.go index 187da746f..ccaece90c 100644 --- a/users/api/responses.go +++ b/users/api/responses.go @@ -16,6 +16,8 @@ const MailSent = "Email with reset link is sent" var ( _ supermq.Response = (*tokenRes)(nil) + _ supermq.Response = (*sendVerificationRes)(nil) + _ supermq.Response = (*verifyEmailRes)(nil) _ supermq.Response = (*viewUserRes)(nil) _ supermq.Response = (*createUserRes)(nil) _ supermq.Response = (*changeUserStatusRes)(nil) @@ -79,6 +81,34 @@ func (res tokenRes) Empty() bool { return res.AccessToken == "" || res.RefreshToken == "" } +type sendVerificationRes struct{} + +func (res sendVerificationRes) Code() int { + return http.StatusOK +} + +func (res sendVerificationRes) Headers() map[string]string { + return map[string]string{} +} + +func (res sendVerificationRes) Empty() bool { + return true +} + +type verifyEmailRes struct{} + +func (res verifyEmailRes) Code() int { + return http.StatusOK +} + +func (res verifyEmailRes) Headers() map[string]string { + return map[string]string{} +} + +func (res verifyEmailRes) Empty() bool { + return true +} + type updateUserRes struct { users.User `json:",inline"` } diff --git a/users/api/transport.go b/users/api/transport.go index 287615113..1abd74fcc 100644 --- a/users/api/transport.go +++ b/users/api/transport.go @@ -18,7 +18,7 @@ import ( ) // MakeHandler returns a HTTP handler for Users and Groups API endpoints. -func MakeHandler(cls users.Service, authn smqauthn.Authentication, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler { +func MakeHandler(cls users.Service, authn smqauthn.AuthNMiddleware, tokensvc grpcTokenV1.TokenServiceClient, selfRegister bool, mux *chi.Mux, logger *slog.Logger, instanceID string, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) http.Handler { mux = usersHandler(cls, authn, tokensvc, selfRegister, mux, logger, pr, idp, providers...) mux.Get("/health", supermq.Health("users", instanceID)) diff --git a/users/api/users.go b/users/api/users.go index b0305e4cc..c6cd91a81 100644 --- a/users/api/users.go +++ b/users/api/users.go @@ -28,13 +28,15 @@ import ( var passRegex = regexp.MustCompile("^.{8,}$") // usersHandler returns a HTTP handler for API endpoints. -func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux { +func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient grpcTokenV1.TokenServiceClient, selfRegister bool, r *chi.Mux, logger *slog.Logger, pr *regexp.Regexp, idp supermq.IDProvider, providers ...oauth2.Provider) *chi.Mux { passRegex = pr opts := []kithttp.ServerOption{ kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), } + // All endpoints in users service don't required Domain check + authn = authn.WithOptions(smqauthn.WithDomainCheck(false)) r.Route("/users", func(r chi.Router) { r.Use(api.RequestIDMiddleware(idp)) @@ -47,16 +49,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient opts..., ), "register_user").ServeHTTP) default: - r.With(api.AuthenticateMiddleware(authn, false)).Post("/", otelhttp.NewHandler(kithttp.NewServer( + r.With(authn.Middleware()).Post("/", otelhttp.NewHandler(kithttp.NewServer( registrationEndpoint(svc, selfRegister), decodeCreateUserReq, api.EncodeResponse, opts..., ), "register_user").ServeHTTP) } - + // Endpoints which are allowed for unverified user r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, false)) + r.Use(authn.WithOptions(smqauthn.WithAllowUnverifiedUser(true)).Middleware()) + r.Post("/send-verification", otelhttp.NewHandler(kithttp.NewServer( + sendVerificationEndpoint(svc), + decodeSendVerification, + api.EncodeResponse, + opts..., + ), "send_verification").ServeHTTP) r.Get("/profile", otelhttp.NewHandler(kithttp.NewServer( viewProfileEndpoint(svc), @@ -64,6 +72,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient api.EncodeResponse, opts..., ), "view_profile").ServeHTTP) + r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer( + refreshTokenEndpoint(svc), + decodeRefreshToken, + api.EncodeResponse, + opts..., + ), "refresh_token").ServeHTTP) + r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer( + updateEmailEndpoint(svc), + decodeUpdateUserEmail, + api.EncodeResponse, + opts..., + ), "update_user_email").ServeHTTP) + }) + + r.Group(func(r chi.Router) { + r.Use(authn.Middleware()) r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( viewEndpoint(svc), @@ -121,13 +145,6 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient opts..., ), "update_user_tags").ServeHTTP) - r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer( - updateEmailEndpoint(svc), - decodeUpdateUserEmail, - api.EncodeResponse, - opts..., - ), "update_user_email").ServeHTTP) - r.Patch("/{id}/role", otelhttp.NewHandler(kithttp.NewServer( updateRoleEndpoint(svc), decodeUpdateUserRole, @@ -155,18 +172,11 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient api.EncodeResponse, opts..., ), "delete_user").ServeHTTP) - - r.Post("/tokens/refresh", otelhttp.NewHandler(kithttp.NewServer( - refreshTokenEndpoint(svc), - decodeRefreshToken, - api.EncodeResponse, - opts..., - ), "refresh_token").ServeHTTP) }) }) r.Group(func(r chi.Router) { - r.Use(api.AuthenticateMiddleware(authn, false)) + r.Use(authn.Middleware()) r.Put("/password/reset", otelhttp.NewHandler(kithttp.NewServer( passwordResetEndpoint(svc), decodePasswordReset, @@ -189,6 +199,13 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient opts..., ), "password_reset_req").ServeHTTP) + r.Get("/verify-email", otelhttp.NewHandler(kithttp.NewServer( + verifyEmailEndpoint(svc), + decodeVerifyEmail, + api.EncodeResponse, + opts..., + ), "verify_email").ServeHTTP) + for _, provider := range providers { r.HandleFunc("/oauth/callback/"+provider.Name(), oauth2CallbackHandler(provider, svc, tokenClient)) } @@ -196,6 +213,22 @@ func usersHandler(svc users.Service, authn smqauthn.Authentication, tokenClient return r } +func decodeSendVerification(_ context.Context, r *http.Request) (any, error) { + req := sendVerificationReq{} + return req, nil +} + +func decodeVerifyEmail(_ context.Context, r *http.Request) (any, error) { + token, err := apiutil.ReadStringQuery(r, api.TokenKey, "") + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + return verifyEmailReq{ + token: token, + }, nil +} + func decodeViewUser(_ context.Context, r *http.Request) (any, error) { req := viewUserReq{ id: chi.URLParam(r, "id"), diff --git a/users/emailer.go b/users/emailer.go index f06217f93..4b93df933 100644 --- a/users/emailer.go +++ b/users/emailer.go @@ -7,4 +7,7 @@ package users type Emailer interface { // SendPasswordReset sends an email to the user with a link to reset the password. SendPasswordReset(To []string, user, token string) error + + // SendVerification sends an email to the user with a verification token. + SendVerification(To []string, user, verificationToken string) error } diff --git a/users/emailer/emailer.go b/users/emailer/emailer.go index c0ba25f44..6cf4bdfb3 100644 --- a/users/emailer/emailer.go +++ b/users/emailer/emailer.go @@ -13,20 +13,38 @@ import ( var _ users.Emailer = (*emailer)(nil) type emailer struct { - resetURL string - agent *email.Agent + resetURL string + verificationURL string + resetAgent *email.Agent + verifyAgent *email.Agent } // New creates new emailer utility. -func New(url string, c *email.Config) (users.Emailer, error) { - e, err := email.New(c) +func New(resetURL, verificationURL string, resetConfig, verifyConfig *email.Config) (users.Emailer, error) { + resetAgent, err := email.New(resetConfig) + if err != nil { + return nil, err + } + + verifyAgent, err := email.New(verifyConfig) + if err != nil { + return nil, err + } + return &emailer{ - resetURL: url, - agent: e, - }, err + resetURL: resetURL, + verificationURL: verificationURL, + resetAgent: resetAgent, + verifyAgent: verifyAgent, + }, nil } func (e *emailer) SendPasswordReset(to []string, user, token string) error { url := fmt.Sprintf("%s?token=%s", e.resetURL, token) - return e.agent.Send(to, "", "Password Reset Request", "", user, url, "") + return e.resetAgent.Send(to, "", "Password Reset Request", "", user, url, "") +} + +func (e *emailer) SendVerification(to []string, user, verificationToken string) error { + url := fmt.Sprintf("%s?token=%s", e.verificationURL, verificationToken) + return e.verifyAgent.Send(to, "", "Email Verification", "", user, url, "") } diff --git a/users/events/events.go b/users/events/events.go index db35f28fb..71d95e243 100644 --- a/users/events/events.go +++ b/users/events/events.go @@ -14,6 +14,8 @@ import ( const ( userPrefix = "user." userCreate = userPrefix + "create" + userSendVerification = userPrefix + "send_verification" + userVerifyEmail = userPrefix + "verify_email" userUpdate = userPrefix + "update" userUpdateRole = userPrefix + "update_role" userUpdateTags = userPrefix + "update_tags" @@ -41,6 +43,8 @@ const ( var ( _ events.Event = (*createUserEvent)(nil) + _ events.Event = (*sendVerificationEvent)(nil) + _ events.Event = (*verifyEmailEvent)(nil) _ events.Event = (*updateUserEvent)(nil) _ events.Event = (*updateProfilePictureEvent)(nil) _ events.Event = (*updateUsernameEvent)(nil) @@ -99,6 +103,37 @@ func (uce createUserEvent) Encode() (map[string]any, error) { return val, nil } +type sendVerificationEvent struct { + authn.Session + requestID string +} + +func (sve sendVerificationEvent) Encode() (map[string]any, error) { + return map[string]any{ + "operation": userSendVerification, + "user_id": sve.UserID, + "token_type": sve.Type.String(), + "request_id": sve.requestID, + }, nil +} + +type verifyEmailEvent struct { + requestID string + email string + userID string + verifiedAt time.Time +} + +func (vee verifyEmailEvent) Encode() (map[string]any, error) { + return map[string]any{ + "operation": userVerifyEmail, + "request_id": vee.requestID, + "email": vee.email, + "user_id": vee.userID, + "verified_at": vee.verifiedAt, + }, nil +} + type updateUserEvent struct { users.User operation string diff --git a/users/events/streams.go b/users/events/streams.go index 5b44028c2..ad3c5e376 100644 --- a/users/events/streams.go +++ b/users/events/streams.go @@ -17,6 +17,8 @@ import ( const ( supermqPrefix = "supermq." createStream = supermqPrefix + userCreate + sendVerificationStream = supermqPrefix + userSendVerification + verifyEmailStream = supermqPrefix + userVerifyEmail updateStream = supermqPrefix + userUpdate updateRoleStream = supermqPrefix + userUpdateRole updateTagsStream = supermqPrefix + userUpdateTags @@ -82,6 +84,38 @@ func (es *eventStore) Register(ctx context.Context, session authn.Session, user return user, nil } +func (es *eventStore) SendVerification(ctx context.Context, session authn.Session) error { + err := es.svc.SendVerification(ctx, session) + if err != nil { + return err + } + + event := sendVerificationEvent{ + session, + middleware.GetReqID(ctx), + } + + return es.Publish(ctx, sendVerificationStream, event) +} + +func (es *eventStore) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) { + user, err := es.svc.VerifyEmail(ctx, verificationToken) + if err != nil { + return user, err + } + + event := verifyEmailEvent{ + email: user.Email, + userID: user.ID, + verifiedAt: user.VerifiedAt, + requestID: middleware.GetReqID(ctx), + } + if err := es.Publish(ctx, verifyEmailStream, event); err != nil { + return user, err + } + return user, nil +} + func (es *eventStore) Update(ctx context.Context, session authn.Session, id string, usr users.UserReq) (users.User, error) { user, err := es.svc.Update(ctx, session, id, usr) if err != nil { diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index 608189e77..c244a8846 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -26,11 +26,15 @@ type authorizationMiddleware struct { // AuthorizationMiddleware adds authorization to the clients service. func AuthorizationMiddleware(svc users.Service, authz smqauthz.Authorization, selfRegister bool) users.Service { - return &authorizationMiddleware{ - svc: svc, - authz: authz, - selfRegister: selfRegister, - } + return &authorizationMiddleware{svc: svc, authz: authz, selfRegister: selfRegister} +} + +func (am *authorizationMiddleware) SendVerification(ctx context.Context, session authn.Session) error { + return am.svc.SendVerification(ctx, session) +} + +func (am *authorizationMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) { + return am.svc.VerifyEmail(ctx, verificationToken) } func (am *authorizationMiddleware) Register(ctx context.Context, session authn.Session, user users.User, selfRegister bool) (users.User, error) { diff --git a/users/middleware/logging.go b/users/middleware/logging.go index 59e815a2d..3424e7c08 100644 --- a/users/middleware/logging.go +++ b/users/middleware/logging.go @@ -50,6 +50,43 @@ func (lm *loggingMiddleware) Register(ctx context.Context, session authn.Session return lm.svc.Register(ctx, session, user, selfRegister) } +// SendVerification logs the send_verification request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) SendVerification(ctx context.Context, session authn.Session) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Send verification failed", args...) + return + } + lm.logger.Info("Send verification completed successfully", args...) + }(time.Now()) + return lm.svc.SendVerification(ctx, session) +} + +// VerifyEmail logs the verify_email request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (user users.User, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", user.ID), + slog.String("request_id", middleware.GetReqID(ctx)), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Verify email failed", args...) + return + } + lm.logger.Info("Verify email completed successfully", args...) + }(time.Now()) + return lm.svc.VerifyEmail(ctx, verificationToken) +} + // IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request. // If the request fails, it logs the error. func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) { diff --git a/users/middleware/metrics.go b/users/middleware/metrics.go index 22cd70627..dd4d651fd 100644 --- a/users/middleware/metrics.go +++ b/users/middleware/metrics.go @@ -39,6 +39,24 @@ func (ms *metricsMiddleware) Register(ctx context.Context, session authn.Session return ms.svc.Register(ctx, session, user, selfRegister) } +// SendVerification instruments SendVerification method with metrics. +func (ms *metricsMiddleware) SendVerification(ctx context.Context, session authn.Session) error { + defer func(begin time.Time) { + ms.counter.With("method", "send_verification").Add(1) + ms.latency.With("method", "send_verification").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.SendVerification(ctx, session) +} + +// VerifyEmail instruments VerifyEmail method with metrics. +func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) { + defer func(begin time.Time) { + ms.counter.With("method", "verify_email").Add(1) + ms.latency.With("method", "verify_email").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.VerifyEmail(ctx, verificationToken) +} + // IssueToken instruments IssueToken method with metrics. func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { defer func(begin time.Time) { diff --git a/users/mocks/emailer.go b/users/mocks/emailer.go index 999810af9..e9e562333 100644 --- a/users/mocks/emailer.go +++ b/users/mocks/emailer.go @@ -100,3 +100,66 @@ func (_c *Emailer_SendPasswordReset_Call) RunAndReturn(run func(To []string, use _c.Call.Return(run) return _c } + +// SendVerification provides a mock function for the type Emailer +func (_mock *Emailer) SendVerification(To []string, user string, verificationToken string) error { + ret := _mock.Called(To, user, verificationToken) + + if len(ret) == 0 { + panic("no return value specified for SendVerification") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func([]string, string, string) error); ok { + r0 = returnFunc(To, user, verificationToken) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Emailer_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification' +type Emailer_SendVerification_Call struct { + *mock.Call +} + +// SendVerification is a helper method to define mock.On call +// - To []string +// - user string +// - verificationToken string +func (_e *Emailer_Expecter) SendVerification(To interface{}, user interface{}, verificationToken interface{}) *Emailer_SendVerification_Call { + return &Emailer_SendVerification_Call{Call: _e.mock.On("SendVerification", To, user, verificationToken)} +} + +func (_c *Emailer_SendVerification_Call) Run(run func(To []string, user string, verificationToken string)) *Emailer_SendVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 []string + if args[0] != nil { + arg0 = args[0].([]string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Emailer_SendVerification_Call) Return(err error) *Emailer_SendVerification_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Emailer_SendVerification_Call) RunAndReturn(run func(To []string, user string, verificationToken string) error) *Emailer_SendVerification_Call { + _c.Call.Return(run) + return _c +} diff --git a/users/mocks/repository.go b/users/mocks/repository.go index 068f86688..ba2e5ac94 100644 --- a/users/mocks/repository.go +++ b/users/mocks/repository.go @@ -41,6 +41,63 @@ func (_m *Repository) EXPECT() *Repository_Expecter { return &Repository_Expecter{mock: &_m.Mock} } +// AddUserVerification provides a mock function for the type Repository +func (_mock *Repository) AddUserVerification(ctx context.Context, uv users.UserVerification) error { + ret := _mock.Called(ctx, uv) + + if len(ret) == 0 { + panic("no return value specified for AddUserVerification") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok { + r0 = returnFunc(ctx, uv) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_AddUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddUserVerification' +type Repository_AddUserVerification_Call struct { + *mock.Call +} + +// AddUserVerification is a helper method to define mock.On call +// - ctx context.Context +// - uv users.UserVerification +func (_e *Repository_Expecter) AddUserVerification(ctx interface{}, uv interface{}) *Repository_AddUserVerification_Call { + return &Repository_AddUserVerification_Call{Call: _e.mock.On("AddUserVerification", ctx, uv)} +} + +func (_c *Repository_AddUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_AddUserVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 users.UserVerification + if args[1] != nil { + arg1 = args[1].(users.UserVerification) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_AddUserVerification_Call) Return(err error) *Repository_AddUserVerification_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_AddUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_AddUserVerification_Call { + _c.Call.Return(run) + return _c +} + // ChangeStatus provides a mock function for the type Repository func (_mock *Repository) ChangeStatus(ctx context.Context, user users.User) (users.User, error) { ret := _mock.Called(ctx, user) @@ -551,6 +608,78 @@ func (_c *Repository_RetrieveByUsername_Call) RunAndReturn(run func(ctx context. return _c } +// RetrieveUserVerification provides a mock function for the type Repository +func (_mock *Repository) RetrieveUserVerification(ctx context.Context, userID string, email string) (users.UserVerification, error) { + ret := _mock.Called(ctx, userID, email) + + if len(ret) == 0 { + panic("no return value specified for RetrieveUserVerification") + } + + var r0 users.UserVerification + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (users.UserVerification, error)); ok { + return returnFunc(ctx, userID, email) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) users.UserVerification); ok { + r0 = returnFunc(ctx, userID, email) + } else { + r0 = ret.Get(0).(users.UserVerification) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = returnFunc(ctx, userID, email) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_RetrieveUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveUserVerification' +type Repository_RetrieveUserVerification_Call struct { + *mock.Call +} + +// RetrieveUserVerification is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - email string +func (_e *Repository_Expecter) RetrieveUserVerification(ctx interface{}, userID interface{}, email interface{}) *Repository_RetrieveUserVerification_Call { + return &Repository_RetrieveUserVerification_Call{Call: _e.mock.On("RetrieveUserVerification", ctx, userID, email)} +} + +func (_c *Repository_RetrieveUserVerification_Call) Run(run func(ctx context.Context, userID string, email string)) *Repository_RetrieveUserVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Repository_RetrieveUserVerification_Call) Return(userVerification users.UserVerification, err error) *Repository_RetrieveUserVerification_Call { + _c.Call.Return(userVerification, err) + return _c +} + +func (_c *Repository_RetrieveUserVerification_Call) RunAndReturn(run func(ctx context.Context, userID string, email string) (users.UserVerification, error)) *Repository_RetrieveUserVerification_Call { + _c.Call.Return(run) + return _c +} + // Save provides a mock function for the type Repository func (_mock *Repository) Save(ctx context.Context, user users.User) (users.User, error) { ret := _mock.Called(ctx, user) @@ -755,6 +884,138 @@ func (_c *Repository_Update_Call) RunAndReturn(run func(ctx context.Context, id return _c } +// UpdateEmail provides a mock function for the type Repository +func (_mock *Repository) UpdateEmail(ctx context.Context, user users.User) (users.User, error) { + ret := _mock.Called(ctx, user) + + if len(ret) == 0 { + panic("no return value specified for UpdateEmail") + } + + var r0 users.User + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok { + return returnFunc(ctx, user) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok { + r0 = returnFunc(ctx, user) + } else { + r0 = ret.Get(0).(users.User) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok { + r1 = returnFunc(ctx, user) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateEmail' +type Repository_UpdateEmail_Call struct { + *mock.Call +} + +// UpdateEmail is a helper method to define mock.On call +// - ctx context.Context +// - user users.User +func (_e *Repository_Expecter) UpdateEmail(ctx interface{}, user interface{}) *Repository_UpdateEmail_Call { + return &Repository_UpdateEmail_Call{Call: _e.mock.On("UpdateEmail", ctx, user)} +} + +func (_c *Repository_UpdateEmail_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 users.User + if args[1] != nil { + arg1 = args[1].(users.User) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateEmail_Call) Return(user1 users.User, err error) *Repository_UpdateEmail_Call { + _c.Call.Return(user1, err) + return _c +} + +func (_c *Repository_UpdateEmail_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateEmail_Call { + _c.Call.Return(run) + return _c +} + +// UpdateRole provides a mock function for the type Repository +func (_mock *Repository) UpdateRole(ctx context.Context, user users.User) (users.User, error) { + ret := _mock.Called(ctx, user) + + if len(ret) == 0 { + panic("no return value specified for UpdateRole") + } + + var r0 users.User + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok { + return returnFunc(ctx, user) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok { + r0 = returnFunc(ctx, user) + } else { + r0 = ret.Get(0).(users.User) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok { + r1 = returnFunc(ctx, user) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateRole' +type Repository_UpdateRole_Call struct { + *mock.Call +} + +// UpdateRole is a helper method to define mock.On call +// - ctx context.Context +// - user users.User +func (_e *Repository_Expecter) UpdateRole(ctx interface{}, user interface{}) *Repository_UpdateRole_Call { + return &Repository_UpdateRole_Call{Call: _e.mock.On("UpdateRole", ctx, user)} +} + +func (_c *Repository_UpdateRole_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateRole_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 users.User + if args[1] != nil { + arg1 = args[1].(users.User) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateRole_Call) Return(user1 users.User, err error) *Repository_UpdateRole_Call { + _c.Call.Return(user1, err) + return _c +} + +func (_c *Repository_UpdateRole_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateRole_Call { + _c.Call.Return(run) + return _c +} + // UpdateSecret provides a mock function for the type Repository func (_mock *Repository) UpdateSecret(ctx context.Context, user users.User) (users.User, error) { ret := _mock.Called(ctx, user) @@ -821,6 +1082,63 @@ func (_c *Repository_UpdateSecret_Call) RunAndReturn(run func(ctx context.Contex return _c } +// UpdateUserVerification provides a mock function for the type Repository +func (_mock *Repository) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error { + ret := _mock.Called(ctx, uv) + + if len(ret) == 0 { + panic("no return value specified for UpdateUserVerification") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, users.UserVerification) error); ok { + r0 = returnFunc(ctx, uv) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Repository_UpdateUserVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateUserVerification' +type Repository_UpdateUserVerification_Call struct { + *mock.Call +} + +// UpdateUserVerification is a helper method to define mock.On call +// - ctx context.Context +// - uv users.UserVerification +func (_e *Repository_Expecter) UpdateUserVerification(ctx interface{}, uv interface{}) *Repository_UpdateUserVerification_Call { + return &Repository_UpdateUserVerification_Call{Call: _e.mock.On("UpdateUserVerification", ctx, uv)} +} + +func (_c *Repository_UpdateUserVerification_Call) Run(run func(ctx context.Context, uv users.UserVerification)) *Repository_UpdateUserVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 users.UserVerification + if args[1] != nil { + arg1 = args[1].(users.UserVerification) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateUserVerification_Call) Return(err error) *Repository_UpdateUserVerification_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Repository_UpdateUserVerification_Call) RunAndReturn(run func(ctx context.Context, uv users.UserVerification) error) *Repository_UpdateUserVerification_Call { + _c.Call.Return(run) + return _c +} + // UpdateUsername provides a mock function for the type Repository func (_mock *Repository) UpdateUsername(ctx context.Context, user users.User) (users.User, error) { ret := _mock.Called(ctx, user) @@ -886,3 +1204,69 @@ func (_c *Repository_UpdateUsername_Call) RunAndReturn(run func(ctx context.Cont _c.Call.Return(run) return _c } + +// UpdateVerifiedAt provides a mock function for the type Repository +func (_mock *Repository) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) { + ret := _mock.Called(ctx, user) + + if len(ret) == 0 { + panic("no return value specified for UpdateVerifiedAt") + } + + var r0 users.User + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) (users.User, error)); ok { + return returnFunc(ctx, user) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, users.User) users.User); ok { + r0 = returnFunc(ctx, user) + } else { + r0 = ret.Get(0).(users.User) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, users.User) error); ok { + r1 = returnFunc(ctx, user) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Repository_UpdateVerifiedAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateVerifiedAt' +type Repository_UpdateVerifiedAt_Call struct { + *mock.Call +} + +// UpdateVerifiedAt is a helper method to define mock.On call +// - ctx context.Context +// - user users.User +func (_e *Repository_Expecter) UpdateVerifiedAt(ctx interface{}, user interface{}) *Repository_UpdateVerifiedAt_Call { + return &Repository_UpdateVerifiedAt_Call{Call: _e.mock.On("UpdateVerifiedAt", ctx, user)} +} + +func (_c *Repository_UpdateVerifiedAt_Call) Run(run func(ctx context.Context, user users.User)) *Repository_UpdateVerifiedAt_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 users.User + if args[1] != nil { + arg1 = args[1].(users.User) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Repository_UpdateVerifiedAt_Call) Return(user1 users.User, err error) *Repository_UpdateVerifiedAt_Call { + _c.Call.Return(user1, err) + return _c +} + +func (_c *Repository_UpdateVerifiedAt_Call) RunAndReturn(run func(ctx context.Context, user users.User) (users.User, error)) *Repository_UpdateVerifiedAt_Call { + _c.Call.Return(run) + return _c +} diff --git a/users/mocks/service.go b/users/mocks/service.go index 50ccda2ca..087e04e05 100644 --- a/users/mocks/service.go +++ b/users/mocks/service.go @@ -923,6 +923,63 @@ func (_c *Service_SendPasswordReset_Call) RunAndReturn(run func(ctx context.Cont return _c } +// SendVerification provides a mock function for the type Service +func (_mock *Service) SendVerification(ctx context.Context, session authn.Session) error { + ret := _mock.Called(ctx, session) + + if len(ret) == 0 { + panic("no return value specified for SendVerification") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) error); ok { + r0 = returnFunc(ctx, session) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_SendVerification_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SendVerification' +type Service_SendVerification_Call struct { + *mock.Call +} + +// SendVerification is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +func (_e *Service_Expecter) SendVerification(ctx interface{}, session interface{}) *Service_SendVerification_Call { + return &Service_SendVerification_Call{Call: _e.mock.On("SendVerification", ctx, session)} +} + +func (_c *Service_SendVerification_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_SendVerification_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_SendVerification_Call) Return(err error) *Service_SendVerification_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_SendVerification_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) error) *Service_SendVerification_Call { + _c.Call.Return(run) + return _c +} + // Update provides a mock function for the type Service func (_mock *Service) Update(ctx context.Context, session authn.Session, id string, user users.UserReq) (users.User, error) { ret := _mock.Called(ctx, session, id, user) @@ -1463,6 +1520,72 @@ func (_c *Service_UpdateUsername_Call) RunAndReturn(run func(ctx context.Context return _c } +// VerifyEmail provides a mock function for the type Service +func (_mock *Service) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) { + ret := _mock.Called(ctx, verificationToken) + + if len(ret) == 0 { + panic("no return value specified for VerifyEmail") + } + + var r0 users.User + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (users.User, error)); ok { + return returnFunc(ctx, verificationToken) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) users.User); ok { + r0 = returnFunc(ctx, verificationToken) + } else { + r0 = ret.Get(0).(users.User) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, verificationToken) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_VerifyEmail_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'VerifyEmail' +type Service_VerifyEmail_Call struct { + *mock.Call +} + +// VerifyEmail is a helper method to define mock.On call +// - ctx context.Context +// - verificationToken string +func (_e *Service_Expecter) VerifyEmail(ctx interface{}, verificationToken interface{}) *Service_VerifyEmail_Call { + return &Service_VerifyEmail_Call{Call: _e.mock.On("VerifyEmail", ctx, verificationToken)} +} + +func (_c *Service_VerifyEmail_Call) Run(run func(ctx context.Context, verificationToken string)) *Service_VerifyEmail_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_VerifyEmail_Call) Return(user users.User, err error) *Service_VerifyEmail_Call { + _c.Call.Return(user, err) + return _c +} + +func (_c *Service_VerifyEmail_Call) RunAndReturn(run func(ctx context.Context, verificationToken string) (users.User, error)) *Service_VerifyEmail_Call { + _c.Call.Return(run) + return _c +} + // View provides a mock function for the type Service func (_mock *Service) View(ctx context.Context, session authn.Session, id string) (users.User, error) { ret := _mock.Called(ctx, session, id) diff --git a/users/postgres/init.go b/users/postgres/init.go index 4e8ac0d8d..29c1a19c8 100644 --- a/users/postgres/init.go +++ b/users/postgres/init.go @@ -14,7 +14,7 @@ func Migration() *migrate.MemoryMigrationSource { Migrations: []*migrate.Migration{ { Id: "clients_01", - // VARCHAR(36) for colums with IDs as UUIDS have a maximum of 36 characters + // VARCHAR(36) for column with IDs as UUIDS have a maximum of 36 characters // STATUS 0 to imply enabled and 1 to imply disabled // Role 0 to imply user role and 1 to imply admin role Up: []string{ @@ -50,8 +50,8 @@ func Migration() *migrate.MemoryMigrationSource { Up: []string{ `ALTER TABLE clients ADD COLUMN username VARCHAR(254) UNIQUE, - ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '', - ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '', + ADD COLUMN first_name VARCHAR(254) NOT NULL DEFAULT '', + ADD COLUMN last_name VARCHAR(254) NOT NULL DEFAULT '', ADD COLUMN profile_picture TEXT`, `ALTER TABLE clients RENAME COLUMN identity TO email`, `ALTER TABLE clients DROP COLUMN name`, @@ -97,6 +97,27 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE users ALTER COLUMN updated_at TYPE TIMESTAMP;`, }, }, + { + Id: "clients_07", + Up: []string{ + `ALTER TABLE users ADD COLUMN verified_at TIMESTAMPTZ DEFAULT NULL;`, + `CREATE TABLE users_verifications ( + user_id VARCHAR(36) NOT NULL, + email VARCHAR(254) NOT NULL, + otp VARCHAR(255), + created_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ, + used_at TIMESTAMPTZ, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE + ); + CREATE INDEX idx_users_verifications_lookup ON users_verifications (user_id, email, created_at DESC); + `, + }, + Down: []string{ + `ALTER TABLE users DROP COLUMN verified_at;`, + `DROP TABLE users_verifications;`, + }, + }, }, } } diff --git a/users/postgres/users.go b/users/postgres/users.go index 5bcc83a6f..0f8d65530 100644 --- a/users/postgres/users.go +++ b/users/postgres/users.go @@ -33,7 +33,7 @@ func NewRepository(db postgres.Database) users.Repository { func (repo *userRepo) Save(ctx context.Context, c users.User) (users.User, error) { q := `INSERT INTO users (id, tags, email, secret, metadata, created_at, status, role, first_name, last_name, username, profile_picture) VALUES (:id, :tags, :email, :secret, :metadata, :created_at, :status, :role, :first_name, :last_name, :username, :profile_picture) - RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture` + RETURNING id, tags, email, metadata, created_at, status, role, first_name, last_name, username, profile_picture, verified_at` dbu, err := toDBUser(c) if err != nil { @@ -81,7 +81,7 @@ func (repo *userRepo) CheckSuperAdmin(ctx context.Context, adminID string) error } func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User, error) { - q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture + q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, profile_picture, verified_at FROM users WHERE id = :id` dbu := DBUser{ @@ -95,20 +95,20 @@ func (repo *userRepo) RetrieveByID(ctx context.Context, id string) (users.User, defer rows.Close() dbu = DBUser{} - if rows.Next() { - if err = rows.StructScan(&dbu); err != nil { - return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err) - } - - user, err := ToUser(dbu) - if err != nil { - return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err) - } - - return user, nil + if !rows.Next() { + return users.User{}, repoerr.ErrNotFound } - return users.User{}, repoerr.ErrNotFound + if err = rows.StructScan(&dbu); err != nil { + return users.User{}, postgres.HandleError(repoerr.ErrViewEntity, err) + } + + user, err := ToUser(dbu) + if err != nil { + return users.User{}, errors.Wrap(repoerr.ErrFailedOpDB, err) + } + + return user, nil } func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.UsersPage, error) { @@ -127,7 +127,7 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use } q := fmt.Sprintf(`SELECT u.id, u.tags, u.email, u.metadata, u.status, u.role, u.first_name, u.last_name, u.username, - u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by + u.created_at, u.updated_at, u.profile_picture, COALESCE(u.updated_by, '') AS updated_by, u.verified_at FROM users u %s %s LIMIT :limit OFFSET :offset;`, query, orderClause) dbPage, err := ToDBUsersPage(pm) @@ -180,35 +180,9 @@ func (repo *userRepo) RetrieveAll(ctx context.Context, pm users.Page) (users.Use func (repo *userRepo) UpdateUsername(ctx context.Context, user users.User) (users.User, error) { q := `UPDATE users SET username = :username, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id AND status = :status - RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email` + RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, email, role, verified_at` - dbu, err := toDBUser(user) - if err != nil { - return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err) - } - - row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbu) - if err != nil { - return users.User{}, postgres.HandleError(repoerr.ErrUpdateEntity, err) - } - - defer row.Close() - - dbu = DBUser{ - ID: user.ID, - Username: stringToNullString(user.Credentials.Username), - UpdatedAt: sql.NullTime{Time: time.Now().UTC(), Valid: true}, - } - - if ok := row.Next(); !ok { - return users.User{}, errors.Wrap(repoerr.ErrNotFound, row.Err()) - } - - if err := row.StructScan(&dbu); err != nil { - return users.User{}, err - } - - return ToUser(dbu) + return repo.update(ctx, user, q) } func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) (users.User, error) { @@ -231,18 +205,10 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) ( query = append(query, "tags = :tags") u.Tags = *ur.Tags } - if ur.Role != nil && *ur.Role != users.AllRole { - query = append(query, "role = :role") - u.Role = *ur.Role - } if ur.ProfilePicture != nil { query = append(query, "profile_picture = :profile_picture") u.ProfilePicture = *ur.ProfilePicture } - if ur.Email != nil && *ur.Email != "" { - query = append(query, "email = :email") - u.Email = *ur.Email - } u.UpdatedAt = time.Now().UTC() if ur.UpdatedAt != nil { query = append(query, "updated_at = :updated_at") @@ -259,7 +225,7 @@ func (repo *userRepo) Update(ctx context.Context, id string, ur users.UserReq) ( q := fmt.Sprintf(`UPDATE users SET %s WHERE id = :id AND status = :status - RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role`, upq) + RETURNING id, tags, metadata, status, created_at, updated_at, updated_by, last_name, first_name, username, profile_picture, email, role, verified_at`, upq) u.Status = users.EnabledStatus return repo.update(ctx, u, q) @@ -278,21 +244,37 @@ func (repo *userRepo) update(ctx context.Context, user users.User, query string) defer row.Close() dbu = DBUser{} - if row.Next() { - if err := row.StructScan(&dbu); err != nil { - return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err) - } - - return ToUser(dbu) + if !row.Next() { + return users.User{}, repoerr.ErrNotFound } - return users.User{}, repoerr.ErrNotFound + if err := row.StructScan(&dbu); err != nil { + return users.User{}, errors.Wrap(repoerr.ErrUpdateEntity, err) + } + + return ToUser(dbu) +} + +func (repo *userRepo) UpdateEmail(ctx context.Context, user users.User) (users.User, error) { + q := `UPDATE users SET email = :email, verified_at = NULL, updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id AND status = :status + RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at` + user.Status = users.EnabledStatus + return repo.update(ctx, user, q) +} + +func (repo *userRepo) UpdateRole(ctx context.Context, user users.User) (users.User, error) { + q := `UPDATE users SET role = :role, updated_at = :updated_at, updated_by = :updated_by + WHERE id = :id AND status = :status + RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at` + user.Status = users.EnabledStatus + return repo.update(ctx, user, q) } func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users.User, error) { q := `UPDATE users SET secret = :secret, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id AND status = :status - RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username` + RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at` user.Status = users.EnabledStatus return repo.update(ctx, user, q) } @@ -300,7 +282,15 @@ func (repo *userRepo) UpdateSecret(ctx context.Context, user users.User) (users. func (repo *userRepo) ChangeStatus(ctx context.Context, user users.User) (users.User, error) { q := `UPDATE users SET status = :status, updated_at = :updated_at, updated_by = :updated_by WHERE id = :id - RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username` + RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at` + + return repo.update(ctx, user, q) +} + +func (repo *userRepo) UpdateVerifiedAt(ctx context.Context, user users.User) (users.User, error) { + q := `UPDATE users SET verified_at = :verified_at + WHERE id = :id and email = :email + RETURNING id, tags, email, metadata, status, created_at, updated_at, updated_by, first_name, last_name, username, role, verified_at` return repo.update(ctx, user, q) } @@ -434,7 +424,7 @@ func (repo *userRepo) RetrieveAllByIDs(ctx context.Context, pm users.Page) (user } func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users.User, error) { - q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username + q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at FROM users WHERE email = :email AND status = :status` dbu := DBUser{ @@ -461,7 +451,7 @@ func (repo *userRepo) RetrieveByEmail(ctx context.Context, email string) (users. } func (repo *userRepo) RetrieveByUsername(ctx context.Context, username string) (users.User, error) { - q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username + q := `SELECT id, tags, email, secret, metadata, created_at, updated_at, updated_by, status, role, first_name, last_name, username, verified_at FROM users WHERE username = :username AND status = :status` dbu := DBUser{ @@ -504,6 +494,7 @@ type DBUser struct { LastName sql.NullString `db:"last_name, omitempty"` ProfilePicture sql.NullString `db:"profile_picture, omitempty"` Email string `db:"email,omitempty"` + VerifiedAt sql.NullTime `db:"verified_at,omitempty"` } func toDBUser(u users.User) (DBUser, error) { @@ -527,6 +518,10 @@ func toDBUser(u users.User) (DBUser, error) { if u.UpdatedAt != (time.Time{}) { updatedAt = sql.NullTime{Time: u.UpdatedAt, Valid: true} } + var verifiedAt sql.NullTime + if u.VerifiedAt != (time.Time{}) { + verifiedAt = sql.NullTime{Time: u.VerifiedAt, Valid: true} + } return DBUser{ ID: u.ID, @@ -543,6 +538,7 @@ func toDBUser(u users.User) (DBUser, error) { Username: stringToNullString(u.Credentials.Username), ProfilePicture: stringToNullString(u.ProfilePicture), Email: u.Email, + VerifiedAt: verifiedAt, }, nil } @@ -565,6 +561,10 @@ func ToUser(dbu DBUser) (users.User, error) { if dbu.UpdatedAt.Valid { updatedAt = dbu.UpdatedAt.Time.UTC() } + var verifiedAt time.Time + if dbu.VerifiedAt.Valid { + verifiedAt = dbu.VerifiedAt.Time.UTC() + } user := users.User{ ID: dbu.ID, @@ -582,6 +582,7 @@ func ToUser(dbu DBUser) (users.User, error) { Status: dbu.Status, Tags: tags, ProfilePicture: nullStringString(dbu.ProfilePicture), + VerifiedAt: verifiedAt, } if dbu.Role != nil { user.Role = *dbu.Role diff --git a/users/postgres/users_test.go b/users/postgres/users_test.go index 471524aa6..25bcc65cd 100644 --- a/users/postgres/users_test.go +++ b/users/postgres/users_test.go @@ -1048,6 +1048,147 @@ func TestSearch(t *testing.T) { } } +func TestUpdateRole(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM users") + require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err)) + }) + repo := cpostgres.NewRepository(database) + user1 := generateUser(t, users.EnabledStatus, repo) + user2 := generateUser(t, users.DisabledStatus, repo) + adminRole := users.AdminRole + userRole := users.UserRole + + cases := []struct { + desc string + update string + userID string + userReq users.User + err error + }{ + { + desc: "update role of user to admin", + userReq: users.User{ + ID: user1.ID, + Role: adminRole, + }, + err: nil, + }, + { + desc: "update role of admin to user", + userReq: users.User{ + ID: user1.ID, + Role: userRole, + }, + err: nil, + }, + { + desc: "update role for disabled user", + userReq: users.User{ + ID: user2.ID, + Role: adminRole, + }, + err: repoerr.ErrNotFound, + }, + { + desc: "update role for invalid user", + userReq: users.User{ + ID: testsutil.GenerateUUID(t), + Role: adminRole, + }, + err: repoerr.ErrNotFound, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + updatedAt := time.Now().UTC().Truncate(time.Millisecond) + updatedBy := testsutil.GenerateUUID(t) + c.userReq.UpdatedAt = updatedAt + c.userReq.UpdatedBy = updatedBy + expected, err := repo.UpdateRole(context.Background(), c.userReq) + assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err)) + if err == nil { + assert.Equal(t, c.userReq.Role, expected.Role) + assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt) + assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy) + } + }) + } +} + +func TestUpdateEmail(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM users") + require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err)) + }) + repo := cpostgres.NewRepository(database) + + user1 := generateUser(t, users.EnabledStatus, repo) + user2 := generateUser(t, users.DisabledStatus, repo) + user3 := generateUser(t, users.EnabledStatus, repo) + + updatedEmail := namesgen.Generate() + emailSuffix + emptyName := "" + + cases := []struct { + desc string + update string + userReq users.User + err error + }{ + { + desc: "update email for enabled user", + userReq: users.User{ + ID: user1.ID, + Email: updatedEmail, + }, + + err: nil, + }, + { + desc: "update empty email for enabled user", + userReq: users.User{ + ID: user3.ID, + Email: emptyName, + }, + err: nil, + }, + { + desc: "update email for disabled user", + userReq: users.User{ + ID: user2.ID, + Email: updatedEmail, + }, + err: repoerr.ErrNotFound, + }, + { + desc: "update email for invalid user", + userReq: users.User{ + ID: testsutil.GenerateUUID(t), + Email: updatedEmail, + }, + err: repoerr.ErrNotFound, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + updatedAt := time.Now().UTC().Truncate(time.Millisecond) + updatedBy := testsutil.GenerateUUID(t) + c.userReq.UpdatedAt = updatedAt + c.userReq.UpdatedBy = updatedBy + expected, err := repo.UpdateEmail(context.Background(), c.userReq) + assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err)) + if err == nil { + assert.Equal(t, c.userReq.Email, expected.Email) + assert.Equal(t, c.userReq.UpdatedAt, expected.UpdatedAt) + assert.Equal(t, c.userReq.UpdatedBy, expected.UpdatedBy) + } + }) + } +} + func TestUpdate(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM users") @@ -1065,11 +1206,8 @@ func TestUpdate(t *testing.T) { updatedFirstName := namesgen.Generate() updateTags := namesgen.GenerateMultiple(5) updatedProfilePicture := namesgen.Generate() - adminRole := users.AdminRole - updatedEmail := namesgen.Generate() + emailSuffix emptyName := "" emptyTags := []string{} - allRole := users.AllRole cases := []struct { desc string @@ -1286,78 +1424,6 @@ func TestUpdate(t *testing.T) { }, err: repoerr.ErrNotFound, }, - { - desc: "update role for enabled user", - userID: user1.ID, - userReq: users.UserReq{ - Role: &adminRole, - }, - userRes: users.User{ - Role: adminRole, - }, - err: nil, - }, - { - desc: "update all role for enabled user", - userID: user3.ID, - userReq: users.UserReq{ - Role: &allRole, - }, - userRes: user3, - err: nil, - }, - { - desc: "update role for disabled user", - userID: user2.ID, - userReq: users.UserReq{ - Role: &adminRole, - }, - err: repoerr.ErrNotFound, - }, - { - desc: "update role for invalid user", - userID: testsutil.GenerateUUID(t), - userReq: users.UserReq{ - Role: &adminRole, - }, - err: repoerr.ErrNotFound, - }, - { - desc: "update email for enabled user", - userID: user1.ID, - userReq: users.UserReq{ - Email: &updatedEmail, - }, - userRes: users.User{ - Email: updatedEmail, - }, - err: nil, - }, - { - desc: "update empty email for enabled user", - userID: user3.ID, - userReq: users.UserReq{ - Email: &emptyName, - }, - userRes: user3, - err: nil, - }, - { - desc: "update email for disabled user", - userID: user2.ID, - userReq: users.UserReq{ - Email: &updatedEmail, - }, - err: repoerr.ErrNotFound, - }, - { - desc: "update email for invalid user", - userID: testsutil.GenerateUUID(t), - userReq: users.UserReq{ - Email: &updatedEmail, - }, - err: repoerr.ErrNotFound, - }, } for _, c := range cases { diff --git a/users/postgres/verfications.go b/users/postgres/verfications.go new file mode 100644 index 000000000..884f455d1 --- /dev/null +++ b/users/postgres/verfications.go @@ -0,0 +1,131 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + "database/sql" + "time" + + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/users" +) + +// AddUserVerification adds new verification for given user id and email. +func (repo *userRepo) AddUserVerification(ctx context.Context, uv users.UserVerification) error { + q := `INSERT INTO users_verifications (user_id, email, otp, created_at, expires_at ) + VALUES (:user_id, :email, :otp, :created_at, :expires_at );` + dbuv := toDBUserVerification(uv) + if _, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + return nil +} + +// RetrieveUserVerification retrieves verification details of given user id and email. +func (repo *userRepo) RetrieveUserVerification(ctx context.Context, userID, email string) (users.UserVerification, error) { + dbuv := dbUserVerification{ + UserID: userID, + Email: email, + } + q := `SELECT user_id, email, otp, created_at, expires_at , used_at FROM users_verifications WHERE user_id = :user_id AND email = :email ORDER BY created_at DESC LIMIT 1 ` + + row, err := repo.Repository.DB.NamedQueryContext(ctx, q, dbuv) + if err != nil { + return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + if !row.Next() { + return users.UserVerification{}, repoerr.ErrNotFound + } + + defer row.Close() + + if err := row.StructScan(&dbuv); err != nil { + return users.UserVerification{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + return toUserVerification(dbuv), nil +} + +// UpdateUserVerification update user verification details for the given user id and email. +func (repo *userRepo) UpdateUserVerification(ctx context.Context, uv users.UserVerification) error { + q := `UPDATE users_verifications SET otp = :otp, used_at = :used_at WHERE user_id = :user_id AND email = :email` + dbuv := toDBUserVerification(uv) + res, err := repo.Repository.DB.NamedExecContext(ctx, q, dbuv) + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + rows, err := res.RowsAffected() + if err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if rows == 0 { + return repoerr.ErrNotFound + } + return nil +} + +type dbUserVerification struct { + UserID string `db:"user_id"` + Email string `db:"email"` + OTP sql.NullString `db:"otp"` + CreatedAt sql.NullTime `db:"created_at"` + ExpiresAt sql.NullTime `db:"expires_at"` + UsedAt sql.NullTime `db:"used_at"` +} + +func toDBUserVerification(uv users.UserVerification) dbUserVerification { + var otp sql.NullString + if uv.OTP != "" { + otp = sql.NullString{String: uv.OTP, Valid: true} + } + var createdAt sql.NullTime + if !uv.CreatedAt.IsZero() { + createdAt = sql.NullTime{Time: uv.CreatedAt, Valid: true} + } + var expiresAt sql.NullTime + if !uv.ExpiresAt.IsZero() { + expiresAt = sql.NullTime{Time: uv.ExpiresAt, Valid: true} + } + var usedAt sql.NullTime + if !uv.UsedAt.IsZero() { + usedAt = sql.NullTime{Time: uv.UsedAt, Valid: true} + } + + return dbUserVerification{ + UserID: uv.UserID, + Email: uv.Email, + OTP: otp, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + UsedAt: usedAt, + } +} + +func toUserVerification(dbuv dbUserVerification) users.UserVerification { + var createdAt time.Time + if dbuv.CreatedAt.Valid { + createdAt = dbuv.CreatedAt.Time.UTC() + } + + var expiresAt time.Time + if dbuv.ExpiresAt.Valid { + expiresAt = dbuv.ExpiresAt.Time.UTC() + } + + var usedAt time.Time + if dbuv.UsedAt.Valid { + usedAt = dbuv.UsedAt.Time.UTC() + } + + return users.UserVerification{ + UserID: dbuv.UserID, + Email: dbuv.Email, + OTP: dbuv.OTP.String, + CreatedAt: createdAt, + ExpiresAt: expiresAt, + UsedAt: usedAt, + } +} diff --git a/users/postgres/verfications_test.go b/users/postgres/verfications_test.go new file mode 100644 index 000000000..74587659f --- /dev/null +++ b/users/postgres/verfications_test.go @@ -0,0 +1,217 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/users" + "github.com/absmach/supermq/users/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddUserVerification(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM users") + require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM users_verifications") + require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + first_name := namesgen.Generate() + last_name := namesgen.Generate() + username := namesgen.Generate() + user := users.User{ + ID: "test-user-id", + Email: "test@example.com", + FirstName: first_name, + LastName: last_name, + Credentials: users.Credentials{ + Username: username, + }, + } + _, err := repo.Save(context.Background(), user) + require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err)) + + cases := []struct { + desc string + uv users.UserVerification + err error + }{ + { + desc: "add new user verification", + uv: users.UserVerification{ + UserID: user.ID, + Email: user.Email, + CreatedAt: time.Now().UTC(), + OTP: "123456", + ExpiresAt: time.Now().UTC().Add(time.Hour), + }, + err: nil, + }, + { + desc: "add user verification for non-existing user", + uv: users.UserVerification{ + UserID: "non-existing-user", + Email: "non-existing@example.com", + OTP: "654321", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Hour), + }, + err: repoerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + err := repo.AddUserVerification(context.Background(), tc.uv) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + } +} + +func TestRetrieveUserVerification(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM users") + require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM users_verifications") + require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + first_name := namesgen.Generate() + last_name := namesgen.Generate() + username := namesgen.Generate() + user := users.User{ + ID: "test-user-id", + Email: "test@example.com", + FirstName: first_name, + LastName: last_name, + Credentials: users.Credentials{ + Username: username, + }, + } + _, err := repo.Save(context.Background(), user) + require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err)) + + uv := users.UserVerification{ + UserID: user.ID, + Email: user.Email, + OTP: "123456", + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(time.Hour), + } + err = repo.AddUserVerification(context.Background(), uv) + require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err)) + + cases := []struct { + desc string + userID string + email string + err error + }{ + { + desc: "retrieve existing user verification", + userID: user.ID, + email: user.Email, + err: nil, + }, + { + desc: "retrieve non-existing user verification", + userID: "non-existing-user", + email: "non-existing@example.com", + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.userID, tc.email) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + if err == nil { + assert.Equal(t, uv.UserID, retrieved.UserID, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.UserID, retrieved.UserID)) + assert.Equal(t, uv.Email, retrieved.Email, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.Email, retrieved.Email)) + assert.Equal(t, uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, uv.OTP, retrieved.OTP)) + } + } +} + +func TestUpdateUserVerification(t *testing.T) { + t.Cleanup(func() { + _, err := db.Exec("DELETE FROM users") + require.Nil(t, err, fmt.Sprintf("clean users unexpected error: %s", err)) + _, err = db.Exec("DELETE FROM users_verifications") + require.Nil(t, err, fmt.Sprintf("clean users_verifications unexpected error: %s", err)) + }) + + repo := postgres.NewRepository(database) + + first_name := namesgen.Generate() + last_name := namesgen.Generate() + username := namesgen.Generate() + user := users.User{ + ID: "test-user-id", + Email: "test@example.com", + FirstName: first_name, + LastName: last_name, + Credentials: users.Credentials{ + Username: username, + }, + } + _, err := repo.Save(context.Background(), user) + require.Nil(t, err, fmt.Sprintf("saving user unexpected error: %s", err)) + + uv := users.UserVerification{ + UserID: user.ID, + Email: user.Email, + OTP: "123456", + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + err = repo.AddUserVerification(context.Background(), uv) + require.Nil(t, err, fmt.Sprintf("adding user verification unexpected error: %s", err)) + + usedTime := time.Now() + cases := []struct { + desc string + uv users.UserVerification + err error + }{ + { + desc: "update existing user verification", + uv: users.UserVerification{ + UserID: user.ID, + Email: user.Email, + OTP: "654321", + UsedAt: usedTime, + }, + err: nil, + }, + { + desc: "update non-existing user verification", + uv: users.UserVerification{ + UserID: "non-existing-user", + Email: "non-existing@example.com", + }, + err: repoerr.ErrNotFound, + }, + } + + for _, tc := range cases { + err := repo.UpdateUserVerification(context.Background(), tc.uv) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s", tc.desc, tc.err, err)) + if err == nil { + retrieved, err := repo.RetrieveUserVerification(context.Background(), tc.uv.UserID, tc.uv.Email) + require.Nil(t, err, fmt.Sprintf("retrieving updated verification unexpected error: %s", err)) + assert.Equal(t, tc.uv.OTP, retrieved.OTP, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.OTP, retrieved.OTP)) + assert.WithinDuration(t, tc.uv.UsedAt, retrieved.UsedAt, 10*time.Second, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.uv.UsedAt, retrieved.UsedAt)) + } + } +} diff --git a/users/service.go b/users/service.go index 9be1f6db1..4297e9eba 100644 --- a/users/service.go +++ b/users/service.go @@ -5,6 +5,7 @@ package users import ( "context" + "fmt" "net/mail" "time" @@ -92,6 +93,85 @@ func (svc service) Register(ctx context.Context, session authn.Session, u User, return user, nil } +func (svc service) SendVerification(ctx context.Context, session authn.Session) error { + dbUser, err := svc.users.RetrieveByID(ctx, session.UserID) + if err != nil { + return err + } + + if !dbUser.VerifiedAt.IsZero() { + return svcerr.ErrUserAlreadyVerified + } + + uv, err := svc.users.RetrieveUserVerification(ctx, dbUser.ID, dbUser.Email) + if err != nil && err != repoerr.ErrNotFound { + return err + } + + if err = uv.Valid(); err != nil { + uv, err = NewUserVerification(dbUser.ID, dbUser.Email) + if err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + if err := svc.users.AddUserVerification(ctx, uv); err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + } + + uvs, err := uv.Encode() + if err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + + if err := svc.email.SendVerification([]string{dbUser.Email}, dbUser.Credentials.Username, uvs); err != nil { + return errors.Wrap(svcerr.ErrCreateEntity, err) + } + return nil +} + +func (svc service) VerifyEmail(ctx context.Context, token string) (User, error) { + var received UserVerification + if err := received.Decode(token); err != nil { + return User{}, errors.Wrap(svcerr.ErrInvalidUserVerification, err) + } + + stored, err := svc.users.RetrieveUserVerification(ctx, received.UserID, received.Email) + if err != nil { + return User{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + + if err := stored.Match(received); err != nil { + return User{}, err + } + + if err := stored.Valid(); err != nil { + if err == svcerr.ErrUserVerificationExpired { + return User{}, err + } + return User{}, errors.Wrap(svcerr.ErrMalformedEntity, err) + } + + stored.UsedAt = time.Now().UTC() + if err = svc.users.UpdateUserVerification(ctx, stored); err != nil { + return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + user := User{ + ID: stored.UserID, + Email: stored.Email, + VerifiedAt: time.Now().UTC(), + } + user, err = svc.users.UpdateVerifiedAt(ctx, user) + if err == repoerr.ErrNotFound { + return User{}, svcerr.ErrInvalidUserVerification + } + if err != nil { + return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + return user, nil +} + func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) { var dbUser User var err error @@ -110,7 +190,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err) } - token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey)}) + token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()}) if err != nil { return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err) } @@ -127,7 +207,7 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser) } - return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken}) + return svc.token.Refresh(ctx, &grpcTokenV1.RefreshReq{RefreshToken: refreshToken, Verified: !dbUser.VerifiedAt.IsZero()}) } func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) { @@ -255,14 +335,23 @@ func (svc service) UpdateEmail(ctx context.Context, session authn.Session, userI return User{}, err } } - - updatedAt := time.Now().UTC() - usr := UserReq{ - Email: &email, - UpdatedAt: &updatedAt, - UpdatedBy: &session.UserID, + oldUsr, err := svc.users.RetrieveByID(ctx, userID) + if err != nil { + return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } - user, err := svc.users.Update(ctx, userID, usr) + if oldUsr.Email == email { + return User{}, fmt.Errorf("current email is same as update requested email") + } + + usr := User{ + ID: userID, + Email: email, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, + VerifiedAt: time.Time{}, + } + + user, err := svc.users.UpdateEmail(ctx, usr) if err != nil { return User{}, errors.Wrap(svcerr.ErrUpdateEntity, err) } @@ -362,18 +451,18 @@ func (svc service) UpdateRole(ctx context.Context, session authn.Session, usr Us if err := svc.checkSuperAdmin(ctx, session); err != nil { return User{}, err } - updateAt := time.Now().UTC() - uReq := UserReq{ - Role: &usr.Role, - UpdatedAt: &updateAt, - UpdatedBy: &session.UserID, + usr = User{ + ID: usr.ID, + Role: usr.Role, + UpdatedAt: time.Now().UTC(), + UpdatedBy: session.UserID, } if err := svc.updateUserPolicy(ctx, usr.ID, usr.Role); err != nil { return User{}, err } - u, err := svc.users.Update(ctx, usr.ID, uReq) + u, err := svc.users.UpdateRole(ctx, usr) if err != nil { // If failed to update role in DB, then revert back to platform admin policies in spicedb if errRollback := svc.updateUserPolicy(ctx, usr.ID, UserRole); errRollback != nil { diff --git a/users/service_test.go b/users/service_test.go index b1df9a8fb..df2e7f517 100644 --- a/users/service_test.go +++ b/users/service_test.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" "testing" + "time" grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" smqauth "github.com/absmach/supermq/auth" @@ -776,13 +777,13 @@ func TestUpdateRole(t *testing.T) { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) policyCall := policies.On("AddPolicy", context.Background(), mock.Anything).Return(tc.addPolicyErr) policyCall1 := policies.On("DeletePolicyFilter", context.Background(), mock.Anything).Return(tc.deletePolicyErr) - repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr) + repoCall1 := cRepo.On("UpdateRole", context.Background(), mock.Anything).Return(tc.updateRoleResponse, tc.updateRoleErr) updatedUser, err := svc.UpdateRole(context.Background(), tc.session, tc.user) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateRoleResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateRoleResponse, updatedUser)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything) + ok := repoCall1.Parent.AssertCalled(t, "UpdateRole", context.Background(), mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() @@ -906,7 +907,7 @@ func TestUpdateEmail(t *testing.T) { svc, cRepo := newServiceMinimal() user2 := user - user2.Email = "updated@example.com" + user2.Email = "user2@example.com" cases := []struct { desc string @@ -921,16 +922,26 @@ func TestUpdateEmail(t *testing.T) { }{ { desc: "update user as normal user successfully", - email: "updated@example.com", + email: "user2-update-1@example.com", token: validToken, reqUserID: user.ID, id: user.ID, updateEmailResponse: user2, err: nil, }, + { + desc: "update to same email as normal user successfully", + email: "user2-update-1@example.com", + token: validToken, + reqUserID: user.ID, + id: user.ID, + updateEmailResponse: user2, + err: nil, + }, + { desc: "update user email as normal user with repo error on update", - email: "updated@example.com", + email: "user2-update-2@example.com", token: validToken, reqUserID: user.ID, id: user.ID, @@ -940,14 +951,14 @@ func TestUpdateEmail(t *testing.T) { }, { desc: "update user email as admin successfully", - email: "updated@example.com", + email: "user2-update-3@example.com", token: validToken, id: user.ID, err: nil, }, { desc: "update user email as admin with repo error on update", - email: "updated@exmaple.com", + email: "user2-update-4@exmaple.com", token: validToken, reqUserID: user.ID, id: user.ID, @@ -957,7 +968,7 @@ func TestUpdateEmail(t *testing.T) { }, { desc: "update user as admin user with failed check on super admin", - email: "updated@exmaple.com", + email: "user2-update-5@exmaple.com", token: validToken, reqUserID: user.ID, id: "", @@ -970,15 +981,18 @@ func TestUpdateEmail(t *testing.T) { for _, tc := range cases { repoCall := cRepo.On("CheckSuperAdmin", context.Background(), mock.Anything).Return(tc.checkSuperAdminErr) - repoCall1 := cRepo.On("Update", context.Background(), mock.Anything, mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr) + repocall2 := cRepo.On("RetrieveByID", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr) + repoCall1 := cRepo.On("UpdateEmail", context.Background(), mock.Anything).Return(tc.updateEmailResponse, tc.updateEmailErr) updatedUser, err := svc.UpdateEmail(context.Background(), authn.Session{DomainUserID: tc.reqUserID, UserID: validID, DomainID: validID}, tc.id, tc.email) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.updateEmailResponse, updatedUser, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.updateEmailResponse, updatedUser)) - if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "Update", context.Background(), mock.Anything, mock.Anything) + if tc.err == nil && user2.Email != tc.email { + ok := repoCall1.Parent.AssertCalled(t, "UpdateEmail", context.Background(), mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) + user2.Email = tc.email } repoCall.Unset() + repocall2.Unset() repoCall1.Unset() } } @@ -1791,3 +1805,160 @@ func TestOAuthCallback(t *testing.T) { }) } } + +func TestSendVerification(t *testing.T) { + svc, _, cRepo, _, e := newService() + + verifiedAt := time.Now().UTC() + cases := []struct { + desc string + session authn.Session + retrieveByIDResponse users.User + retrieveByIDError error + retrieveUserVerResponse users.UserVerification + retrieveUserVerError error + addUserVerError error + sendVerificationEmailError error + err error + }{ + { + desc: "send verification email successfully", + session: authn.Session{UserID: user.ID}, + retrieveByIDResponse: user, + retrieveUserVerError: repoerr.ErrNotFound, + sendVerificationEmailError: nil, + err: nil, + }, + { + desc: "send verification email for already verified user", + session: authn.Session{UserID: user.ID}, + retrieveByIDResponse: users.User{VerifiedAt: verifiedAt}, + err: svcerr.ErrUserAlreadyVerified, + }, + { + desc: "send verification email for non-existing user", + session: authn.Session{UserID: wrongID}, + retrieveByIDError: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "send verification email with failed to retrieve user verification", + session: authn.Session{UserID: user.ID}, + retrieveByIDResponse: user, + retrieveUserVerError: svcerr.ErrViewEntity, + err: svcerr.ErrViewEntity, + }, + { + desc: "send verification email with failed to add user verification", + session: authn.Session{UserID: user.ID}, + retrieveByIDResponse: user, + retrieveUserVerError: repoerr.ErrNotFound, + addUserVerError: svcerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, + { + desc: "send verification email with failed to send email", + session: authn.Session{UserID: user.ID}, + retrieveByIDResponse: user, + retrieveUserVerError: repoerr.ErrNotFound, + sendVerificationEmailError: svcerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := cRepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.retrieveByIDResponse, tc.retrieveByIDError) + repoCall1 := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError) + repoCall2 := cRepo.On("AddUserVerification", context.Background(), mock.Anything).Return(tc.addUserVerError) + emailCall := e.On("SendVerification", []string{user.Email}, user.Credentials.Username, mock.Anything).Return(tc.sendVerificationEmailError) + + err := svc.SendVerification(context.Background(), tc.session) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + emailCall.Unset() + }) + } +} + +func TestVerifyEmail(t *testing.T) { + //nolint:dogsled + svc, _, cRepo, _, _ := newService() + uv, err := users.NewUserVerification(user.ID, user.Email) + assert.Nil(t, err, fmt.Sprintf("failed to create user verification: %v", err)) + uvs, err := uv.Encode() + assert.Nil(t, err, fmt.Sprintf("failed to encode user verification: %v", err)) + createdAt := time.Now().Add(-5 * users.VerificationExpiryDuration).UTC() + expiresdAt := time.Now().Add(-users.VerificationExpiryDuration).UTC() + cases := []struct { + desc string + uvs string + retrieveUserVerResponse users.UserVerification + retrieveUserVerError error + updateUserVerError error + updateVerifiedAtError error + err error + }{ + { + desc: "verify email successfully", + uvs: uvs, + retrieveUserVerResponse: uv, + err: nil, + }, + { + desc: "verify email with malformed token", + uvs: "invalid", + err: svcerr.ErrInvalidUserVerification, + }, + { + desc: "verify email with non-existing user verification", + uvs: uvs, + retrieveUserVerError: repoerr.ErrNotFound, + err: svcerr.ErrViewEntity, + }, + { + desc: "verify email with expired token", + uvs: uvs, + retrieveUserVerResponse: users.UserVerification{ + UserID: uv.UserID, + Email: uv.Email, + OTP: uv.OTP, + ExpiresAt: expiresdAt, + CreatedAt: createdAt, + UsedAt: uv.UsedAt, + }, + err: svcerr.ErrUserVerificationExpired, + }, + { + desc: "verify email with failed to update user verification", + uvs: uvs, + retrieveUserVerResponse: uv, + updateUserVerError: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + }, + { + desc: "verify email with failed to update verified at", + uvs: uvs, + retrieveUserVerResponse: uv, + updateVerifiedAtError: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := cRepo.On("RetrieveUserVerification", context.Background(), mock.Anything, mock.Anything).Return(tc.retrieveUserVerResponse, tc.retrieveUserVerError) + repoCall1 := cRepo.On("UpdateUserVerification", context.Background(), mock.Anything).Return(tc.updateUserVerError) + repoCall2 := cRepo.On("UpdateVerifiedAt", context.Background(), mock.Anything).Return(users.User{}, tc.updateVerifiedAtError) + _, err := svc.VerifyEmail(context.Background(), tc.uvs) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + }) + } +} diff --git a/users/tracing/tracing.go b/users/tracing/tracing.go index 40f07c5bd..2befd27c4 100644 --- a/users/tracing/tracing.go +++ b/users/tracing/tracing.go @@ -34,6 +34,21 @@ func (tm *tracingMiddleware) Register(ctx context.Context, session authn.Session return tm.svc.Register(ctx, session, user, selfRegister) } +// SendVerification traces the "SendVerification" operation of the wrapped users.Service. +func (tm *tracingMiddleware) SendVerification(ctx context.Context, session authn.Session) error { + ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_send_verification") + defer span.End() + + return tm.svc.SendVerification(ctx, session) +} + +// VerifyEmail traces the "VerifyEmail" operation of the wrapped users.Service. +func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken string) (users.User, error) { + ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_verify_email") + defer span.End() + return tm.svc.VerifyEmail(ctx, verificationToken) +} + // IssueToken traces the "IssueToken" operation of the wrapped users.Service. func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username))) diff --git a/users/users.go b/users/users.go index c3d77b0bf..1c91c6133 100644 --- a/users/users.go +++ b/users/users.go @@ -29,6 +29,7 @@ type User struct { CreatedAt time.Time `json:"created_at,omitempty"` UpdatedAt time.Time `json:"updated_at,omitempty"` UpdatedBy string `json:"updated_by,omitempty"` + VerifiedAt time.Time `json:"verified_at,omitempty"` } type Credentials struct { @@ -49,9 +50,7 @@ type UserReq struct { LastName *string `json:"last_name,omitempty"` Metadata *Metadata `json:"metadata,omitempty"` Tags *[]string `json:"tags,omitempty"` - Role *Role `json:"role,omitempty"` ProfilePicture *string `json:"profile_picture,omitempty"` - Email *string `json:"email,omitempty"` UpdatedBy *string `json:"updated_by,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` } @@ -90,6 +89,15 @@ type Repository interface { // UpdateSecret updates secret for user with given email. UpdateSecret(ctx context.Context, user User) (User, error) + // UpdateEmail updates email for user with given id. + UpdateEmail(ctx context.Context, user User) (User, error) + + // UpdateRole updates role for user with given id. + UpdateRole(ctx context.Context, user User) (User, error) + + // UpdateVerifiedAt updates the verified time for user with given id. + UpdateVerifiedAt(ctx context.Context, user User) (User, error) + // ChangeStatus changes user status to enabled or disabled ChangeStatus(ctx context.Context, user User) (User, error) @@ -107,6 +115,15 @@ type Repository interface { // Save persists the user account. A non-nil error is returned to indicate // operation failure. Save(ctx context.Context, user User) (User, error) + + // AddUserVerification adds new verification for given user id and email + AddUserVerification(ctx context.Context, uv UserVerification) error + + // RetrieveVerificationToken retrieves verification token of given user id and email. + RetrieveUserVerification(ctx context.Context, userID, email string) (UserVerification, error) + + // UpdateUserVerificationDetails update verification details for the given user id and email. + UpdateUserVerification(ctx context.Context, uv UserVerification) error } // Validate returns an error if user representation is invalid. @@ -143,6 +160,7 @@ type Page struct { FirstName string `json:"first_name,omitempty"` LastName string `json:"last_name,omitempty"` Email string `json:"email,omitempty"` + Verified bool `json:"verified,omitempty"` } // Service specifies an API that must be fullfiled by the domain service @@ -152,6 +170,12 @@ type Service interface { // non-nil error value is returned. Register(ctx context.Context, session authn.Session, user User, selfRegister bool) (User, error) + // SendVerification sends a verification email to the user. + SendVerification(ctx context.Context, session authn.Session) error + + // VerifyEmail verifies user's email using the verification token. + VerifyEmail(ctx context.Context, verificationToken string) (User, error) + // View retrieves user info for a given user ID and an authorized token. View(ctx context.Context, session authn.Session, id string) (User, error) diff --git a/users/verification.go b/users/verification.go new file mode 100644 index 000000000..e4519d0fb --- /dev/null +++ b/users/verification.go @@ -0,0 +1,121 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package users + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "time" + + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" +) + +const VerificationExpiryDuration = 24 * time.Hour + +var ( + errFailedToCreateUserVerification = errors.New("failed to create new user verification") + errFailedToEncodeUserVerification = errors.New("failed to encode user verification") + errFailedToDecodeUserVerification = errors.New("failed to decode user verification") +) + +// UserVerification OTP is sent to the user's email as base64 encoded with UserID, Email and OTP. It should not be exposed via API. +type UserVerification struct { + UserID string `json:"user_id"` + Email string `json:"email"` + OTP string `json:"otp"` + CreatedAt time.Time `json:"-"` + ExpiresAt time.Time `json:"-"` + UsedAt time.Time `json:"-"` +} + +func NewUserVerification(userID, email string) (UserVerification, error) { + randomBytes := make([]byte, 32) + if _, err := rand.Read(randomBytes); err != nil { + return UserVerification{}, errors.Wrap(errFailedToCreateUserVerification, err) + } + + return UserVerification{ + UserID: userID, + Email: email, + OTP: base64.URLEncoding.EncodeToString(randomBytes), + CreatedAt: time.Now().UTC(), + ExpiresAt: time.Now().Add(VerificationExpiryDuration).UTC(), + }, nil +} + +func (u UserVerification) Encode() (string, error) { + jsonBytes, err := json.Marshal(u) + if err != nil { + return "", errors.Wrap(errFailedToEncodeUserVerification, err) + } + + return base64.URLEncoding.EncodeToString(jsonBytes), nil +} + +func (u *UserVerification) Decode(data string) error { + decodedPayload, err := base64.URLEncoding.DecodeString(data) + if err != nil { + return errors.Wrap(errFailedToDecodeUserVerification, err) + } + + if err := json.Unmarshal(decodedPayload, u); err != nil { + return errors.Wrap(errFailedToDecodeUserVerification, err) + } + + if u.UserID == "" || u.Email == "" || u.OTP == "" { + return svcerr.ErrInvalidUserVerification + } + + return nil +} + +func (u UserVerification) Valid() error { + if u.UserID == "" || u.Email == "" || u.OTP == "" { + return svcerr.ErrInvalidUserVerification + } + + // Verification should have created time. + if u.CreatedAt.IsZero() { + return svcerr.ErrInvalidUserVerification + } + + // Verification should have expiry time. + if u.ExpiresAt.IsZero() { + return svcerr.ErrInvalidUserVerification + } + + // Expiry time should not be before Created time + if u.ExpiresAt.Before(u.CreatedAt) { + return svcerr.ErrInvalidUserVerification + } + + // Verification should be not be Expired. + if time.Now().After(u.ExpiresAt) { + return svcerr.ErrUserVerificationExpired + } + + // Verification should not be used. + if !u.UsedAt.IsZero() { + return svcerr.ErrUserVerificationExpired + } + + return nil +} + +func (u UserVerification) Match(ruv UserVerification) error { + if u.UserID != ruv.UserID { + return svcerr.ErrInvalidUserVerification + } + + if u.Email != ruv.Email { + return svcerr.ErrInvalidUserVerification + } + + if u.OTP != ruv.OTP { + return svcerr.ErrInvalidUserVerification + } + return nil +}