From 9c2608659f85c656c22b3ebcc66faaee489138c8 Mon Sep 17 00:00:00 2001 From: Felix Gateru Date: Tue, 3 Mar 2026 17:22:28 +0300 Subject: [PATCH] SMQ-1672 - Revoke refresh token (#3241) Signed-off-by: Felix Gateru Signed-off-by: nyagamunene Co-authored-by: nyagamunene --- Makefile | 7 +- api/grpc/token/v1/token.pb.go | 306 ++++++++++++++++++++++-- api/grpc/token/v1/token_grpc.pb.go | 80 ++++++- apidocs/openapi/users.yaml | 98 ++++++++ auth/api/grpc/token/client.go | 84 ++++++- auth/api/grpc/token/endpoint.go | 40 +++- auth/api/grpc/token/endpoint_test.go | 118 +++++++-- auth/api/grpc/token/requests.go | 34 ++- auth/api/grpc/token/responses.go | 6 + auth/api/grpc/token/server.go | 73 +++++- auth/cache/doc.go | 2 + auth/cache/tokens.go | 132 +++++++++++ auth/cache/tokens_test.go | 288 ++++++++++++++++++++++ auth/key_manager.go | 27 ++- auth/keys.go | 17 +- auth/middleware/logging.go | 36 +++ auth/middleware/metrics.go | 18 ++ auth/middleware/tracing.go | 19 ++ auth/mocks/service.go | 131 ++++++++++ auth/mocks/token_client.go | 166 +++++++++++++ auth/mocks/user_active_tokens_cache.go | 316 +++++++++++++++++++++++++ auth/postgres/errors.go | 24 ++ auth/service.go | 63 ++++- auth/service_test.go | 244 ++++++++++++++++++- auth/tokenizer/asymmetric/README.md | 10 +- auth/tokenizer/asymmetric/tokenizer.go | 30 ++- auth/tokenizer/symmetric/tokenizer.go | 20 +- auth/tokenizer/util/jwt.go | 4 + cmd/auth/main.go | 10 +- cmd/users/main.go | 2 +- docker/docker-compose.yaml | 2 + docker/redis/redis.conf | 14 ++ internal/proto/token/v1/token.proto | 26 ++ pkg/sdk/tokens.go | 5 +- pkg/sdk/tokens_test.go | 4 +- tools/config/.mockery.yaml | 2 +- users/api/endpoint_test.go | 211 ++++++++++++++++- users/api/endpoints.go | 39 ++- users/api/requests.go | 17 +- users/api/responses.go | 34 ++- users/api/users.go | 29 +++ users/events/events.go | 17 +- users/events/streams.go | 73 +++--- users/events/streams_test.go | 121 ++++++++-- users/middleware/authorization.go | 12 +- users/middleware/logging.go | 43 +++- users/middleware/metrics.go | 22 +- users/middleware/tracing.go | 20 +- users/mocks/service.go | 161 ++++++++++++- users/service.go | 48 +++- users/service_test.go | 177 +++++++++++++- users/users.go | 8 +- 52 files changed, 3266 insertions(+), 224 deletions(-) create mode 100644 auth/cache/tokens.go create mode 100644 auth/cache/tokens_test.go create mode 100644 auth/mocks/user_active_tokens_cache.go create mode 100644 auth/postgres/errors.go create mode 100644 docker/redis/redis.conf diff --git a/Makefile b/Makefile index 24c1682e1..b681e4d12 100644 --- a/Makefile +++ b/Makefile @@ -165,11 +165,8 @@ mocks: $(MOCKERY) $(MOCKERY): @mkdir -p $(GOBIN) - @mkdir -p mockery - @echo ">> downloading mockery $(MOCKERY_VERSION)..." - @curl -sL https://github.com/vektra/mockery/releases/download/v$(MOCKERY_VERSION)/mockery_$(MOCKERY_VERSION)_Linux_x86_64.tar.gz | tar -xz -C mockery - @mv mockery/mockery $(GOBIN) - @rm -r mockery + @echo ">> installing mockery $(MOCKERY_VERSION)..." + @go install github.com/vektra/mockery/v3@v$(MOCKERY_VERSION) DIRS = consumers readers postgres internal test: mocks diff --git a/api/grpc/token/v1/token.pb.go b/api/grpc/token/v1/token.pb.go index 475630b65..da09ac553 100644 --- a/api/grpc/token/v1/token.pb.go +++ b/api/grpc/token/v1/token.pb.go @@ -30,6 +30,7 @@ type IssueReq struct { UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"` Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` + Description string `protobuf:"bytes,5,opt,name=description,proto3" json:"description,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -92,6 +93,13 @@ func (x *IssueReq) GetVerified() bool { return false } +func (x *IssueReq) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + type RefreshReq struct { state protoimpl.MessageState `protogen:"open.v1"` RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"` @@ -144,6 +152,58 @@ func (x *RefreshReq) GetVerified() bool { return false } +type RevokeReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + TokenId string `protobuf:"bytes,1,opt,name=token_id,json=tokenId,proto3" json:"token_id,omitempty"` + UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeReq) Reset() { + *x = RevokeReq{} + mi := &file_token_v1_token_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeReq) ProtoMessage() {} + +func (x *RevokeReq) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead. +func (*RevokeReq) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{2} +} + +func (x *RevokeReq) GetTokenId() string { + if x != nil { + return x.TokenId + } + return "" +} + +func (x *RevokeReq) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -158,7 +218,7 @@ type Token struct { func (x *Token) Reset() { *x = Token{} - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -170,7 +230,7 @@ func (x *Token) String() string { func (*Token) ProtoMessage() {} func (x *Token) ProtoReflect() protoreflect.Message { - mi := &file_token_v1_token_proto_msgTypes[2] + mi := &file_token_v1_token_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -183,7 +243,7 @@ func (x *Token) ProtoReflect() protoreflect.Message { // Deprecated: Use Token.ProtoReflect.Descriptor instead. func (*Token) Descriptor() ([]byte, []int) { - return file_token_v1_token_proto_rawDescGZIP(), []int{2} + return file_token_v1_token_proto_rawDescGZIP(), []int{3} } func (x *Token) GetAccessToken() string { @@ -207,29 +267,219 @@ func (x *Token) GetAccessType() string { return "" } +type RevokeRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RevokeRes) Reset() { + *x = RevokeRes{} + mi := &file_token_v1_token_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RevokeRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RevokeRes) ProtoMessage() {} + +func (x *RevokeRes) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RevokeRes.ProtoReflect.Descriptor instead. +func (*RevokeRes) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{4} +} + +type ListUserRefreshTokensReq struct { + state protoimpl.MessageState `protogen:"open.v1"` + UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListUserRefreshTokensReq) Reset() { + *x = ListUserRefreshTokensReq{} + mi := &file_token_v1_token_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListUserRefreshTokensReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListUserRefreshTokensReq) ProtoMessage() {} + +func (x *ListUserRefreshTokensReq) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListUserRefreshTokensReq.ProtoReflect.Descriptor instead. +func (*ListUserRefreshTokensReq) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{5} +} + +func (x *ListUserRefreshTokensReq) GetUserId() string { + if x != nil { + return x.UserId + } + return "" +} + +type ListUserRefreshTokensRes struct { + state protoimpl.MessageState `protogen:"open.v1"` + RefreshTokens []*RefreshToken `protobuf:"bytes,1,rep,name=refresh_tokens,json=refreshTokens,proto3" json:"refresh_tokens,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ListUserRefreshTokensRes) Reset() { + *x = ListUserRefreshTokensRes{} + mi := &file_token_v1_token_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ListUserRefreshTokensRes) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListUserRefreshTokensRes) ProtoMessage() {} + +func (x *ListUserRefreshTokensRes) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListUserRefreshTokensRes.ProtoReflect.Descriptor instead. +func (*ListUserRefreshTokensRes) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{6} +} + +func (x *ListUserRefreshTokensRes) GetRefreshTokens() []*RefreshToken { + if x != nil { + return x.RefreshTokens + } + return nil +} + +type RefreshToken struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RefreshToken) Reset() { + *x = RefreshToken{} + mi := &file_token_v1_token_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RefreshToken) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RefreshToken) ProtoMessage() {} + +func (x *RefreshToken) ProtoReflect() protoreflect.Message { + mi := &file_token_v1_token_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RefreshToken.ProtoReflect.Descriptor instead. +func (*RefreshToken) Descriptor() ([]byte, []int) { + return file_token_v1_token_proto_rawDescGZIP(), []int{7} +} + +func (x *RefreshToken) GetId() string { + if x != nil { + return x.Id + } + return "" +} + +func (x *RefreshToken) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + var File_token_v1_token_proto protoreflect.FileDescriptor const file_token_v1_token_proto_rawDesc = "" + "\n" + - "\x14token/v1/token.proto\x12\btoken.v1\"p\n" + + "\x14token/v1/token.proto\x12\btoken.v1\"\x92\x01\n" + "\bIssueReq\x12\x17\n" + "\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" + "\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" + "\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" + - "\bverified\x18\x04 \x01(\bR\bverified\"M\n" + + "\bverified\x18\x04 \x01(\bR\bverified\x12 \n" + + "\vdescription\x18\x05 \x01(\tR\vdescription\"M\n" + "\n" + "RefreshReq\x12#\n" + "\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" + - "\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" + + "\bverified\x18\x02 \x01(\bR\bverified\"?\n" + + "\tRevokeReq\x12\x19\n" + + "\btoken_id\x18\x01 \x01(\tR\atokenId\x12\x17\n" + + "\auser_id\x18\x02 \x01(\tR\x06userId\"\x87\x01\n" + "\x05Token\x12!\n" + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" + "\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" + "\vaccess_type\x18\x03 \x01(\tR\n" + "accessTypeB\x10\n" + - "\x0e_refresh_token2r\n" + + "\x0e_refresh_token\"\v\n" + + "\tRevokeRes\"3\n" + + "\x18ListUserRefreshTokensReq\x12\x17\n" + + "\auser_id\x18\x01 \x01(\tR\x06userId\"Y\n" + + "\x18ListUserRefreshTokensRes\x12=\n" + + "\x0erefresh_tokens\x18\x01 \x03(\v2\x16.token.v1.RefreshTokenR\rrefreshTokens\"@\n" + + "\fRefreshToken\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12 \n" + + "\vdescription\x18\x02 \x01(\tR\vdescription2\x8b\x02\n" + "\fTokenService\x12.\n" + "\x05Issue\x12\x12.token.v1.IssueReq\x1a\x0f.token.v1.Token\"\x00\x122\n" + - "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" + "\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00\x124\n" + + "\x06Revoke\x12\x13.token.v1.RevokeReq\x1a\x13.token.v1.RevokeRes\"\x00\x12a\n" + + "\x15ListUserRefreshTokens\x12\".token.v1.ListUserRefreshTokensReq\x1a\".token.v1.ListUserRefreshTokensRes\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3" var ( file_token_v1_token_proto_rawDescOnce sync.Once @@ -243,22 +493,32 @@ func file_token_v1_token_proto_rawDescGZIP() []byte { return file_token_v1_token_proto_rawDescData } -var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_token_v1_token_proto_goTypes = []any{ - (*IssueReq)(nil), // 0: token.v1.IssueReq - (*RefreshReq)(nil), // 1: token.v1.RefreshReq - (*Token)(nil), // 2: token.v1.Token + (*IssueReq)(nil), // 0: token.v1.IssueReq + (*RefreshReq)(nil), // 1: token.v1.RefreshReq + (*RevokeReq)(nil), // 2: token.v1.RevokeReq + (*Token)(nil), // 3: token.v1.Token + (*RevokeRes)(nil), // 4: token.v1.RevokeRes + (*ListUserRefreshTokensReq)(nil), // 5: token.v1.ListUserRefreshTokensReq + (*ListUserRefreshTokensRes)(nil), // 6: token.v1.ListUserRefreshTokensRes + (*RefreshToken)(nil), // 7: token.v1.RefreshToken } var file_token_v1_token_proto_depIdxs = []int32{ - 0, // 0: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq - 1, // 1: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq - 2, // 2: token.v1.TokenService.Issue:output_type -> token.v1.Token - 2, // 3: token.v1.TokenService.Refresh:output_type -> token.v1.Token - 2, // [2:4] is the sub-list for method output_type - 0, // [0:2] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 7, // 0: token.v1.ListUserRefreshTokensRes.refresh_tokens:type_name -> token.v1.RefreshToken + 0, // 1: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq + 1, // 2: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq + 2, // 3: token.v1.TokenService.Revoke:input_type -> token.v1.RevokeReq + 5, // 4: token.v1.TokenService.ListUserRefreshTokens:input_type -> token.v1.ListUserRefreshTokensReq + 3, // 5: token.v1.TokenService.Issue:output_type -> token.v1.Token + 3, // 6: token.v1.TokenService.Refresh:output_type -> token.v1.Token + 4, // 7: token.v1.TokenService.Revoke:output_type -> token.v1.RevokeRes + 6, // 8: token.v1.TokenService.ListUserRefreshTokens:output_type -> token.v1.ListUserRefreshTokensRes + 5, // [5:9] is the sub-list for method output_type + 1, // [1:5] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_token_v1_token_proto_init() } @@ -266,14 +526,14 @@ func file_token_v1_token_proto_init() { if File_token_v1_token_proto != nil { return } - file_token_v1_token_proto_msgTypes[2].OneofWrappers = []any{} + file_token_v1_token_proto_msgTypes[3].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_v1_token_proto_rawDesc), len(file_token_v1_token_proto_rawDesc)), NumEnums: 0, - NumMessages: 3, + NumMessages: 8, NumExtensions: 0, NumServices: 1, }, diff --git a/api/grpc/token/v1/token_grpc.pb.go b/api/grpc/token/v1/token_grpc.pb.go index 4cdab04b0..0edec8845 100644 --- a/api/grpc/token/v1/token_grpc.pb.go +++ b/api/grpc/token/v1/token_grpc.pb.go @@ -22,8 +22,10 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue" - TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh" + TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue" + TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh" + TokenService_Revoke_FullMethodName = "/token.v1.TokenService/Revoke" + TokenService_ListUserRefreshTokens_FullMethodName = "/token.v1.TokenService/ListUserRefreshTokens" ) // TokenServiceClient is the client API for TokenService service. @@ -32,6 +34,8 @@ const ( type TokenServiceClient interface { Issue(ctx context.Context, in *IssueReq, opts ...grpc.CallOption) (*Token, error) Refresh(ctx context.Context, in *RefreshReq, opts ...grpc.CallOption) (*Token, error) + Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) + ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error) } type tokenServiceClient struct { @@ -62,12 +66,34 @@ func (c *tokenServiceClient) Refresh(ctx context.Context, in *RefreshReq, opts . return out, nil } +func (c *tokenServiceClient) Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RevokeRes) + err := c.cc.Invoke(ctx, TokenService_Revoke_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *tokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ListUserRefreshTokensRes) + err := c.cc.Invoke(ctx, TokenService_ListUserRefreshTokens_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // TokenServiceServer is the server API for TokenService service. // All implementations must embed UnimplementedTokenServiceServer // for forward compatibility. type TokenServiceServer interface { Issue(context.Context, *IssueReq) (*Token, error) Refresh(context.Context, *RefreshReq) (*Token, error) + Revoke(context.Context, *RevokeReq) (*RevokeRes, error) + ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error) mustEmbedUnimplementedTokenServiceServer() } @@ -84,6 +110,12 @@ func (UnimplementedTokenServiceServer) Issue(context.Context, *IssueReq) (*Token func (UnimplementedTokenServiceServer) Refresh(context.Context, *RefreshReq) (*Token, error) { return nil, status.Error(codes.Unimplemented, "method Refresh not implemented") } +func (UnimplementedTokenServiceServer) Revoke(context.Context, *RevokeReq) (*RevokeRes, error) { + return nil, status.Error(codes.Unimplemented, "method Revoke not implemented") +} +func (UnimplementedTokenServiceServer) ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error) { + return nil, status.Error(codes.Unimplemented, "method ListUserRefreshTokens not implemented") +} func (UnimplementedTokenServiceServer) mustEmbedUnimplementedTokenServiceServer() {} func (UnimplementedTokenServiceServer) testEmbeddedByValue() {} @@ -141,6 +173,42 @@ func _TokenService_Refresh_Handler(srv interface{}, ctx context.Context, dec fun return interceptor(ctx, in, info, handler) } +func _TokenService_Revoke_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RevokeReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TokenServiceServer).Revoke(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: TokenService_Revoke_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TokenServiceServer).Revoke(ctx, req.(*RevokeReq)) + } + return interceptor(ctx, in, info, handler) +} + +func _TokenService_ListUserRefreshTokens_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListUserRefreshTokensReq) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: TokenService_ListUserRefreshTokens_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, req.(*ListUserRefreshTokensReq)) + } + return interceptor(ctx, in, info, handler) +} + // TokenService_ServiceDesc is the grpc.ServiceDesc for TokenService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -156,6 +224,14 @@ var TokenService_ServiceDesc = grpc.ServiceDesc{ MethodName: "Refresh", Handler: _TokenService_Refresh_Handler, }, + { + MethodName: "Revoke", + Handler: _TokenService_Revoke_Handler, + }, + { + MethodName: "ListUserRefreshTokens", + Handler: _TokenService_ListUserRefreshTokens_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "token/v1/token.proto", diff --git a/apidocs/openapi/users.yaml b/apidocs/openapi/users.yaml index 999e2204f..39b73f94b 100644 --- a/apidocs/openapi/users.yaml +++ b/apidocs/openapi/users.yaml @@ -610,6 +610,57 @@ paths: "500": $ref: "#/components/responses/ServiceError" + /users/tokens/revoke: + post: + operationId: revokeRefreshToken + summary: Revoke Refresh Token + description: | + Revokes a specific refresh token by its ID. This invalidates the + refresh token so it can no longer be used to obtain new access tokens. + tags: + - Users + security: + - bearerAuth: [] + requestBody: + $ref: "#/components/requestBodies/RevokeRefreshTokenReq" + responses: + "204": + description: Refresh token revoked successfully. + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "415": + description: Missing or invalid content type. + "422": + description: Database can't process request. + "500": + $ref: "#/components/responses/ServiceError" + + /users/tokens/refresh-tokens: + get: + operationId: listActiveRefreshTokens + summary: List Active Refresh Tokens + description: | + Lists all active refresh token sessions for the currently authenticated user. + tags: + - Users + security: + - bearerAuth: [] + responses: + "200": + $ref: "#/components/responses/RefreshTokensPageRes" + "400": + description: Failed due to malformed JSON. + "401": + description: Missing or invalid access token provided. + "404": + description: A non-existent entity request. + "500": + $ref: "#/components/responses/ServiceError" + /users/send-verification: post: operationId: sendVerification @@ -1042,6 +1093,30 @@ components: - username - password + RefreshToken: + type: object + properties: + id: + type: string + format: uuid + example: "bb7edb32-2eac-4aad-aebe-ed96fe073879" + description: Unique identifier of the refresh token. + description: + type: string + example: "Chrome browser session" + description: Description of the refresh token session. + + RefreshTokensPage: + type: object + properties: + refresh_tokens: + type: array + items: + $ref: "#/components/schemas/RefreshToken" + description: List of active refresh tokens. + required: + - refresh_tokens + Error: type: object properties: @@ -1463,6 +1538,22 @@ components: format: jwt description: Reset token generated and sent in email. + RevokeRefreshTokenReq: + description: JSON-formatted document describing the refresh token to revoke. + required: true + content: + application/json: + schema: + type: object + properties: + token_id: + type: string + format: uuid + example: "bb7edb32-2eac-4aad-aebe-ed96fe073879" + description: The unique identifier of the refresh token to revoke. + required: + - token_id + PasswordChange: description: Password change data. User can change its password. required: true @@ -1574,6 +1665,13 @@ components: example: access description: User access token type. + RefreshTokensPageRes: + description: List of active refresh tokens for the authenticated user. + content: + application/json: + schema: + $ref: "#/components/schemas/RefreshTokensPage" + HealthRes: description: Service Health Check. content: diff --git a/auth/api/grpc/token/client.go b/auth/api/grpc/token/client.go index 25c3bf62f..aec772b14 100644 --- a/auth/api/grpc/token/client.go +++ b/auth/api/grpc/token/client.go @@ -18,14 +18,16 @@ import ( const tokenSvcName = "token.v1.TokenService" type tokenGrpcClient struct { - issue endpoint.Endpoint - refresh endpoint.Endpoint - timeout time.Duration + issue endpoint.Endpoint + refresh endpoint.Endpoint + revoke endpoint.Endpoint + listUserRefreshTokens endpoint.Endpoint + timeout time.Duration } var _ grpcTokenV1.TokenServiceClient = (*tokenGrpcClient)(nil) -// NewAuthClient returns new auth gRPC client instance. +// NewTokenClient returns new token gRPC client instance. func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.TokenServiceClient { return &tokenGrpcClient{ issue: kitgrpc.NewClient( @@ -44,6 +46,22 @@ func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.To decodeRefreshResponse, grpcTokenV1.Token{}, ).Endpoint(), + revoke: kitgrpc.NewClient( + conn, + tokenSvcName, + "Revoke", + encodeRevokeRequest, + decodeRevokeResponse, + grpcTokenV1.RevokeRes{}, + ).Endpoint(), + listUserRefreshTokens: kitgrpc.NewClient( + conn, + tokenSvcName, + "ListUserRefreshTokens", + encodeListUserRefreshTokensRequest, + decodeListUserRefreshTokensResponse, + grpcTokenV1.ListUserRefreshTokensRes{}, + ).Endpoint(), timeout: timeout, } } @@ -53,10 +71,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR defer cancel() res, err := client.issue(ctx, issueReq{ - userID: req.GetUserId(), - userRole: auth.Role(req.GetUserRole()), - keyType: auth.KeyType(req.GetType()), - verified: req.GetVerified(), + userID: req.GetUserId(), + userRole: auth.Role(req.GetUserRole()), + keyType: auth.KeyType(req.GetType()), + verified: req.GetVerified(), + description: req.GetDescription(), }) if err != nil { return &grpcTokenV1.Token{}, grpcapi.DecodeError(err) @@ -67,10 +86,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(issueReq) return &grpcTokenV1.IssueReq{ - UserId: req.userID, - UserRole: uint32(req.userRole), - Type: uint32(req.keyType), - Verified: req.verified, + UserId: req.userID, + UserRole: uint32(req.userRole), + Type: uint32(req.keyType), + Verified: req.verified, + Description: req.description, }, nil } @@ -97,3 +117,43 @@ func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) { func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) { return grpcRes, nil } + +func (client tokenGrpcClient) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq, _ ...grpc.CallOption) (*grpcTokenV1.RevokeRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.revoke(ctx, revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()}) + if err != nil { + return &grpcTokenV1.RevokeRes{}, grpcapi.DecodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func encodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(revokeReq) + return &grpcTokenV1.RevokeReq{UserId: req.userID, TokenId: req.tokenID}, nil +} + +func decodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return grpcRes, nil +} + +func (client tokenGrpcClient) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq, _ ...grpc.CallOption) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + ctx, cancel := context.WithTimeout(ctx, client.timeout) + defer cancel() + + res, err := client.listUserRefreshTokens(ctx, listUserRefreshTokensReq{userID: req.GetUserId()}) + if err != nil { + return &grpcTokenV1.ListUserRefreshTokensRes{}, grpcapi.DecodeError(err) + } + return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil +} + +func encodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(listUserRefreshTokensReq) + return &grpcTokenV1.ListUserRefreshTokensReq{UserId: req.userID}, nil +} + +func decodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) { + return grpcRes, nil +} diff --git a/auth/api/grpc/token/endpoint.go b/auth/api/grpc/token/endpoint.go index b03e42ae5..e401cf4f1 100644 --- a/auth/api/grpc/token/endpoint.go +++ b/auth/api/grpc/token/endpoint.go @@ -18,10 +18,11 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint { } key := auth.Key{ - Type: req.keyType, - Subject: req.userID, - Role: req.userRole, - Verified: req.verified, + Type: req.keyType, + Subject: req.userID, + Role: req.userRole, + Verified: req.verified, + Description: req.description, } tkn, err := svc.Issue(ctx, "", key) if err != nil { @@ -56,3 +57,34 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint { return ret, nil } } + +func revokeEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(revokeReq) + if err := req.validate(); err != nil { + return nil, err + } + err := svc.RevokeToken(ctx, req.userID, req.tokenID) + if err != nil { + return nil, err + } + + return nil, nil + } +} + +func listUserRefreshTokensEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(listUserRefreshTokensReq) + if err := req.validate(); err != nil { + return listUserRefreshTokensRes{}, err + } + + refreshTokens, err := svc.ListUserRefreshTokens(ctx, req.userID) + if err != nil { + return listUserRefreshTokensRes{}, err + } + + return listUserRefreshTokensRes{refreshTokens: refreshTokens}, nil + } +} diff --git a/auth/api/grpc/token/endpoint_test.go b/auth/api/grpc/token/endpoint_test.go index be1820e42..25c173190 100644 --- a/auth/api/grpc/token/endpoint_test.go +++ b/auth/api/grpc/token/endpoint_test.go @@ -24,24 +24,10 @@ import ( ) const ( - port = 8082 - secret = "secret" - email = "test@example.com" - id = "testID" - clientsType = "clients" - usersType = "users" - description = "Description" - groupName = "smqx" - adminPermission = "admin" - - authoritiesObj = "authorities" - memberRelation = "member" - loginDuration = 30 * time.Minute - refreshDuration = 24 * time.Hour - invalidDuration = 7 * 24 * time.Hour - validToken = "valid" - inValidToken = "invalid" - validPolicy = "valid" + port = 8082 + validToken = "valid" + inValidToken = "invalid" + invalidID = "invalid" ) var ( @@ -63,9 +49,9 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server { func TestIssue(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -127,9 +113,9 @@ func TestIssue(t *testing.T) { func TestRefresh(t *testing.T) { conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) - defer conn.Close() assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() cases := []struct { desc string @@ -161,9 +147,99 @@ func TestRefresh(t *testing.T) { } for _, tc := range cases { - svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err) + svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err) _, err := grpcClient.Refresh(context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: tc.token}) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) svcCall.Unset() } } + +func TestRevoke(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "revoke token with valid id", + id: validID, + err: nil, + }, + { + desc: "revoke token with invalid id", + id: invalidID, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke token with empty id", + id: "", + err: apiutil.ErrMissingID, + }, + { + desc: "revoke already revoked token", + id: validID, + err: svcerr.ErrConflict, + }, + } + + for _, tc := range cases { + svcCall := svc.On("RevokeToken", mock.Anything, mock.Anything, tc.id).Return(tc.err) + _, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{TokenId: tc.id}) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + } +} + +func TestListUserRefreshTokens(t *testing.T) { + conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err)) + grpcClient := grpcapi.NewTokenClient(conn, time.Second) + defer conn.Close() + + cases := []struct { + desc string + userID string + listResponse []auth.TokenInfo + err error + }{ + { + desc: "list tokens for user with valid id", + userID: validID, + listResponse: []auth.TokenInfo{ + {ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 1"}, + {ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 2"}, + }, + err: nil, + }, + { + desc: "list tokens for user with empty list", + userID: validID, + listResponse: []auth.TokenInfo{}, + err: nil, + }, + { + desc: "list tokens with invalid user id", + userID: invalidID, + listResponse: nil, + err: svcerr.ErrAuthentication, + }, + { + desc: "list tokens with empty user id", + userID: "", + listResponse: nil, + err: apiutil.ErrMissingID, + }, + } + + for _, tc := range cases { + svcCall := svc.On("ListUserRefreshTokens", mock.Anything, tc.userID).Return(tc.listResponse, tc.err) + _, err := grpcClient.ListUserRefreshTokens(context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.userID}) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + } +} diff --git a/auth/api/grpc/token/requests.go b/auth/api/grpc/token/requests.go index a5ab3c094..fb25dcd34 100644 --- a/auth/api/grpc/token/requests.go +++ b/auth/api/grpc/token/requests.go @@ -9,10 +9,11 @@ import ( ) type issueReq struct { - userID string - userRole auth.Role - keyType auth.KeyType - verified bool + userID string + userRole auth.Role + keyType auth.KeyType + verified bool + description string } func (req issueReq) validate() error { @@ -38,3 +39,28 @@ func (req refreshReq) validate() error { return nil } + +type revokeReq struct { + userID string + tokenID string +} + +func (req revokeReq) validate() error { + if req.tokenID == "" { + return apiutil.ErrMissingID + } + + return nil +} + +type listUserRefreshTokensReq struct { + userID string +} + +func (req listUserRefreshTokensReq) validate() error { + if req.userID == "" { + return apiutil.ErrMissingID + } + + return nil +} diff --git a/auth/api/grpc/token/responses.go b/auth/api/grpc/token/responses.go index cb62744ee..8d3941bd2 100644 --- a/auth/api/grpc/token/responses.go +++ b/auth/api/grpc/token/responses.go @@ -3,8 +3,14 @@ package token +import "github.com/absmach/supermq/auth" + type issueRes struct { accessToken string refreshToken string accessType string } + +type listUserRefreshTokensRes struct { + refreshTokens []auth.TokenInfo +} diff --git a/auth/api/grpc/token/server.go b/auth/api/grpc/token/server.go index 319e46e6e..a88f0575a 100644 --- a/auth/api/grpc/token/server.go +++ b/auth/api/grpc/token/server.go @@ -16,11 +16,13 @@ var _ grpcTokenV1.TokenServiceServer = (*tokenGrpcServer)(nil) type tokenGrpcServer struct { grpcTokenV1.UnimplementedTokenServiceServer - issue kitgrpc.Handler - refresh kitgrpc.Handler + issue kitgrpc.Handler + refresh kitgrpc.Handler + revoke kitgrpc.Handler + listUserRefreshTokens kitgrpc.Handler } -// NewAuthServer returns new AuthnServiceServer instance. +// NewTokenServer returns new TokenServiceServer instance. func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer { return &tokenGrpcServer{ issue: kitgrpc.NewServer( @@ -33,6 +35,16 @@ func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer { decodeRefreshRequest, encodeIssueResponse, ), + revoke: kitgrpc.NewServer( + (revokeEndpoint(svc)), + decodeRevokeRequest, + encodeRevokeResponse, + ), + listUserRefreshTokens: kitgrpc.NewServer( + (listUserRefreshTokensEndpoint(svc)), + decodeListUserRefreshTokensRequest, + encodeListUserRefreshTokensResponse, + ), } } @@ -55,10 +67,11 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *grpcTokenV1.RefreshR func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) { req := grpcReq.(*grpcTokenV1.IssueReq) return issueReq{ - userID: req.GetUserId(), - userRole: auth.Role(req.GetUserRole()), - keyType: auth.KeyType(req.GetType()), - verified: req.Verified, + userID: req.GetUserId(), + userRole: auth.Role(req.GetUserRole()), + keyType: auth.KeyType(req.GetType()), + verified: req.Verified, + description: req.GetDescription(), }, nil } @@ -76,3 +89,49 @@ func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) { AccessType: res.accessType, }, nil } + +func (s *tokenGrpcServer) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq) (*grpcTokenV1.RevokeRes, error) { + _, res, err := s.revoke.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcTokenV1.RevokeRes), nil +} + +func decodeRevokeRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(*grpcTokenV1.RevokeReq) + return revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()}, nil +} + +func encodeRevokeResponse(_ context.Context, grpcRes any) (any, error) { + return &grpcTokenV1.RevokeRes{}, nil +} + +func (s *tokenGrpcServer) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + _, res, err := s.listUserRefreshTokens.ServeGRPC(ctx, req) + if err != nil { + return nil, grpcapi.EncodeError(err) + } + return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil +} + +func decodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) { + req := grpcReq.(*grpcTokenV1.ListUserRefreshTokensReq) + return listUserRefreshTokensReq{userID: req.GetUserId()}, nil +} + +func encodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) { + res := grpcRes.(listUserRefreshTokensRes) + + refreshTokens := make([]*grpcTokenV1.RefreshToken, len(res.refreshTokens)) + for i, refreshToken := range res.refreshTokens { + refreshTokens[i] = &grpcTokenV1.RefreshToken{ + Id: refreshToken.ID, + Description: refreshToken.Description, + } + } + + return &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: refreshTokens, + }, nil +} diff --git a/auth/cache/doc.go b/auth/cache/doc.go index 42396c983..571a973d0 100644 --- a/auth/cache/doc.go +++ b/auth/cache/doc.go @@ -1,4 +1,6 @@ // Copyright (c) Abstract Machines // SPDX-License-Identifier: Apache-2.0 +// Package cache contains the domain concept definitions needed to +// support SuperMQ auth cache service functionality. package cache diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 000000000..9fc8637ca --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,132 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "strconv" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/errors" + "github.com/redis/go-redis/v9" +) + +const ( + refreshPrefix = "refresh_tokens:" + scoreNegInf = "-inf" + scorePosInf = "+inf" +) + +var _ auth.UserActiveTokensCache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewUserActiveTokensCache returns redis auth cache implementation. +func NewUserActiveTokensCache(client *redis.Client, duration time.Duration) (auth.UserActiveTokensCache, error) { + if duration == 0 { + return nil, errors.New("token cache duration must not be zero") + } + return &tokensCache{ + client: client, + keyDuration: duration, + }, nil +} + +// SaveActive saves an active refresh token ID for a user with optional description. +func (tc *tokensCache) SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error { + ttl := min(tc.keyDuration, time.Until(expiry)) + + pipe := tc.client.TxPipeline() + + pipe.Set(ctx, tokenKey(tokenID), description, ttl) + pipe.ZAdd(ctx, userTokensKey(userID), redis.Z{ + Score: float64(time.Now().Add(ttl).Unix()), + Member: tokenID, + }) + + _, err := pipe.Exec(ctx) + + return err +} + +// IsActive checks if the token ID is active for the given user. +func (tc *tokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) { + count, err := tc.client.Exists(ctx, tokenKey(tokenID)).Result() + if err != nil { + return false, err + } + return count > 0, nil +} + +// ListUserTokens lists all active refresh token IDs with descriptions for a user. +func (tc *tokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) { + key := userTokensKey(userID) + now := strconv.FormatInt(time.Now().Unix(), 10) + + pipe := tc.client.TxPipeline() + pipe.ZRemRangeByScore(ctx, key, scoreNegInf, now) + zrangeCmd := pipe.ZRangeByScore(ctx, key, &redis.ZRangeBy{Min: "(" + now, Max: scorePosInf}) + if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { + return nil, err + } + + tokenIDs, err := zrangeCmd.Result() + if err != nil { + return nil, err + } + + if len(tokenIDs) == 0 { + return nil, nil + } + + getPipe := tc.client.Pipeline() + getCmds := make([]*redis.StringCmd, len(tokenIDs)) + for i, tokenID := range tokenIDs { + getCmds[i] = getPipe.Get(ctx, tokenKey(tokenID)) + } + + if _, err = getPipe.Exec(ctx); err != nil && err != redis.Nil { + return nil, err + } + + valid := make([]auth.TokenInfo, 0, len(tokenIDs)) + for i, cmd := range getCmds { + description, err := cmd.Result() + if err == redis.Nil { + continue + } + if err != nil { + return nil, err + } + + valid = append(valid, auth.TokenInfo{ + ID: tokenIDs[i], + Description: description, + }) + } + + return valid, nil +} + +// RemoveActive removes an active refresh token ID for a user. +func (tc *tokensCache) RemoveActive(ctx context.Context, userID, tokenID string) error { + pipe := tc.client.TxPipeline() + pipe.Del(ctx, tokenKey(tokenID)) + pipe.ZRem(ctx, userTokensKey(userID), tokenID) + + _, err := pipe.Exec(ctx) + return err +} + +func tokenKey(tokenID string) string { + return refreshPrefix + "token:" + tokenID +} + +func userTokensKey(userID string) string { + return refreshPrefix + "user_tokens:" + userID +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 000000000..0225b8259 --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,288 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/cache" + "github.com/absmach/supermq/internal/testsutil" + "github.com/absmach/supermq/pkg/errors" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var ( + storeClient *redis.Client + storeURL string +) + +func TestMain(m *testing.M) { + code := testsutil.RunRedisTest(m, &storeClient, &storeURL) + os.Exit(code) +} + +func setupRedisTokensClient() auth.UserActiveTokensCache { + tc, err := cache.NewUserActiveTokensCache(storeClient, 10*time.Minute) + if err != nil { + panic(err) + } + return tc +} + +func TestTokenSave(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + tokenID := testsutil.GenerateUUID(t) + + cases := []struct { + desc string + userID string + tokenID string + description string + expiry time.Time + err error + }{ + { + desc: "Save active token", + userID: userID, + tokenID: tokenID, + description: "Test token", + expiry: time.Now().Add(10 * time.Minute), + err: nil, + }, + { + desc: "Save already cached token", + userID: userID, + tokenID: tokenID, + description: "Updated token", + expiry: time.Now().Add(10 * time.Minute), + err: nil, + }, + { + desc: "Save another token for same user", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + description: "Another token", + expiry: time.Now().Add(10 * time.Minute), + err: nil, + }, + { + desc: "Save token with empty id", + userID: userID, + tokenID: "", + description: "Empty ID token", + expiry: time.Now().Add(10 * time.Minute), + err: nil, + }, + { + desc: "Save token with empty description", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + description: "", + expiry: time.Now().Add(10 * time.Minute), + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.SaveActive(context.Background(), tc.userID, tc.tokenID, tc.description, tc.expiry) + if err == nil { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + assert.NoError(t, err) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + tokenID := testsutil.GenerateUUID(t) + + err := tokensCache.SaveActive(context.Background(), userID, tokenID, "Test token", time.Now().Add(10*time.Minute)) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + userID string + tokenID string + ok bool + }{ + { + desc: "IsActive for existing token", + userID: userID, + tokenID: tokenID, + ok: true, + }, + { + desc: "IsActive for non existing token", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + }, + { + desc: "IsActive with empty token id", + userID: userID, + tokenID: "", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + if tc.ok { + assert.NoError(t, err) + } + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + num := 10 + var tokenIDs []string + for i := range num { + tokenID := testsutil.GenerateUUID(t) + err := tokensCache.SaveActive(context.Background(), userID, tokenID, fmt.Sprintf("Token %d", i), time.Now().Add(10*time.Minute)) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + tokenIDs = append(tokenIDs, tokenID) + } + + cases := []struct { + desc string + userID string + tokenID string + err error + }{ + { + desc: "Remove an existing token from cache", + userID: userID, + tokenID: tokenIDs[0], + err: nil, + }, + { + desc: "Remove token with empty id from cache", + userID: userID, + tokenID: "", + err: nil, + }, + { + desc: "Remove non existing id from cache", + userID: userID, + tokenID: testsutil.GenerateUUID(t), + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.RemoveActive(context.Background(), tc.userID, tc.tokenID) + assert.True(t, errors.Contains(err, tc.err)) + if err == nil { + ok, err := tokensCache.IsActive(context.Background(), tc.tokenID) + assert.NoError(t, err) + assert.False(t, ok) + } + }) + } +} + +func TestListUserTokens(t *testing.T) { + storeClient.FlushAll(context.Background()) + tokensCache := setupRedisTokensClient() + + userID := testsutil.GenerateUUID(t) + userID2 := testsutil.GenerateUUID(t) + num := 5 + var expectedTokens []auth.TokenInfo + + for i := range num { + tokenID := testsutil.GenerateUUID(t) + description := fmt.Sprintf("Token %d", i) + err := tokensCache.SaveActive(context.Background(), userID, tokenID, description, time.Now().Add(10*time.Minute)) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + expectedTokens = append(expectedTokens, auth.TokenInfo{ + ID: tokenID, + Description: description, + }) + } + + tokenID2 := testsutil.GenerateUUID(t) + desc2 := "User 2 token" + err := tokensCache.SaveActive(context.Background(), userID2, tokenID2, desc2, time.Now().Add(10*time.Minute)) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + userID string + expectedCount int + expectedTokens []auth.TokenInfo + err error + }{ + { + desc: "List all tokens for user with multiple tokens", + userID: userID, + expectedCount: num, + expectedTokens: expectedTokens, + err: nil, + }, + { + desc: "List tokens for user with single token", + userID: userID2, + expectedCount: 1, + expectedTokens: []auth.TokenInfo{{ID: tokenID2, Description: desc2}}, + err: nil, + }, + { + desc: "List tokens for user with no tokens", + userID: testsutil.GenerateUUID(t), + expectedCount: 0, + expectedTokens: nil, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tokens, err := tokensCache.ListUserTokens(context.Background(), tc.userID) + assert.True(t, errors.Contains(err, tc.err)) + assert.Equal(t, tc.expectedCount, len(tokens)) + if tc.expectedTokens != nil { + assert.ElementsMatch(t, tc.expectedTokens, tokens) + } + }) + } + + t.Run("Cleanup expired tokens from list", func(t *testing.T) { + // Remove one token directly from Redis to simulate expiration + err := tokensCache.RemoveActive(context.Background(), userID, expectedTokens[0].ID) + assert.NoError(t, err) + + // List should now return only valid tokens + tokens, err := tokensCache.ListUserTokens(context.Background(), userID) + assert.NoError(t, err) + assert.Equal(t, num-1, len(tokens)) + + // Check that the removed token is not in the list + for _, token := range tokens { + assert.NotEqual(t, expectedTokens[0].ID, token.ID) + } + }) +} diff --git a/auth/key_manager.go b/auth/key_manager.go index 73c74db82..6d23a60bb 100644 --- a/auth/key_manager.go +++ b/auth/key_manager.go @@ -5,13 +5,16 @@ package auth import ( "context" - "errors" + "time" + + "github.com/absmach/supermq/pkg/errors" ) var ( ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm") ErrInvalidSymmetricKey = errors.New("invalid symmetric key") ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm") + ErrRevokedToken = errors.NewAuthNError("token is revoked") ) // PublicKeyInfo represents a public key for external distribution via JWKS. @@ -33,6 +36,7 @@ type PublicKeyInfo struct { // Implementations manage underlying cryptographic operations and key distribution. type Tokenizer interface { // Issue creates a signed token string from the given key claims. + // For RefreshKey types, the token ID is stored as active in the cache. Issue(key Key) (token string, err error) // Parse verifies and parses a token string (JWT or PAT), returning the extracted claims. @@ -45,6 +49,27 @@ type Tokenizer interface { RetrieveJWKS() ([]PublicKeyInfo, error) } +// UserActiveTokensCache represents a cache repository for managing active refresh tokens per user. +type UserActiveTokensCache interface { + // SaveActive saves an active refresh token ID for a user with optional description. + SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error + + // IsActive checks if the token ID is active. + IsActive(ctx context.Context, tokenID string) (bool, error) + + // ListUserTokens lists all active token IDs with descriptions for a given user. + ListUserTokens(ctx context.Context, userID string) ([]TokenInfo, error) + + // RemoveActive removes an active refresh token ID. + RemoveActive(ctx context.Context, userID, tokenID string) error +} + +// TokenInfo represents information about an active refresh token. +type TokenInfo struct { + ID string `json:"id"` + Description string `json:"description,omitempty"` +} + // IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based). // Returns true for HMAC algorithms (HS256, HS384, HS512). // Returns false for asymmetric algorithms (EdDSA). diff --git a/auth/keys.go b/auth/keys.go index efde3a00c..4577987bf 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -81,14 +81,15 @@ func (r Role) Validate() bool { // Key represents API key. type Key struct { - ID string `json:"id,omitempty"` - Type KeyType `json:"type,omitempty"` - Issuer string `json:"issuer,omitempty"` - Subject string `json:"subject,omitempty"` // user ID - Role Role `json:"role,omitempty"` - IssuedAt time.Time `json:"issued_at,omitempty"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - Verified bool `json:"verified,omitempty"` + ID string `json:"id,omitempty"` + Type KeyType `json:"type,omitempty"` + Issuer string `json:"issuer,omitempty"` + Subject string `json:"subject,omitempty"` // user ID + Role Role `json:"role,omitempty"` + IssuedAt time.Time `json:"issued_at,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Verified bool `json:"verified,omitempty"` + Description string `json:"description,omitempty"` // Optional description for refresh tokens } func (key Key) String() string { diff --git a/auth/middleware/logging.go b/auth/middleware/logging.go index 95dbcad3f..d95dfab05 100644 --- a/auth/middleware/logging.go +++ b/auth/middleware/logging.go @@ -110,6 +110,42 @@ func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) { return lm.svc.RetrieveJWKS() } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", userID), + slog.String("token_id", tokenID), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Revoke token failed", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, userID, tokenID) +} + +func (lm *loggingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) (tokens []auth.TokenInfo, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", userID), + slog.Int("tokens_count", len(tokens)), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("List user refresh tokens failed", args...) + return + } + lm.logger.Info("List user refresh tokens completed successfully", args...) + }(time.Now()) + + return lm.svc.ListUserRefreshTokens(ctx, userID) +} + func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) (err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/middleware/metrics.go b/auth/middleware/metrics.go index 18efd72cc..92d07e997 100644 --- a/auth/middleware/metrics.go +++ b/auth/middleware/metrics.go @@ -40,6 +40,15 @@ func (ms *metricsMiddleware) Issue(ctx context.Context, token string, key auth.K return ms.svc.Issue(ctx, token, key) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, userID, tokenID) +} + func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error { defer func(begin time.Time) { ms.counter.With("method", "revoke_key").Add(1) @@ -75,6 +84,15 @@ func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo { return ms.svc.RetrieveJWKS() } +func (ms *metricsMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_user_refresh_tokens").Add(1) + ms.latency.With("method", "list_user_refresh_tokens").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.ListUserRefreshTokens(ctx, userID) +} + func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error { defer func(begin time.Time) { ms.counter.With("method", "authorize").Add(1) diff --git a/auth/middleware/tracing.go b/auth/middleware/tracing.go index c975bf332..ccdff8983 100644 --- a/auth/middleware/tracing.go +++ b/auth/middleware/tracing.go @@ -65,6 +65,25 @@ func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo { return tm.svc.RetrieveJWKS() } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_token", trace.WithAttributes( + attribute.String("user_id", userID), + attribute.String("token_id", tokenID), + )) + defer span.End() + + return tm.svc.RevokeToken(ctx, userID, tokenID) +} + +func (tm *tracingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) { + ctx, span := tm.tracer.Start(ctx, "list_user_refresh_tokens", trace.WithAttributes( + attribute.String("user_id", userID), + )) + defer span.End() + + return tm.svc.ListUserRefreshTokens(ctx, userID) +} + func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error { attributes := []attribute.KeyValue{ attribute.String("subject", pr.Subject), diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 136408bf7..2f16a68da 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -758,6 +758,74 @@ func (_c *Service_ListScopes_Call) RunAndReturn(run func(ctx context.Context, to return _c } +// ListUserRefreshTokens provides a mock function for the type Service +func (_mock *Service) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) { + ret := _mock.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for ListUserRefreshTokens") + } + + var r0 []auth.TokenInfo + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok { + return returnFunc(ctx, userID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok { + r0 = returnFunc(ctx, userID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]auth.TokenInfo) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, userID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens' +type Service_ListUserRefreshTokens_Call struct { + *mock.Call +} + +// ListUserRefreshTokens is a helper method to define mock.On call +// - ctx context.Context +// - userID string +func (_e *Service_Expecter) ListUserRefreshTokens(ctx interface{}, userID interface{}) *Service_ListUserRefreshTokens_Call { + return &Service_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens", ctx, userID)} +} + +func (_c *Service_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, userID string)) *Service_ListUserRefreshTokens_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_ListUserRefreshTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *Service_ListUserRefreshTokens_Call { + _c.Call.Return(tokenInfos, err) + return _c +} + +func (_c *Service_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *Service_ListUserRefreshTokens_Call { + _c.Call.Return(run) + return _c +} + // RemoveAllPAT provides a mock function for the type Service func (_mock *Service) RemoveAllPAT(ctx context.Context, token string) error { ret := _mock.Called(ctx, token) @@ -1350,6 +1418,69 @@ func (_c *Service_RevokePATSecret_Call) RunAndReturn(run func(ctx context.Contex return _c } +// RevokeToken provides a mock function for the type Service +func (_mock *Service) RevokeToken(ctx context.Context, userID string, tokenID string) error { + ret := _mock.Called(ctx, userID, tokenID) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, userID, tokenID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeToken' +type Service_RevokeToken_Call struct { + *mock.Call +} + +// RevokeToken is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - tokenID string +func (_e *Service_Expecter) RevokeToken(ctx interface{}, userID interface{}, tokenID interface{}) *Service_RevokeToken_Call { + return &Service_RevokeToken_Call{Call: _e.mock.On("RevokeToken", ctx, userID, tokenID)} +} + +func (_c *Service_RevokeToken_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *Service_RevokeToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RevokeToken_Call) Return(err error) *Service_RevokeToken_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeToken_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *Service_RevokeToken_Call { + _c.Call.Return(run) + return _c +} + // UpdatePATDescription provides a mock function for the type Service func (_mock *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { ret := _mock.Called(ctx, token, patID, description) diff --git a/auth/mocks/token_client.go b/auth/mocks/token_client.go index 025065092..72f81de07 100644 --- a/auth/mocks/token_client.go +++ b/auth/mocks/token_client.go @@ -126,6 +126,89 @@ func (_c *TokenServiceClient_Issue_Call) RunAndReturn(run func(ctx context.Conte return _c } +// ListUserRefreshTokens provides a mock function for the type TokenServiceClient +func (_mock *TokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for ListUserRefreshTokens") + } + + var r0 *v1.ListUserRefreshTokensRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) *v1.ListUserRefreshTokensRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokenServiceClient_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens' +type TokenServiceClient_ListUserRefreshTokens_Call struct { + *mock.Call +} + +// ListUserRefreshTokens is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.ListUserRefreshTokensReq +// - opts ...grpc.CallOption +func (_e *TokenServiceClient_Expecter) ListUserRefreshTokens(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_ListUserRefreshTokens_Call { + return &TokenServiceClient_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption)) *TokenServiceClient_ListUserRefreshTokens_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.ListUserRefreshTokensReq + if args[1] != nil { + arg1 = args[1].(*v1.ListUserRefreshTokensReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *TokenServiceClient_ListUserRefreshTokens_Call { + _c.Call.Return(listUserRefreshTokensRes, err) + return _c +} + +func (_c *TokenServiceClient_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)) *TokenServiceClient_ListUserRefreshTokens_Call { + _c.Call.Return(run) + return _c +} + // Refresh provides a mock function for the type TokenServiceClient func (_mock *TokenServiceClient) Refresh(ctx context.Context, in *v1.RefreshReq, opts ...grpc.CallOption) (*v1.Token, error) { var tmpRet mock.Arguments @@ -208,3 +291,86 @@ func (_c *TokenServiceClient_Refresh_Call) RunAndReturn(run func(ctx context.Con _c.Call.Return(run) return _c } + +// Revoke provides a mock function for the type TokenServiceClient +func (_mock *TokenServiceClient) Revoke(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error) { + var tmpRet mock.Arguments + if len(opts) > 0 { + tmpRet = _mock.Called(ctx, in, opts) + } else { + tmpRet = _mock.Called(ctx, in) + } + ret := tmpRet + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 *v1.RevokeRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*v1.RevokeRes, error)); ok { + return returnFunc(ctx, in, opts...) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *v1.RevokeRes); ok { + r0 = returnFunc(ctx, in, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.RevokeRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok { + r1 = returnFunc(ctx, in, opts...) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokenServiceClient_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke' +type TokenServiceClient_Revoke_Call struct { + *mock.Call +} + +// Revoke is a helper method to define mock.On call +// - ctx context.Context +// - in *v1.RevokeReq +// - opts ...grpc.CallOption +func (_e *TokenServiceClient_Expecter) Revoke(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_Revoke_Call { + return &TokenServiceClient_Revoke_Call{Call: _e.mock.On("Revoke", + append([]interface{}{ctx, in}, opts...)...)} +} + +func (_c *TokenServiceClient_Revoke_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *TokenServiceClient_Revoke_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 *v1.RevokeReq + if args[1] != nil { + arg1 = args[1].(*v1.RevokeReq) + } + var arg2 []grpc.CallOption + var variadicArgs []grpc.CallOption + if len(args) > 2 { + variadicArgs = args[2].([]grpc.CallOption) + } + arg2 = variadicArgs + run( + arg0, + arg1, + arg2..., + ) + }) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) Return(revokeRes *v1.RevokeRes, err error) *TokenServiceClient_Revoke_Call { + _c.Call.Return(revokeRes, err) + return _c +} + +func (_c *TokenServiceClient_Revoke_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error)) *TokenServiceClient_Revoke_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/mocks/user_active_tokens_cache.go b/auth/mocks/user_active_tokens_cache.go new file mode 100644 index 000000000..450819c10 --- /dev/null +++ b/auth/mocks/user_active_tokens_cache.go @@ -0,0 +1,316 @@ +// Copyright (c) Abstract Machines + +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package mocks + +import ( + "context" + "time" + + "github.com/absmach/supermq/auth" + mock "github.com/stretchr/testify/mock" +) + +// NewUserActiveTokensCache creates a new instance of UserActiveTokensCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewUserActiveTokensCache(t interface { + mock.TestingT + Cleanup(func()) +}) *UserActiveTokensCache { + mock := &UserActiveTokensCache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// UserActiveTokensCache is an autogenerated mock type for the UserActiveTokensCache type +type UserActiveTokensCache struct { + mock.Mock +} + +type UserActiveTokensCache_Expecter struct { + mock *mock.Mock +} + +func (_m *UserActiveTokensCache) EXPECT() *UserActiveTokensCache_Expecter { + return &UserActiveTokensCache_Expecter{mock: &_m.Mock} +} + +// IsActive provides a mock function for the type UserActiveTokensCache +func (_mock *UserActiveTokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) { + ret := _mock.Called(ctx, tokenID) + + if len(ret) == 0 { + panic("no return value specified for IsActive") + } + + var r0 bool + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return returnFunc(ctx, tokenID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = returnFunc(ctx, tokenID) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, tokenID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// UserActiveTokensCache_IsActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsActive' +type UserActiveTokensCache_IsActive_Call struct { + *mock.Call +} + +// IsActive is a helper method to define mock.On call +// - ctx context.Context +// - tokenID string +func (_e *UserActiveTokensCache_Expecter) IsActive(ctx interface{}, tokenID interface{}) *UserActiveTokensCache_IsActive_Call { + return &UserActiveTokensCache_IsActive_Call{Call: _e.mock.On("IsActive", ctx, tokenID)} +} + +func (_c *UserActiveTokensCache_IsActive_Call) Run(run func(ctx context.Context, tokenID string)) *UserActiveTokensCache_IsActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *UserActiveTokensCache_IsActive_Call) Return(b bool, err error) *UserActiveTokensCache_IsActive_Call { + _c.Call.Return(b, err) + return _c +} + +func (_c *UserActiveTokensCache_IsActive_Call) RunAndReturn(run func(ctx context.Context, tokenID string) (bool, error)) *UserActiveTokensCache_IsActive_Call { + _c.Call.Return(run) + return _c +} + +// ListUserTokens provides a mock function for the type UserActiveTokensCache +func (_mock *UserActiveTokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) { + ret := _mock.Called(ctx, userID) + + if len(ret) == 0 { + panic("no return value specified for ListUserTokens") + } + + var r0 []auth.TokenInfo + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok { + return returnFunc(ctx, userID) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok { + r0 = returnFunc(ctx, userID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]auth.TokenInfo) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = returnFunc(ctx, userID) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// UserActiveTokensCache_ListUserTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserTokens' +type UserActiveTokensCache_ListUserTokens_Call struct { + *mock.Call +} + +// ListUserTokens is a helper method to define mock.On call +// - ctx context.Context +// - userID string +func (_e *UserActiveTokensCache_Expecter) ListUserTokens(ctx interface{}, userID interface{}) *UserActiveTokensCache_ListUserTokens_Call { + return &UserActiveTokensCache_ListUserTokens_Call{Call: _e.mock.On("ListUserTokens", ctx, userID)} +} + +func (_c *UserActiveTokensCache_ListUserTokens_Call) Run(run func(ctx context.Context, userID string)) *UserActiveTokensCache_ListUserTokens_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *UserActiveTokensCache_ListUserTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *UserActiveTokensCache_ListUserTokens_Call { + _c.Call.Return(tokenInfos, err) + return _c +} + +func (_c *UserActiveTokensCache_ListUserTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *UserActiveTokensCache_ListUserTokens_Call { + _c.Call.Return(run) + return _c +} + +// RemoveActive provides a mock function for the type UserActiveTokensCache +func (_mock *UserActiveTokensCache) RemoveActive(ctx context.Context, userID string, tokenID string) error { + ret := _mock.Called(ctx, userID, tokenID) + + if len(ret) == 0 { + panic("no return value specified for RemoveActive") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = returnFunc(ctx, userID, tokenID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// UserActiveTokensCache_RemoveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveActive' +type UserActiveTokensCache_RemoveActive_Call struct { + *mock.Call +} + +// RemoveActive is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - tokenID string +func (_e *UserActiveTokensCache_Expecter) RemoveActive(ctx interface{}, userID interface{}, tokenID interface{}) *UserActiveTokensCache_RemoveActive_Call { + return &UserActiveTokensCache_RemoveActive_Call{Call: _e.mock.On("RemoveActive", ctx, userID, tokenID)} +} + +func (_c *UserActiveTokensCache_RemoveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *UserActiveTokensCache_RemoveActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *UserActiveTokensCache_RemoveActive_Call) Return(err error) *UserActiveTokensCache_RemoveActive_Call { + _c.Call.Return(err) + return _c +} + +func (_c *UserActiveTokensCache_RemoveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *UserActiveTokensCache_RemoveActive_Call { + _c.Call.Return(run) + return _c +} + +// SaveActive provides a mock function for the type UserActiveTokensCache +func (_mock *UserActiveTokensCache) SaveActive(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error { + ret := _mock.Called(ctx, userID, tokenID, description, expiry) + + if len(ret) == 0 { + panic("no return value specified for SaveActive") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) error); ok { + r0 = returnFunc(ctx, userID, tokenID, description, expiry) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// UserActiveTokensCache_SaveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveActive' +type UserActiveTokensCache_SaveActive_Call struct { + *mock.Call +} + +// SaveActive is a helper method to define mock.On call +// - ctx context.Context +// - userID string +// - tokenID string +// - description string +// - expiry time.Time +func (_e *UserActiveTokensCache_Expecter) SaveActive(ctx interface{}, userID interface{}, tokenID interface{}, description interface{}, expiry interface{}) *UserActiveTokensCache_SaveActive_Call { + return &UserActiveTokensCache_SaveActive_Call{Call: _e.mock.On("SaveActive", ctx, userID, tokenID, description, expiry)} +} + +func (_c *UserActiveTokensCache_SaveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time)) *UserActiveTokensCache_SaveActive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } + var arg4 time.Time + if args[4] != nil { + arg4 = args[4].(time.Time) + } + run( + arg0, + arg1, + arg2, + arg3, + arg4, + ) + }) + return _c +} + +func (_c *UserActiveTokensCache_SaveActive_Call) Return(err error) *UserActiveTokensCache_SaveActive_Call { + _c.Call.Return(err) + return _c +} + +func (_c *UserActiveTokensCache_SaveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error) *UserActiveTokensCache_SaveActive_Call { + _c.Call.Return(run) + return _c +} diff --git a/auth/postgres/errors.go b/auth/postgres/errors.go new file mode 100644 index 000000000..45477426e --- /dev/null +++ b/auth/postgres/errors.go @@ -0,0 +1,24 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import "github.com/absmach/supermq/pkg/errors" + +var _ errors.Mapper = (*duplicateErrors)(nil) + +type duplicateErrors struct{} + +// GetError maps constraint names to known errors. +func (d duplicateErrors) GetError(constraint string) (error, bool) { + switch constraint { + case "revoked_tokens_pkey": + return errors.NewRequestError("revoked token already exists"), true + default: + return nil, false + } +} + +func NewDuplicateErrors() errors.Mapper { + return duplicateErrors{} +} diff --git a/auth/service.go b/auth/service.go index 0cfd53db5..09a78461d 100644 --- a/auth/service.go +++ b/auth/service.go @@ -29,13 +29,16 @@ var ( // ErrExpiry indicates that the token is expired. ErrExpiry = errors.New("token is expired") - errIssueUser = errors.New("failed to issue new login key") - errIssueTmp = errors.New("failed to issue new temporary key") - errRevoke = errors.New("failed to remove key") - errRetrieve = errors.New("failed to retrieve key data") - errIdentify = errors.New("failed to validate token") - errPlatform = errors.New("invalid platform id") - errRoleAuth = errors.New("failed to authorize user role") + errIssueUser = errors.New("failed to issue new login key") + errIssueTmp = errors.New("failed to issue new temporary key") + errRevoke = errors.New("failed to remove key") + errRetrieve = errors.New("failed to retrieve key data") + errIdentify = errors.New("failed to validate token") + errPlatform = errors.New("invalid platform id") + errRoleAuth = errors.New("failed to authorize user role") + errSaveRefreshKey = errors.NewServiceError("failed to save refresh key") + errRevokeRefreshKey = errors.NewServiceError("failed to revoke refresh key") + errListRefreshKeys = errors.NewServiceError("failed to list refresh keys") errMalformedPAT = errors.New("malformed personal access token") errFailedToParseUUID = errors.New("failed to parse string to UUID") @@ -67,6 +70,9 @@ type Authn interface { // Issue issues a new Key, returning its token value alongside. Issue(ctx context.Context, token string, key Key) (Token, error) + // RevokeToken revokes the refresh token by its ID. + RevokeToken(ctx context.Context, userID, tokenID string) error + // Revoke removes the Key with the provided id that is // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error @@ -82,6 +88,9 @@ type Authn interface { // RetrieveJWKS retrieves public keys to validate issued tokens. RetrieveJWKS() []PublicKeyInfo + + // ListUserRefreshTokens lists all active refresh token sessions for a user. + ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error) } // Service specifies an API that must be fulfilled by the domain service @@ -100,6 +109,7 @@ type service struct { keys KeyRepository pats PATSRepository cache Cache + tokensCache UserActiveTokensCache hasher Hasher idProvider supermq.IDProvider evaluator policies.Evaluator @@ -111,12 +121,13 @@ type service struct { } // New instantiates the auth service implementation. -func New(keys KeyRepository, pats PATSRepository, cache Cache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { +func New(keys KeyRepository, pats PATSRepository, cache Cache, tokensCache UserActiveTokensCache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service { return &service{ tokenizer: tokenizer, keys: keys, pats: pats, cache: cache, + tokensCache: tokensCache, hasher: hasher, idProvider: idp, evaluator: policyEvaluator, @@ -143,6 +154,14 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } } +func (svc service) RevokeToken(ctx context.Context, userID, tokenID string) error { + if err := svc.tokensCache.RemoveActive(ctx, userID, tokenID); err != nil { + return errors.Wrap(errRevokeRefreshKey, err) + } + + return nil +} + func (svc service) Revoke(ctx context.Context, token, id string) error { issuerID, _, err := svc.authenticate(ctx, token) if err != nil { @@ -205,6 +224,15 @@ func (svc service) RetrieveJWKS() []PublicKeyInfo { return keys } +func (svc service) ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error) { + tokenInfo, err := svc.tokensCache.ListUserTokens(ctx, userID) + if err != nil { + return nil, errors.Wrap(errListRefreshKeys, err) + } + + return tokenInfo, nil +} + func (svc service) Authorize(ctx context.Context, pr policies.Policy, patAuthz *PATAuthz) error { if patAuthz != nil { if err := svc.AuthorizePAT(ctx, patAuthz.UserID, patAuthz.PatID, patAuthz.EntityType, patAuthz.Domain, patAuthz.Operation, patAuthz.EntityID); err != nil { @@ -265,10 +293,20 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) { key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration) key.Type = RefreshKey + id, err := svc.idProvider.ID() + if err != nil { + return Token{}, errors.Wrap(errIssueTmp, err) + } + key.ID = id refresh, err := svc.tokenizer.Issue(key) if err != nil { return Token{}, errors.Wrap(errIssueTmp, err) } + if key.Subject != "" && key.ExpiresAt.After(time.Now()) { + if err := svc.tokensCache.SaveActive(ctx, key.Subject, key.ID, key.Description, key.ExpiresAt); err != nil { + return Token{}, errors.Wrap(errSaveRefreshKey, err) + } + } return Token{AccessToken: access, RefreshToken: refresh}, nil } @@ -298,6 +336,13 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token if k.Type != RefreshKey { return Token{}, errIssueUser } + ok, err := svc.tokensCache.IsActive(ctx, key.ID) + if err != nil { + return Token{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + if !ok { + return Token{}, ErrRevokedToken + } key.ID = k.ID key.Type = AccessKey key.Subject = k.Subject @@ -313,7 +358,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token return Token{}, errors.Wrap(errIssueTmp, err) } - key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration) + key.ExpiresAt = k.ExpiresAt key.Type = RefreshKey refresh, err := svc.tokenizer.Issue(key) if err != nil { diff --git a/auth/service_test.go b/auth/service_test.go index 575e1a295..219e6c031 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -54,18 +54,20 @@ var ( ) var ( - krepo *mocks.KeyRepository - pService *policymocks.Service - pEvaluator *policymocks.Evaluator - patsrepo *mocks.PATSRepository - cache *mocks.Cache - hasher *mocks.Hasher - tokenizer *mocks.Tokenizer + krepo *mocks.KeyRepository + pService *policymocks.Service + pEvaluator *policymocks.Evaluator + patsrepo *mocks.PATSRepository + cache *mocks.Cache + tokensCache *mocks.UserActiveTokensCache + hasher *mocks.Hasher + tokenizer *mocks.Tokenizer ) func newService(t *testing.T) (auth.Service, string) { krepo = new(mocks.KeyRepository) cache = new(mocks.Cache) + tokensCache = new(mocks.UserActiveTokensCache) pService = new(policymocks.Service) pEvaluator = new(policymocks.Evaluator) patsrepo = new(mocks.PATSRepository) @@ -76,7 +78,7 @@ func newService(t *testing.T) (auth.Service, string) { token, _, err := signToken(t, issuerName, accessKey, false) assert.Nil(t, err, fmt.Sprintf("Issuing access key expected to succeed: %s", err)) - return auth.New(krepo, patsrepo, cache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, patsrepo, cache, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token } func TestIssue(t *testing.T) { @@ -133,7 +135,7 @@ func TestIssue(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, SubjectType: policies.UserType, @@ -154,6 +156,7 @@ func TestIssue(t *testing.T) { saveResponse auth.Key token string tokenizerErr error + cacheErr error saveErr error roleCheckErr error err error @@ -169,11 +172,24 @@ func TestIssue(t *testing.T) { token: accessToken, err: nil, }, + { + desc: "issue access key with cache error", + key: auth.Key{ + Type: auth.AccessKey, + Subject: userID, + Role: auth.UserRole, + IssuedAt: time.Now(), + }, + token: accessToken, + cacheErr: svcerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, } for _, tc := range cases2 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr) repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) + cacheCall := tokensCache.On("SaveActive", context.Background(), tc.key.Subject, mock.Anything, tc.key.Description, mock.Anything).Return(tc.cacheErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, SubjectType: policies.UserType, @@ -186,6 +202,7 @@ func TestIssue(t *testing.T) { tokenizerCall.Unset() repoCall.Unset() policyCall.Unset() + cacheCall.Unset() }) } @@ -265,7 +282,7 @@ func TestIssue(t *testing.T) { } for _, tc := range cases3 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr) tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr) repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ @@ -292,6 +309,8 @@ func TestIssue(t *testing.T) { parseErr error roleCheckErr error issueErr error + cacheRes bool + cacheErr error err error }{ { @@ -304,6 +323,7 @@ func TestIssue(t *testing.T) { }, token: refreshToken, parseRes: refreshkey, + cacheRes: true, err: nil, }, { @@ -339,15 +359,46 @@ func TestIssue(t *testing.T) { Role: auth.UserRole, }, token: refreshToken, + cacheRes: true, parseRes: refreshkey, roleCheckErr: errRoleAuth, err: errRoleAuth, }, + { + desc: "issue refresh key with revoked refresh token", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + Subject: userID, + Role: auth.UserRole, + }, + token: refreshToken, + parseRes: refreshkey, + cacheRes: false, + cacheErr: nil, + err: auth.ErrRevokedToken, + }, + { + desc: "issue refresh key with cache error", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + Subject: userID, + Role: auth.UserRole, + }, + token: refreshToken, + parseRes: refreshkey, + cacheRes: false, + cacheErr: svcerr.ErrCreateEntity, + err: svcerr.ErrCreateEntity, + }, } for _, tc := range cases4 { t.Run(tc.desc, func(t *testing.T) { - tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr) + tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr) tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr) + tokenizerCall2 := tokenizer.On("Revoke", mock.Anything, tc.token).Return(tc.parseErr) + cacheCall := tokensCache.On("IsActive", context.Background(), tc.key.ID).Return(tc.cacheRes, tc.cacheErr) policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{ Subject: tc.key.Subject, SubjectType: policies.UserType, @@ -359,7 +410,9 @@ func TestIssue(t *testing.T) { assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) tokenizerCall.Unset() tokenizerCall1.Unset() + tokenizerCall2.Unset() policyCall.Unset() + cacheCall.Unset() }) } } @@ -642,6 +695,173 @@ func TestIdentify(t *testing.T) { } } +func TestRevokeToken(t *testing.T) { + svc, _ := newService(t) + + cases := []struct { + desc string + userID string + tokenID string + removeErr error + err error + }{ + { + desc: "revoke token successfully", + userID: validID, + tokenID: "validTokenID", + removeErr: nil, + err: nil, + }, + { + desc: "revoke token with cache error", + userID: validID, + tokenID: "validTokenID", + removeErr: svcerr.ErrRemoveEntity, + err: svcerr.ErrRemoveEntity, + }, + { + desc: "revoke token with empty token ID", + userID: validID, + tokenID: "", + removeErr: nil, + err: nil, + }, + { + desc: "revoke token not found", + userID: validID, + tokenID: "nonExistentTokenID", + removeErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + cacheCall := tokensCache.On("RemoveActive", mock.Anything, tc.userID, tc.tokenID).Return(tc.removeErr) + err := svc.RevokeToken(context.Background(), tc.userID, tc.tokenID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + cacheCall.Unset() + }) + } +} + +func TestRetrieveJWKS(t *testing.T) { + svc, _ := newService(t) + + publicKeys := []auth.PublicKeyInfo{ + { + KeyID: "key1", + Algorithm: "RS256", + }, + { + KeyID: "key2", + Algorithm: "RS256", + }, + } + + cases := []struct { + desc string + tokenizerRes []auth.PublicKeyInfo + tokenizerErr error + expectedResult []auth.PublicKeyInfo + }{ + { + desc: "retrieve JWKS successfully", + tokenizerRes: publicKeys, + tokenizerErr: nil, + expectedResult: publicKeys, + }, + { + desc: "retrieve JWKS with tokenizer error", + tokenizerRes: nil, + tokenizerErr: svcerr.ErrViewEntity, + expectedResult: nil, + }, + { + desc: "retrieve JWKS with empty keys", + tokenizerRes: []auth.PublicKeyInfo{}, + tokenizerErr: nil, + expectedResult: []auth.PublicKeyInfo{}, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + tokenizerCall := tokenizer.On("RetrieveJWKS").Return(tc.tokenizerRes, tc.tokenizerErr) + result := svc.RetrieveJWKS() + assert.Equal(t, tc.expectedResult, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedResult, result)) + tokenizerCall.Unset() + }) + } +} + +func TestListUserRefreshTokens(t *testing.T) { + svc, _ := newService(t) + + tokenInfos := []auth.TokenInfo{ + { + ID: "token1", + Description: "Session 1", + }, + { + ID: "token2", + Description: "Session 2", + }, + } + + cases := []struct { + desc string + userID string + cacheRes []auth.TokenInfo + cacheErr error + expectedRes []auth.TokenInfo + err error + }{ + { + desc: "list user refresh tokens successfully", + userID: userID, + cacheRes: tokenInfos, + cacheErr: nil, + expectedRes: tokenInfos, + err: nil, + }, + { + desc: "list user refresh tokens with cache error", + userID: userID, + cacheRes: nil, + cacheErr: svcerr.ErrViewEntity, + expectedRes: nil, + err: svcerr.ErrViewEntity, + }, + { + desc: "list user refresh tokens with empty result", + userID: userID, + cacheRes: []auth.TokenInfo{}, + cacheErr: nil, + expectedRes: []auth.TokenInfo{}, + err: nil, + }, + { + desc: "list user refresh tokens with invalid user ID", + userID: "", + cacheRes: nil, + cacheErr: svcerr.ErrViewEntity, + expectedRes: nil, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + cacheCall := tokensCache.On("ListUserTokens", mock.Anything, tc.userID).Return(tc.cacheRes, tc.cacheErr) + result, err := svc.ListUserRefreshTokens(context.Background(), tc.userID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected error %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.expectedRes, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedRes, result)) + cacheCall.Unset() + }) + } +} + func TestAuthorize(t *testing.T) { svc, _ := newService(t) diff --git a/auth/tokenizer/asymmetric/README.md b/auth/tokenizer/asymmetric/README.md index d64f96bd9..2bade9fdb 100644 --- a/auth/tokenizer/asymmetric/README.md +++ b/auth/tokenizer/asymmetric/README.md @@ -28,6 +28,7 @@ export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key" ``` The tokenizer will: + - Issue new tokens signed with the active key - Verify tokens using the active key - Return one public key in JWKS endpoint @@ -42,6 +43,7 @@ export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key" ``` The tokenizer will: + - Issue new tokens signed with the active key - Verify tokens using both active and retiring keys - Return both public keys in JWKS endpoint @@ -103,10 +105,12 @@ The grace period should be longer than your longest-lived access token duration. - Store private keys with `0600` permissions - Use cryptographically secure key generation: + ```bash openssl genpkey -algorithm Ed25519 -out private.key chmod 600 private.key ``` + - Rotate keys regularly: - Standard environments: every 90 days - High-security environments: every 30 days @@ -139,7 +143,7 @@ rm ./keys/key-2024.pem ### Active key not found -``` +```bash Error: active key file not found: ./keys/active.key ``` @@ -149,7 +153,7 @@ Error: active key file not found: ./keys/active.key If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key: -``` +```bash WARN: failed to load retiring key, continuing without it ``` @@ -157,7 +161,7 @@ This is by design - a missing retiring key won't prevent startup. ### Invalid key format -``` +```bash Error: failed to parse private key ``` diff --git a/auth/tokenizer/asymmetric/tokenizer.go b/auth/tokenizer/asymmetric/tokenizer.go index ac8d168ec..aa035215a 100644 --- a/auth/tokenizer/asymmetric/tokenizer.go +++ b/auth/tokenizer/asymmetric/tokenizer.go @@ -24,8 +24,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) -const patPrefix = "pat" - var ( errLoadingPrivateKey = errors.New("failed to load private key") errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID") @@ -95,8 +93,8 @@ func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDPr return mgr, nil } -func (km *tokenizer) Issue(key auth.Key) (string, error) { - if km.activeKey == nil { +func (tok *tokenizer) Issue(key auth.Key) (string, error) { + if tok.activeKey == nil { return "", errNoActiveKey } @@ -105,11 +103,11 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) { return "", err } headers := jws.NewHeaders() - if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil { + if err := headers.Set(jwk.KeyIDKey, tok.activeKey.id); err != nil { return "", err } - signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers))) + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, tok.activeKey.privateKey, jws.WithProtectedHeaders(headers))) if err != nil { return "", err } @@ -117,17 +115,17 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) { return string(signedBytes), nil } -func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { - if len(tokenString) >= 3 && tokenString[:3] == patPrefix { +func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix { return auth.Key{Type: auth.PersonalAccessToken}, nil } set := jwk.NewSet() - if err := set.AddKey(km.activeKey.publicKey); err != nil { + if err := set.AddKey(tok.activeKey.publicKey); err != nil { return auth.Key{}, err } - if km.retiringKey != nil { - if err := set.AddKey(km.retiringKey.publicKey); err != nil { + if tok.retiringKey != nil { + if err := set.AddKey(tok.retiringKey.publicKey); err != nil { return auth.Key{}, err } } @@ -148,17 +146,17 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e return smqjwt.ToKey(tkn) } -func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { +func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { publicKeys := make([]auth.PublicKeyInfo, 0, 2) - if km.activeKey != nil { - if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil { + if tok.activeKey != nil { + if pkInfo := extractPublicKeyInfo(tok.activeKey); pkInfo != nil { publicKeys = append(publicKeys, *pkInfo) } } - if km.retiringKey != nil { - if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil { + if tok.retiringKey != nil { + if pkInfo := extractPublicKeyInfo(tok.retiringKey); pkInfo != nil { publicKeys = append(publicKeys, *pkInfo) } } diff --git a/auth/tokenizer/symmetric/tokenizer.go b/auth/tokenizer/symmetric/tokenizer.go index 5933bb697..3c3b8b573 100644 --- a/auth/tokenizer/symmetric/tokenizer.go +++ b/auth/tokenizer/symmetric/tokenizer.go @@ -14,12 +14,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" ) -const ( - patPrefix = "pat" -) - -var errJWTExpiryKey = errors.New(`"exp" not satisfied`) - type tokenizer struct { algorithm jwa.KeyAlgorithm secret []byte @@ -41,13 +35,13 @@ func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) { }, nil } -func (km *tokenizer) Issue(key auth.Key) (string, error) { +func (tok *tokenizer) Issue(key auth.Key) (string, error) { tkn, err := smqjwt.BuildToken(key) if err != nil { return "", err } - signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret)) + signedBytes, err := jwt.Sign(tkn, jwt.WithKey(tok.algorithm, tok.secret)) if err != nil { return "", err } @@ -55,18 +49,18 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) { return string(signedBytes), nil } -func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { - if len(tokenString) >= 3 && tokenString[:3] == patPrefix { +func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) { + if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix { return auth.Key{Type: auth.PersonalAccessToken}, nil } tkn, err := jwt.Parse( []byte(tokenString), jwt.WithValidate(true), - jwt.WithKey(km.algorithm, km.secret), + jwt.WithKey(tok.algorithm, tok.secret), ) if err != nil { - if errors.Contains(err, errJWTExpiryKey) { + if errors.Contains(err, smqjwt.ErrJWTExpiryKey) { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry) } return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -79,6 +73,6 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e return smqjwt.ToKey(tkn) } -func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { +func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) { return nil, auth.ErrPublicKeysNotSupported } diff --git a/auth/tokenizer/util/jwt.go b/auth/tokenizer/util/jwt.go index f18c6ab2a..6e2badd8d 100644 --- a/auth/tokenizer/util/jwt.go +++ b/auth/tokenizer/util/jwt.go @@ -18,6 +18,9 @@ var ( // ErrJSONHandle indicates an error in handling JSON. ErrJSONHandle = errors.New("failed to perform operation JSON") + // ErrJWTExpiryKey indicates that the "exp" claim in the JWT token is not satisfied. + ErrJWTExpiryKey = errors.New(`"exp" not satisfied`) + errInvalidType = errors.New("invalid token type") errInvalidRole = errors.New("invalid role") errInvalidVerified = errors.New("invalid verified") @@ -28,6 +31,7 @@ const ( TokenType = "type" RoleField = "role" VerifiedField = "verified" + PatPrefix = "pat" ) // ToKey converts a JWT token to an auth.Key by extracting claims. diff --git a/cmd/auth/main.go b/cmd/auth/main.go index d2e36d71c..983c66067 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -292,17 +292,21 @@ func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error { } func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) { - cache := cache.NewPatsCache(cacheClient, keyDuration) + patsCache := cache.NewPatsCache(cacheClient, keyDuration) + tokensCache, err := cache.NewUserActiveTokensCache(cacheClient, keyDuration) + if err != nil { + return nil, err + } database := pgclient.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) - patsRepo := apostgres.NewPatRepo(database, cache) + patsRepo := apostgres.NewPatRepo(database, patsCache) hasher := hasher.New() pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger) pService := spicedb.NewPolicyService(spicedbClient, logger) - svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) + svc := auth.New(keysRepo, patsRepo, nil, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc = middleware.NewLogging(svc, logger) counter, latency := prometheus.MakeMetrics("auth", "api") svc = middleware.NewMetrics(svc, counter, latency) diff --git a/cmd/users/main.go b/cmd/users/main.go index eea912504..43e5d4815 100644 --- a/cmd/users/main.go +++ b/cmd/users/main.go @@ -399,7 +399,7 @@ func createAdmin(ctx context.Context, c config, repo users.Repository, hsr users if _, err = repo.Save(ctx, user); err != nil { return "", err } - if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword); err != nil { + if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword, ""); err != nil { return "", err } return user.ID, nil diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 682b1a302..dc82038fb 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -95,6 +95,8 @@ services: - supermq-base-net volumes: - supermq-auth-redis-volume:/data + - ./redis/redis.conf:/etc/redis/redis.conf:ro + command: ["redis-server", "/etc/redis/redis.conf"] auth: image: docker.io/supermq/auth:${SMQ_RELEASE_TAG} diff --git a/docker/redis/redis.conf b/docker/redis/redis.conf new file mode 100644 index 000000000..40bc23df5 --- /dev/null +++ b/docker/redis/redis.conf @@ -0,0 +1,14 @@ +# Copyright (c) Abstract Machines +# SPDX-License-Identifier: Apache-2.0 + +# Enable AOF persistence +appendonly yes +appendfilename "appendonly.aof" +appendfsync everysec + +# Enable periodic snapshots +save 300 10 + +# Persist data in Docker volume +dir /data + diff --git a/internal/proto/token/v1/token.proto b/internal/proto/token/v1/token.proto index 10c066511..a616809bc 100644 --- a/internal/proto/token/v1/token.proto +++ b/internal/proto/token/v1/token.proto @@ -9,6 +9,9 @@ option go_package = "github.com/absmach/supermq/api/grpc/token/v1"; service TokenService { rpc Issue(IssueReq) returns (Token) {} rpc Refresh(RefreshReq) returns (Token) {} + rpc Revoke(RevokeReq) returns (RevokeRes) {} + rpc ListUserRefreshTokens(ListUserRefreshTokensReq) + returns (ListUserRefreshTokensRes) {} } message IssueReq { @@ -16,6 +19,7 @@ message IssueReq { uint32 user_role = 2; uint32 type = 3; bool verified = 4; + string description = 5; } message RefreshReq { @@ -23,6 +27,11 @@ message RefreshReq { bool verified = 2; } +message RevokeReq { + string token_id = 1; + string user_id = 2; +} + // If a token is not carrying any information itself, the type // field can be used to determine how to validate the token. // Also, different tokens can be encoded in different ways. @@ -31,3 +40,20 @@ message Token { optional string refresh_token = 2; string access_type = 3; } + +message RevokeRes{ + +} + +message ListUserRefreshTokensReq { + string user_id = 1; +} + +message ListUserRefreshTokensRes { + repeated RefreshToken refresh_tokens = 1; +} + +message RefreshToken { + string id = 1; + string description = 2; +} diff --git a/pkg/sdk/tokens.go b/pkg/sdk/tokens.go index ccb0db7df..94725f90d 100644 --- a/pkg/sdk/tokens.go +++ b/pkg/sdk/tokens.go @@ -21,8 +21,9 @@ type Token struct { } type Login struct { - Username string `json:"username"` - Password string `json:"password"` + Username string `json:"username"` + Password string `json:"password"` + Description string `json:"description,omitempty"` } func (sdk mgSDK) CreateToken(ctx context.Context, lt Login) (Token, errors.SDKError) { diff --git a/pkg/sdk/tokens_test.go b/pkg/sdk/tokens_test.go index 2e2adc0c5..68ba62a79 100644 --- a/pkg/sdk/tokens_test.go +++ b/pkg/sdk/tokens_test.go @@ -101,12 +101,12 @@ func TestIssueToken(t *testing.T) { } for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password).Return(tc.svcRes, tc.svcErr) + svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description).Return(tc.svcRes, tc.svcErr) resp, err := mgsdk.CreateToken(context.Background(), tc.login) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password) + ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description) assert.True(t, ok) } svcCall.Unset() diff --git a/tools/config/.mockery.yaml b/tools/config/.mockery.yaml index a977b5754..7ae69ef80 100644 --- a/tools/config/.mockery.yaml +++ b/tools/config/.mockery.yaml @@ -59,6 +59,7 @@ packages: Cache: Hasher: KeyRepository: + UserActiveTokensCache: Tokenizer: PATS: PATSRepository: @@ -145,4 +146,3 @@ packages: github.com/absmach/supermq/notifications: interfaces: Notifier: - diff --git a/users/api/endpoint_test.go b/users/api/endpoint_test.go index 2c13e833b..fe8211f54 100644 --- a/users/api/endpoint_test.go +++ b/users/api/endpoint_test.go @@ -2384,7 +2384,9 @@ func TestIssueToken(t *testing.T) { defer us.Close() validUsername := "valid" + validDescription := "test token" dataFormat := `{"username": "%s", "password": "%s"}` + dataFormatWithDesc := `{"username": "%s", "password": "%s", "description": "%s"}` cases := []struct { desc string @@ -2400,6 +2402,13 @@ func TestIssueToken(t *testing.T) { status: http.StatusCreated, err: nil, }, + { + desc: "issue token with valid identity, secret and description", + data: fmt.Sprintf(dataFormatWithDesc, validUsername, secret, validDescription), + contentType: contentType, + status: http.StatusCreated, + err: nil, + }, { desc: "issue token with empty identity", data: fmt.Sprintf(dataFormat, "", secret), @@ -2447,7 +2456,7 @@ func TestIssueToken(t *testing.T) { body: strings.NewReader(tc.data), } - svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err) + svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) if tc.err != nil { @@ -2565,6 +2574,206 @@ func TestRefreshToken(t *testing.T) { } } +func TestRevokeRefreshToken(t *testing.T) { + us, svc, authn := newUsersServer() + defer us.Close() + + cases := []struct { + desc string + data string + contentType string + token string + authnRes smqauthn.Session + authnErr error + status int + svcErr error + err error + }{ + { + desc: "revoke refresh token with valid token", + data: fmt.Sprintf(`{"token_id": "%s"}`, validToken), + contentType: contentType, + token: validToken, + authnRes: verifiedSession, + status: http.StatusNoContent, + err: nil, + }, + { + desc: "revoke refresh token with invalid token", + data: fmt.Sprintf(`{"token_id": "%s"}`, validToken), + contentType: contentType, + token: inValidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke refresh token with empty token", + data: fmt.Sprintf(`{"token_id": "%s"}`, validToken), + contentType: contentType, + token: "", + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: apiutil.ErrBearerToken, + }, + { + desc: "revoke refresh token with empty token id", + data: `{"token_id": ""}`, + contentType: contentType, + token: validToken, + authnRes: verifiedSession, + status: http.StatusBadRequest, + err: apiutil.ErrMissingID, + }, + { + desc: "revoke refresh token with malformed data", + data: fmt.Sprintf(`{"token_id": %s}`, validToken), + contentType: contentType, + token: validToken, + authnRes: verifiedSession, + status: http.StatusBadRequest, + err: apiutil.ErrMalformedRequestBody, + }, + { + desc: "revoke refresh token with invalid content type", + data: fmt.Sprintf(`{"token_id": "%s"}`, validToken), + contentType: "application/xml", + token: validToken, + authnRes: verifiedSession, + status: http.StatusUnsupportedMediaType, + err: apiutil.ErrUnsupportedContentType, + }, + { + desc: "revoke refresh token with service error", + data: fmt.Sprintf(`{"token_id": "%s"}`, validToken), + contentType: contentType, + token: validToken, + authnRes: verifiedSession, + status: http.StatusUnprocessableEntity, + svcErr: svcerr.ErrViewEntity, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + user: us.Client(), + method: http.MethodPost, + url: fmt.Sprintf("%s/users/tokens/revoke", us.URL), + contentType: tc.contentType, + body: strings.NewReader(tc.data), + token: tc.token, + } + authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("RevokeRefreshToken", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + if tc.err != nil { + var resBody respBody + err = json.NewDecoder(res.Body).Decode(&resBody) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if resBody.Err != "" || resBody.Message != "" { + err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authnCall.Unset() + }) + } +} + +func TestListActiveRefreshTokens(t *testing.T) { + us, svc, authn := newUsersServer() + defer us.Close() + + cases := []struct { + desc string + token string + authnRes smqauthn.Session + authnErr error + status int + svcRes *grpcTokenV1.ListUserRefreshTokensRes + svcErr error + err error + }{ + { + desc: "list active refresh tokens with valid token", + token: validToken, + authnRes: verifiedSession, + status: http.StatusOK, + svcRes: &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: []*grpcTokenV1.RefreshToken{ + {Id: "token1", Description: "token-1"}, + {Id: "token2", Description: "token-2"}, + }, + }, + err: nil, + }, + { + desc: "list active refresh tokens with invalid token", + token: inValidToken, + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: svcerr.ErrAuthentication, + }, + { + desc: "list active refresh tokens with empty token", + token: "", + status: http.StatusUnauthorized, + authnErr: svcerr.ErrAuthentication, + err: apiutil.ErrBearerToken, + }, + { + desc: "list active refresh tokens with service error", + token: validToken, + authnRes: verifiedSession, + status: http.StatusUnprocessableEntity, + svcErr: svcerr.ErrViewEntity, + err: svcerr.ErrViewEntity, + }, + { + desc: "list active refresh tokens with empty list", + token: validToken, + authnRes: verifiedSession, + status: http.StatusOK, + svcRes: &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: []*grpcTokenV1.RefreshToken{}, + }, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + req := testRequest{ + user: us.Client(), + method: http.MethodGet, + url: fmt.Sprintf("%s/users/tokens/refresh-tokens", us.URL), + token: tc.token, + } + authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) + svcCall := svc.On("ListActiveRefreshTokens", mock.Anything, tc.authnRes).Return(tc.svcRes, tc.svcErr) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + if tc.err != nil { + var resBody respBody + err = json.NewDecoder(res.Body).Decode(&resBody) + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err)) + if resBody.Err != "" || resBody.Message != "" { + err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message)) + } + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + } + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + authnCall.Unset() + }) + } +} + func TestEnable(t *testing.T) { us, svc, authn := newUsersServer() defer us.Close() diff --git a/users/api/endpoints.go b/users/api/endpoints.go index 0287aa08c..48e57a204 100644 --- a/users/api/endpoints.go +++ b/users/api/endpoints.go @@ -416,7 +416,7 @@ func issueTokenEndpoint(svc users.Service) endpoint.Endpoint { return nil, errors.Wrap(apiutil.ErrValidation, err) } - token, err := svc.IssueToken(ctx, req.Username, req.Password) + token, err := svc.IssueToken(ctx, req.Username, req.Password, req.Description) if err != nil { return nil, err } @@ -454,6 +454,43 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint { } } +func revokeRefreshTokenEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + req := request.(revokeTokenReq) + if err := req.validate(); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + err := svc.RevokeRefreshToken(ctx, session, req.TokenID) + if err != nil { + return nil, err + } + + return revokeRes{}, nil + } +} + +func listActiveRefreshTokensEndpoint(svc users.Service) endpoint.Endpoint { + return func(ctx context.Context, request any) (any, error) { + session, ok := ctx.Value(authn.SessionKey).(authn.Session) + if !ok { + return nil, svcerr.ErrAuthentication + } + + refreshTokens, err := svc.ListActiveRefreshTokens(ctx, session) + if err != nil { + return nil, err + } + + return listRefreshTokensRes{RefreshTokens: refreshTokens.GetRefreshTokens()}, nil + } +} + func enableEndpoint(svc users.Service) endpoint.Endpoint { return func(ctx context.Context, request any) (any, error) { req := request.(changeUserStatusReq) diff --git a/users/api/requests.go b/users/api/requests.go index bc22e5afc..7a0205f48 100644 --- a/users/api/requests.go +++ b/users/api/requests.go @@ -270,8 +270,9 @@ func (req changeUserStatusReq) validate() error { } type loginUserReq struct { - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Description string `json:"description,omitempty"` } func (req loginUserReq) validate() error { @@ -297,6 +298,18 @@ func (req tokenReq) validate() error { return nil } +type revokeTokenReq struct { + TokenID string `json:"token_id,omitempty"` +} + +func (req revokeTokenReq) validate() error { + if req.TokenID == "" { + return apiutil.ErrMissingID + } + + return nil +} + type passResetReq struct { Email string `json:"email"` } diff --git a/users/api/responses.go b/users/api/responses.go index f862912d8..d4bbb3c3c 100644 --- a/users/api/responses.go +++ b/users/api/responses.go @@ -8,6 +8,7 @@ import ( "net/http" "github.com/absmach/supermq" + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" "github.com/absmach/supermq/users" ) @@ -25,8 +26,9 @@ var ( _ supermq.Response = (*passResetReqRes)(nil) _ supermq.Response = (*passChangeRes)(nil) _ supermq.Response = (*updateUserRes)(nil) - _ supermq.Response = (*tokenRes)(nil) + _ supermq.Response = (*revokeRes)(nil) _ supermq.Response = (*deleteUserRes)(nil) + _ supermq.Response = (*listRefreshTokensRes)(nil) ) type pageRes struct { @@ -80,6 +82,36 @@ func (res tokenRes) Empty() bool { return res.AccessToken == "" || res.RefreshToken == "" } +type revokeRes struct{} + +func (res revokeRes) Code() int { + return http.StatusNoContent +} + +func (res revokeRes) Headers() map[string]string { + return map[string]string{} +} + +func (res revokeRes) Empty() bool { + return true +} + +type listRefreshTokensRes struct { + RefreshTokens []*grpcTokenV1.RefreshToken `json:"refresh_tokens"` +} + +func (res listRefreshTokensRes) Code() int { + return http.StatusOK +} + +func (res listRefreshTokensRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listRefreshTokensRes) Empty() bool { + return false +} + type sendVerificationRes struct{} func (res sendVerificationRes) Code() int { diff --git a/users/api/users.go b/users/api/users.go index d84201505..a9b0773a7 100644 --- a/users/api/users.go +++ b/users/api/users.go @@ -78,6 +78,18 @@ func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient api.EncodeResponse, opts..., ), "refresh_token").ServeHTTP) + r.Post("/tokens/revoke", otelhttp.NewHandler(kithttp.NewServer( + revokeRefreshTokenEndpoint(svc), + decodeRevokeRefreshToken, + api.EncodeResponse, + opts..., + ), "revoke_refresh_token").ServeHTTP) + r.Get("/tokens/refresh-tokens", otelhttp.NewHandler(kithttp.NewServer( + listActiveRefreshTokensEndpoint(svc), + decodeListActiveRefreshTokens, + api.EncodeResponse, + opts..., + ), "list_active_refresh_tokens").ServeHTTP) r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer( updateEmailEndpoint(svc), decodeUpdateUserEmail, @@ -532,6 +544,23 @@ func decodeRefreshToken(_ context.Context, r *http.Request) (any, error) { return req, nil } +func decodeRevokeRefreshToken(_ context.Context, r *http.Request) (any, error) { + if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { + return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) + } + + var req revokeTokenReq + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err) + } + + return req, nil +} + +func decodeListActiveRefreshTokens(_ context.Context, r *http.Request) (any, error) { + return nil, nil +} + func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) { if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) { return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType) diff --git a/users/events/events.go b/users/events/events.go index ec95f3f0e..f317707e2 100644 --- a/users/events/events.go +++ b/users/events/events.go @@ -29,11 +29,10 @@ const ( profileView = userPrefix + "view_profile" userList = userPrefix + "list" userSearch = userPrefix + "search" - userListByGroup = userPrefix + "list_by_group" userIdentify = userPrefix + "identify" - generateResetToken = userPrefix + "generate_reset_token" issueToken = userPrefix + "issue_token" refreshToken = userPrefix + "refresh_token" + revokeRefreshToken = userPrefix + "revoke_refresh_token" resetSecret = userPrefix + "reset_secret" sendPasswordReset = userPrefix + "send_password_reset" oauthCallback = userPrefix + "oauth_callback" @@ -56,6 +55,7 @@ var ( _ events.Event = (*identifyUserEvent)(nil) _ events.Event = (*issueTokenEvent)(nil) _ events.Event = (*refreshTokenEvent)(nil) + _ events.Event = (*revokeRefreshTokenEvent)(nil) _ events.Event = (*resetSecretEvent)(nil) _ events.Event = (*sendPasswordResetEvent)(nil) _ events.Event = (*oauthCallbackEvent)(nil) @@ -492,6 +492,19 @@ func (rte refreshTokenEvent) Encode() (map[string]any, error) { }, nil } +type revokeRefreshTokenEvent struct { + tokenID string + requestID string +} + +func (rrte revokeRefreshTokenEvent) Encode() (map[string]any, error) { + return map[string]any{ + "operation": revokeRefreshToken, + "token_id": rrte.tokenID, + "request_id": rrte.requestID, + }, nil +} + type resetSecretEvent struct { requestID string } diff --git a/users/events/streams.go b/users/events/streams.go index 70346185a..fb176adc8 100644 --- a/users/events/streams.go +++ b/users/events/streams.go @@ -15,31 +15,32 @@ import ( ) const ( - supermqPrefix = "supermq." - createStream = supermqPrefix + userCreate - sendVerificationStream = supermqPrefix + userSendVerification - verifyEmailStream = supermqPrefix + userVerifyEmail - updateStream = supermqPrefix + userUpdate - updateRoleStream = supermqPrefix + userUpdateRole - updateTagsStream = supermqPrefix + userUpdateTags - updateSecretStream = supermqPrefix + userUpdateSecret - updateUsernameStream = supermqPrefix + userUpdateUsername - updatePictureStream = supermqPrefix + userUpdateProfilePicture - UpdateEmailStream = supermqPrefix + userUpdateEmail - enableStream = supermqPrefix + userEnable - disableStream = supermqPrefix + userDisable - viewStream = supermqPrefix + userView - viewProfileStream = supermqPrefix + profileView - listStream = supermqPrefix + userList - searchStream = supermqPrefix + userSearch - identifyStream = supermqPrefix + userIdentify - issueTokenStream = supermqPrefix + issueToken - refreshTokenStream = supermqPrefix + refreshToken - resetSecretStream = supermqPrefix + resetSecret - sendPasswordResetStream = supermqPrefix + sendPasswordReset - oauthStream = supermqPrefix + oauthCallback - addPolicyStream = supermqPrefix + addClientPolicy - deleteStream = supermqPrefix + deleteUser + supermqPrefix = "supermq." + createStream = supermqPrefix + userCreate + sendVerificationStream = supermqPrefix + userSendVerification + verifyEmailStream = supermqPrefix + userVerifyEmail + updateStream = supermqPrefix + userUpdate + updateRoleStream = supermqPrefix + userUpdateRole + updateTagsStream = supermqPrefix + userUpdateTags + updateSecretStream = supermqPrefix + userUpdateSecret + updateUsernameStream = supermqPrefix + userUpdateUsername + updatePictureStream = supermqPrefix + userUpdateProfilePicture + UpdateEmailStream = supermqPrefix + userUpdateEmail + enableStream = supermqPrefix + userEnable + disableStream = supermqPrefix + userDisable + viewStream = supermqPrefix + userView + viewProfileStream = supermqPrefix + profileView + listStream = supermqPrefix + userList + searchStream = supermqPrefix + userSearch + identifyStream = supermqPrefix + userIdentify + issueTokenStream = supermqPrefix + issueToken + refreshTokenStream = supermqPrefix + refreshToken + revokeRefreshTokenStream = supermqPrefix + revokeRefreshToken + resetSecretStream = supermqPrefix + resetSecret + sendPasswordResetStream = supermqPrefix + sendPasswordReset + oauthStream = supermqPrefix + oauthCallback + addPolicyStream = supermqPrefix + addClientPolicy + deleteStream = supermqPrefix + deleteUser ) var _ users.Service = (*eventStore)(nil) @@ -350,8 +351,8 @@ func (es *eventStore) SendPasswordReset(ctx context.Context, email string) error return es.Publish(ctx, sendPasswordResetStream, event) } -func (es *eventStore) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { - token, err := es.svc.IssueToken(ctx, username, secret) +func (es *eventStore) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) { + token, err := es.svc.IssueToken(ctx, username, secret, description) if err != nil { return token, err } @@ -385,6 +386,24 @@ func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, r return token, nil } +func (es *eventStore) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + err := es.svc.RevokeRefreshToken(ctx, session, tokenID) + if err != nil { + return err + } + + event := revokeRefreshTokenEvent{ + tokenID: tokenID, + requestID: middleware.GetReqID(ctx), + } + + return es.Publish(ctx, revokeRefreshTokenStream, event) +} + +func (es *eventStore) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + return es.svc.ListActiveRefreshTokens(ctx, session) +} + func (es *eventStore) ResetSecret(ctx context.Context, session authn.Session, secret string) error { if err := es.svc.ResetSecret(ctx, session, secret); err != nil { return err diff --git a/users/events/streams_test.go b/users/events/streams_test.go index 450d99604..80ab118e6 100644 --- a/users/events/streams_test.go +++ b/users/events/streams_test.go @@ -902,22 +902,24 @@ func TestIssueToken(t *testing.T) { } cases := []struct { - desc string - username string - secret string - svcRes *grpcTokenV1.Token - svcErr error - resp *grpcTokenV1.Token - err error + desc string + username string + secret string + description string + svcRes *grpcTokenV1.Token + svcErr error + resp *grpcTokenV1.Token + err error }{ { - desc: "publish successfully", - username: validUser.Credentials.Username, - secret: validUser.Credentials.Secret, - svcRes: validToken, - svcErr: nil, - resp: validToken, - err: nil, + desc: "publish successfully", + username: validUser.Credentials.Username, + secret: validUser.Credentials.Secret, + description: "valid token", + svcRes: validToken, + svcErr: nil, + resp: validToken, + err: nil, }, { desc: "failed to publish with service error", @@ -932,8 +934,8 @@ func TestIssueToken(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret).Return(tc.svcRes, tc.svcErr) - resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret) + svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret, tc.description).Return(tc.svcRes, tc.svcErr) + resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret, tc.description) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp)) svcCall.Unset() @@ -990,6 +992,93 @@ func TestRefreshToken(t *testing.T) { } } +func TestRevokeRefreshToken(t *testing.T) { + svc, nsvc := newEventStoreMiddleware(t) + + validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t)) + validTokenID := "validTokenID" + + cases := []struct { + desc string + session authn.Session + tokenID string + svcErr error + err error + }{ + { + desc: "publish successfully", + session: validSession, + tokenID: validTokenID, + svcErr: nil, + err: nil, + }, + { + desc: "failed to publish with service error", + session: validSession, + tokenID: validTokenID, + svcErr: svcerr.ErrUpdateEntity, + err: svcerr.ErrUpdateEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RevokeRefreshToken", validCtx, tc.session, tc.tokenID).Return(tc.svcErr) + err := nsvc.RevokeRefreshToken(validCtx, tc.session, tc.tokenID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + svcCall.Unset() + }) + } +} + +func TestListActiveRefreshTokens(t *testing.T) { + svc, nsvc := newEventStoreMiddleware(t) + + validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t)) + validTokensList := &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: []*grpcTokenV1.RefreshToken{ + {Id: "token1", Description: "token1"}, + {Id: "token2", Description: "token2"}, + }, + } + + cases := []struct { + desc string + session authn.Session + svcRes *grpcTokenV1.ListUserRefreshTokensRes + svcErr error + resp *grpcTokenV1.ListUserRefreshTokensRes + err error + }{ + { + desc: "publish successfully", + session: validSession, + svcRes: validTokensList, + svcErr: nil, + resp: validTokensList, + err: nil, + }, + { + desc: "failed to publish with service error", + session: validSession, + svcRes: nil, + svcErr: svcerr.ErrViewEntity, + resp: nil, + err: svcerr.ErrViewEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("ListActiveRefreshTokens", validCtx, tc.session).Return(tc.svcRes, tc.svcErr) + resp, err := nsvc.ListActiveRefreshTokens(validCtx, tc.session) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp)) + svcCall.Unset() + }) + } +} + func TestResetSecret(t *testing.T) { svc, nsvc := newEventStoreMiddleware(t) diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index 15064629c..b0d3b7cbd 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -162,14 +162,22 @@ func (am *authorizationMiddleware) Identify(ctx context.Context, session authn.S return am.svc.Identify(ctx, session) } -func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { - return am.svc.IssueToken(ctx, username, secret) +func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) { + return am.svc.IssueToken(ctx, username, secret, description) } func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) { return am.svc.RefreshToken(ctx, session, refreshToken) } +func (am *authorizationMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + return am.svc.RevokeRefreshToken(ctx, session, tokenID) +} + +func (am *authorizationMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + return am.svc.ListActiveRefreshTokens(ctx, session) +} + func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, user users.User) (users.User, error) { return am.svc.OAuthCallback(ctx, user) } diff --git a/users/middleware/logging.go b/users/middleware/logging.go index 8339cfe9c..c42f7433c 100644 --- a/users/middleware/logging.go +++ b/users/middleware/logging.go @@ -89,7 +89,7 @@ func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken // IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request. // If the request fails, it logs the error. -func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) { +func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (t *grpcTokenV1.Token, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), @@ -105,7 +105,7 @@ func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret st } lm.logger.Info("Issue token completed successfully", args...) }(time.Now()) - return lm.svc.IssueToken(ctx, username, secret) + return lm.svc.IssueToken(ctx, username, secret, description) } // RefreshToken logs the refresh_token request. It logs the refreshtoken, token type and the time it took to complete the request. @@ -129,6 +129,45 @@ func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Ses return lm.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken logs the revoke_refresh_token request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("Revoke refresh token failed", args...) + return + } + lm.logger.Info("Revoke refresh token completed successfully", args...) + }(time.Now()) + return lm.svc.RevokeRefreshToken(ctx, session, tokenID) +} + +// ListActiveRefreshTokens logs the list_active_refresh_tokens request. It logs the time it took to complete the request. +// If the request fails, it logs the error. +func (lm *loggingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (tokens *grpcTokenV1.ListUserRefreshTokensRes, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("request_id", middleware.GetReqID(ctx)), + } + if tokens != nil { + args = append(args, slog.Int("tokens_count", len(tokens.GetRefreshTokens()))) + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("List active refresh tokens failed", args...) + return + } + lm.logger.Info("List active refresh tokens completed successfully", args...) + }(time.Now()) + return lm.svc.ListActiveRefreshTokens(ctx, session) +} + // View logs the view_user request. It logs the user id and the time it took to complete the request. // If the request fails, it logs the error. func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id string) (c users.User, err error) { diff --git a/users/middleware/metrics.go b/users/middleware/metrics.go index 1095e267d..6e2b4fb3e 100644 --- a/users/middleware/metrics.go +++ b/users/middleware/metrics.go @@ -58,12 +58,12 @@ func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken } // IssueToken instruments IssueToken method with metrics. -func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { +func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) { defer func(begin time.Time) { ms.counter.With("method", "issue_token").Add(1) ms.latency.With("method", "issue_token").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.IssueToken(ctx, username, secret) + return ms.svc.IssueToken(ctx, username, secret, description) } // RefreshToken instruments RefreshToken method with metrics. @@ -75,6 +75,24 @@ func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Ses return ms.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken instruments RevokeRefreshToken method with metrics. +func (ms *metricsMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_refresh_token").Add(1) + ms.latency.With("method", "revoke_refresh_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RevokeRefreshToken(ctx, session, tokenID) +} + +// ListActiveRefreshTokens instruments ListActiveRefreshTokens method with metrics. +func (ms *metricsMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_active_refresh_tokens").Add(1) + ms.latency.With("method", "list_active_refresh_tokens").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ListActiveRefreshTokens(ctx, session) +} + // View instruments View method with metrics. func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { defer func(begin time.Time) { diff --git a/users/middleware/tracing.go b/users/middleware/tracing.go index 9ee39763e..0f2cb0caf 100644 --- a/users/middleware/tracing.go +++ b/users/middleware/tracing.go @@ -50,11 +50,11 @@ func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken } // IssueToken traces the "IssueToken" operation of the wrapped users.Service. -func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { +func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) { ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username))) defer span.End() - return tm.svc.IssueToken(ctx, username, secret) + return tm.svc.IssueToken(ctx, username, secret, description) } // RefreshToken traces the "RefreshToken" operation of the wrapped users.Service. @@ -65,6 +65,22 @@ func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Ses return tm.svc.RefreshToken(ctx, session, refreshToken) } +// RevokeRefreshToken traces the "RevokeRefreshToken" operation of the wrapped users.Service. +func (tm *tracingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_revoke_refresh_token") + defer span.End() + + return tm.svc.RevokeRefreshToken(ctx, session, tokenID) +} + +// ListActiveRefreshTokens traces the "ListActiveRefreshTokens" operation of the wrapped users.Service. +func (tm *tracingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_active_refresh_tokens") + defer span.End() + + return tm.svc.ListActiveRefreshTokens(ctx, session) +} + // View traces the "View" operation of the wrapped users.Service. func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_user", trace.WithAttributes(attribute.String("id", id))) diff --git a/users/mocks/service.go b/users/mocks/service.go index 74a090668..a509d3869 100644 --- a/users/mocks/service.go +++ b/users/mocks/service.go @@ -318,8 +318,8 @@ func (_c *Service_Identify_Call) RunAndReturn(run func(ctx context.Context, sess } // IssueToken provides a mock function for the type Service -func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string) (*v1.Token, error) { - ret := _mock.Called(ctx, identity, secret) +func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string, description string) (*v1.Token, error) { + ret := _mock.Called(ctx, identity, secret, description) if len(ret) == 0 { panic("no return value specified for IssueToken") @@ -327,18 +327,18 @@ func (_mock *Service) IssueToken(ctx context.Context, identity string, secret st var r0 *v1.Token var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (*v1.Token, error)); ok { - return returnFunc(ctx, identity, secret) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (*v1.Token, error)); ok { + return returnFunc(ctx, identity, secret, description) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) *v1.Token); ok { - r0 = returnFunc(ctx, identity, secret) + if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) *v1.Token); ok { + r0 = returnFunc(ctx, identity, secret, description) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*v1.Token) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = returnFunc(ctx, identity, secret) + if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = returnFunc(ctx, identity, secret, description) } else { r1 = ret.Error(1) } @@ -354,11 +354,12 @@ type Service_IssueToken_Call struct { // - ctx context.Context // - identity string // - secret string -func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}) *Service_IssueToken_Call { - return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret)} +// - description string +func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}, description interface{}) *Service_IssueToken_Call { + return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret, description)} } -func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string)) *Service_IssueToken_Call { +func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string, description string)) *Service_IssueToken_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -372,10 +373,15 @@ func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity st if args[2] != nil { arg2 = args[2].(string) } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -386,7 +392,75 @@ func (_c *Service_IssueToken_Call) Return(token *v1.Token, err error) *Service_I return _c } -func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string) (*v1.Token, error)) *Service_IssueToken_Call { +func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string, description string) (*v1.Token, error)) *Service_IssueToken_Call { + _c.Call.Return(run) + return _c +} + +// ListActiveRefreshTokens provides a mock function for the type Service +func (_mock *Service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error) { + ret := _mock.Called(ctx, session) + + if len(ret) == 0 { + panic("no return value specified for ListActiveRefreshTokens") + } + + var r0 *v1.ListUserRefreshTokensRes + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) (*v1.ListUserRefreshTokensRes, error)); ok { + return returnFunc(ctx, session) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) *v1.ListUserRefreshTokensRes); ok { + r0 = returnFunc(ctx, session) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session) error); ok { + r1 = returnFunc(ctx, session) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// Service_ListActiveRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListActiveRefreshTokens' +type Service_ListActiveRefreshTokens_Call struct { + *mock.Call +} + +// ListActiveRefreshTokens is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +func (_e *Service_Expecter) ListActiveRefreshTokens(ctx interface{}, session interface{}) *Service_ListActiveRefreshTokens_Call { + return &Service_ListActiveRefreshTokens_Call{Call: _e.mock.On("ListActiveRefreshTokens", ctx, session)} +} + +func (_c *Service_ListActiveRefreshTokens_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_ListActiveRefreshTokens_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *Service_ListActiveRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *Service_ListActiveRefreshTokens_Call { + _c.Call.Return(listUserRefreshTokensRes, err) + return _c +} + +func (_c *Service_ListActiveRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error)) *Service_ListActiveRefreshTokens_Call { _c.Call.Return(run) return _c } @@ -801,6 +875,69 @@ func (_c *Service_ResetSecret_Call) RunAndReturn(run func(ctx context.Context, s return _c } +// RevokeRefreshToken provides a mock function for the type Service +func (_mock *Service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + ret := _mock.Called(ctx, session, tokenID) + + if len(ret) == 0 { + panic("no return value specified for RevokeRefreshToken") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok { + r0 = returnFunc(ctx, session, tokenID) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// Service_RevokeRefreshToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeRefreshToken' +type Service_RevokeRefreshToken_Call struct { + *mock.Call +} + +// RevokeRefreshToken is a helper method to define mock.On call +// - ctx context.Context +// - session authn.Session +// - tokenID string +func (_e *Service_Expecter) RevokeRefreshToken(ctx interface{}, session interface{}, tokenID interface{}) *Service_RevokeRefreshToken_Call { + return &Service_RevokeRefreshToken_Call{Call: _e.mock.On("RevokeRefreshToken", ctx, session, tokenID)} +} + +func (_c *Service_RevokeRefreshToken_Call) Run(run func(ctx context.Context, session authn.Session, tokenID string)) *Service_RevokeRefreshToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 authn.Session + if args[1] != nil { + arg1 = args[1].(authn.Session) + } + var arg2 string + if args[2] != nil { + arg2 = args[2].(string) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *Service_RevokeRefreshToken_Call) Return(err error) *Service_RevokeRefreshToken_Call { + _c.Call.Return(err) + return _c +} + +func (_c *Service_RevokeRefreshToken_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, tokenID string) error) *Service_RevokeRefreshToken_Call { + _c.Call.Return(run) + return _c +} + // SearchUsers provides a mock function for the type Service func (_mock *Service) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) { ret := _mock.Called(ctx, pm) diff --git a/users/service.go b/users/service.go index d0fcb67c1..ad9515294 100644 --- a/users/service.go +++ b/users/service.go @@ -183,7 +183,7 @@ func (svc service) VerifyEmail(ctx context.Context, token string) (User, error) return user, nil } -func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) { +func (svc service) IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error) { var dbUser User var err error @@ -205,7 +205,13 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err) } - token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()}) + token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{ + UserId: dbUser.ID, + UserRole: uint32(dbUser.Role + 1), + Type: uint32(smqauth.AccessKey), + Verified: !dbUser.VerifiedAt.IsZero(), + Description: description, + }) if err != nil { return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err) } @@ -229,6 +235,42 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr return token, nil } +func (svc service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error { + dbUser, err := svc.users.RetrieveByID(ctx, session.UserID) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + if dbUser.Status == DisabledStatus { + return errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser) + } + _, err = svc.token.Revoke(ctx, &grpcTokenV1.RevokeReq{UserId: session.UserID, TokenId: tokenID}) + if err != nil { + if errors.Contains(err, svcerr.ErrNotFound) { + return errors.Wrap(svcerr.ErrNotFound, err) + } + return errors.Wrap(svcerr.ErrRemoveEntity, err) + } + + return nil +} + +func (svc service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) { + dbUser, err := svc.users.RetrieveByID(ctx, session.UserID) + if err != nil { + return nil, errors.Wrap(svcerr.ErrAuthentication, err) + } + if dbUser.Status == DisabledStatus { + return nil, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser) + } + + refreshTokens, err := svc.token.ListUserRefreshTokens(ctx, &grpcTokenV1.ListUserRefreshTokensReq{UserId: session.UserID}) + if err != nil { + return nil, errors.Wrap(svcerr.ErrAuthentication, err) + } + + return refreshTokens, nil +} + func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) { user, err := svc.users.RetrieveByID(ctx, id) if err != nil { @@ -453,7 +495,7 @@ func (svc service) UpdateSecret(ctx context.Context, session authn.Session, oldS if err != nil { return User{}, errors.Wrap(svcerr.ErrViewEntity, err) } - if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret); err != nil { + if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret, ""); err != nil { return User{}, err } newSecret, err = svc.hasher.Hash(newSecret) diff --git a/users/service_test.go b/users/service_test.go index 7b09ab3b6..3d9a1fbf1 100644 --- a/users/service_test.go +++ b/users/service_test.go @@ -1615,7 +1615,7 @@ func TestIssueToken(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr) authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr) - token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret) + token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret, "") assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken())) @@ -1710,6 +1710,181 @@ func TestRefreshToken(t *testing.T) { } } +func TestRevokeRefreshToken(t *testing.T) { + svc, authsvc, crepo, _, _ := newService() + + rUser := user + rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) + + cases := []struct { + desc string + session authn.Session + tokenID string + revokeResp *grpcTokenV1.RevokeRes + revokeErr error + repoResp users.User + repoErr error + err error + }{ + { + desc: "revoke refresh token successfully", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + tokenID: validToken, + revokeResp: &grpcTokenV1.RevokeRes{}, + repoResp: rUser, + err: nil, + }, + { + desc: "revoke refresh token with empty domain id", + session: authn.Session{UserID: validID}, + tokenID: validToken, + revokeResp: &grpcTokenV1.RevokeRes{}, + repoResp: rUser, + err: nil, + }, + { + desc: "revoke refresh token for non-existing user", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + tokenID: validToken, + repoErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "revoke refresh token for disabled user", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + tokenID: validToken, + repoResp: users.User{Status: users.DisabledStatus}, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke refresh token with revoke service error", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + tokenID: validToken, + revokeResp: &grpcTokenV1.RevokeRes{}, + revokeErr: svcerr.ErrAuthorization, + repoResp: rUser, + err: svcerr.ErrAuthorization, + }, + { + desc: "revoke refresh token not found", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + tokenID: validToken, + revokeResp: &grpcTokenV1.RevokeRes{}, + revokeErr: svcerr.ErrNotFound, + repoResp: rUser, + err: svcerr.ErrNotFound, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr) + authCall := authsvc.On("Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}).Return(tc.revokeResp, tc.revokeErr) + err := svc.RevokeRefreshToken(context.Background(), tc.session, tc.tokenID) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + ok = authCall.Parent.AssertCalled(t, "Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}) + assert.True(t, ok, fmt.Sprintf("Revoke was not called on %s", tc.desc)) + } + repoCall.Unset() + authCall.Unset() + }) + } +} + +func TestListActiveRefreshTokens(t *testing.T) { + svc, authsvc, crepo, _, _ := newService() + + rUser := user + rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) + + cases := []struct { + desc string + session authn.Session + listResp *grpcTokenV1.ListUserRefreshTokensRes + listErr error + repoResp users.User + repoErr error + expectedTokens int + err error + }{ + { + desc: "list active refresh tokens successfully", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + listResp: &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: []*grpcTokenV1.RefreshToken{ + {Id: "token1", Description: "token1"}, + {Id: "token2", Description: "token2"}, + }, + }, + repoResp: rUser, + expectedTokens: 2, + err: nil, + }, + { + desc: "list active refresh tokens with empty domain id", + session: authn.Session{UserID: validID}, + listResp: &grpcTokenV1.ListUserRefreshTokensRes{ + RefreshTokens: []*grpcTokenV1.RefreshToken{ + {Id: "token1", Description: "token1"}, + }, + }, + repoResp: rUser, + expectedTokens: 1, + err: nil, + }, + { + desc: "list active refresh tokens for non-existing user", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + repoErr: repoerr.ErrNotFound, + err: repoerr.ErrNotFound, + }, + { + desc: "list active refresh tokens for disabled user", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + repoResp: users.User{Status: users.DisabledStatus}, + err: svcerr.ErrAuthentication, + }, + { + desc: "list active refresh tokens with list service error", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + listResp: &grpcTokenV1.ListUserRefreshTokensRes{}, + listErr: svcerr.ErrAuthentication, + repoResp: rUser, + err: svcerr.ErrAuthentication, + }, + { + desc: "list active refresh tokens with empty list", + session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID}, + listResp: &grpcTokenV1.ListUserRefreshTokensRes{RefreshTokens: []*grpcTokenV1.RefreshToken{}}, + repoResp: rUser, + expectedTokens: 0, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr) + authCall := authsvc.On("ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}).Return(tc.listResp, tc.listErr) + tokens, err := svc.ListActiveRefreshTokens(context.Background(), tc.session) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + if err == nil { + assert.NotNil(t, tokens, fmt.Sprintf("%s: expected tokens not to be nil\n", tc.desc)) + assert.Equal(t, tc.expectedTokens, len(tokens.GetRefreshTokens()), fmt.Sprintf("%s: expected %d tokens got %d\n", tc.desc, tc.expectedTokens, len(tokens.GetRefreshTokens()))) + ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID) + assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) + ok = authCall.Parent.AssertCalled(t, "ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}) + assert.True(t, ok, fmt.Sprintf("ListUserRefreshTokens was not called on %s", tc.desc)) + } + repoCall.Unset() + authCall.Unset() + }) + } +} + func TestSendPasswordReset(t *testing.T) { svc, auth, cRepo, _, e := newService() diff --git a/users/users.go b/users/users.go index 327c4460f..bbd6889c6 100644 --- a/users/users.go +++ b/users/users.go @@ -264,13 +264,19 @@ type Service interface { Identify(ctx context.Context, session authn.Session) (string, error) // IssueToken issues a new access and refresh token when provided with either a username or email. - IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) + IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error) // RefreshToken refreshes expired access tokens. // After an access token expires, the refresh token is used to get // a new pair of access and refresh tokens. RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) + // RevokeRefreshToken revokes a refresh token by its ID. + RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error + + // ListActiveRefreshTokens lists all active refresh tokens for the authenticated user. + ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) + // OAuthCallback handles the callback from any supported OAuth provider. // It processes the OAuth tokens and either signs in or signs up the user based on the provided state. OAuthCallback(ctx context.Context, user User) (User, error)