SMQ-1672 - Revoke refresh token (#3241)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
Co-authored-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Felix Gateru
2026-03-03 17:22:28 +03:00
committed by GitHub
parent 2c476c17ee
commit 9c2608659f
52 changed files with 3266 additions and 224 deletions
+2 -5
View File
@@ -165,11 +165,8 @@ mocks: $(MOCKERY)
$(MOCKERY):
@mkdir -p $(GOBIN)
@mkdir -p mockery
@echo ">> downloading mockery $(MOCKERY_VERSION)..."
@curl -sL https://github.com/vektra/mockery/releases/download/v$(MOCKERY_VERSION)/mockery_$(MOCKERY_VERSION)_Linux_x86_64.tar.gz | tar -xz -C mockery
@mv mockery/mockery $(GOBIN)
@rm -r mockery
@echo ">> installing mockery $(MOCKERY_VERSION)..."
@go install github.com/vektra/mockery/v3@v$(MOCKERY_VERSION)
DIRS = consumers readers postgres internal
test: mocks
+283 -23
View File
@@ -30,6 +30,7 @@ type IssueReq struct {
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"`
Description string `protobuf:"bytes,5,opt,name=description,proto3" json:"description,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -92,6 +93,13 @@ func (x *IssueReq) GetVerified() bool {
return false
}
func (x *IssueReq) GetDescription() string {
if x != nil {
return x.Description
}
return ""
}
type RefreshReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
@@ -144,6 +152,58 @@ func (x *RefreshReq) GetVerified() bool {
return false
}
type RevokeReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
TokenId string `protobuf:"bytes,1,opt,name=token_id,json=tokenId,proto3" json:"token_id,omitempty"`
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RevokeReq) Reset() {
*x = RevokeReq{}
mi := &file_token_v1_token_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RevokeReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RevokeReq) ProtoMessage() {}
func (x *RevokeReq) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[2]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead.
func (*RevokeReq) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{2}
}
func (x *RevokeReq) GetTokenId() string {
if x != nil {
return x.TokenId
}
return ""
}
func (x *RevokeReq) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
// If a token is not carrying any information itself, the type
// field can be used to determine how to validate the token.
// Also, different tokens can be encoded in different ways.
@@ -158,7 +218,7 @@ type Token struct {
func (x *Token) Reset() {
*x = Token{}
mi := &file_token_v1_token_proto_msgTypes[2]
mi := &file_token_v1_token_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -170,7 +230,7 @@ func (x *Token) String() string {
func (*Token) ProtoMessage() {}
func (x *Token) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[2]
mi := &file_token_v1_token_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -183,7 +243,7 @@ func (x *Token) ProtoReflect() protoreflect.Message {
// Deprecated: Use Token.ProtoReflect.Descriptor instead.
func (*Token) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{2}
return file_token_v1_token_proto_rawDescGZIP(), []int{3}
}
func (x *Token) GetAccessToken() string {
@@ -207,29 +267,219 @@ func (x *Token) GetAccessType() string {
return ""
}
type RevokeRes struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RevokeRes) Reset() {
*x = RevokeRes{}
mi := &file_token_v1_token_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RevokeRes) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RevokeRes) ProtoMessage() {}
func (x *RevokeRes) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[4]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RevokeRes.ProtoReflect.Descriptor instead.
func (*RevokeRes) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{4}
}
type ListUserRefreshTokensReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListUserRefreshTokensReq) Reset() {
*x = ListUserRefreshTokensReq{}
mi := &file_token_v1_token_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListUserRefreshTokensReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListUserRefreshTokensReq) ProtoMessage() {}
func (x *ListUserRefreshTokensReq) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[5]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListUserRefreshTokensReq.ProtoReflect.Descriptor instead.
func (*ListUserRefreshTokensReq) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{5}
}
func (x *ListUserRefreshTokensReq) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
type ListUserRefreshTokensRes struct {
state protoimpl.MessageState `protogen:"open.v1"`
RefreshTokens []*RefreshToken `protobuf:"bytes,1,rep,name=refresh_tokens,json=refreshTokens,proto3" json:"refresh_tokens,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *ListUserRefreshTokensRes) Reset() {
*x = ListUserRefreshTokensRes{}
mi := &file_token_v1_token_proto_msgTypes[6]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *ListUserRefreshTokensRes) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ListUserRefreshTokensRes) ProtoMessage() {}
func (x *ListUserRefreshTokensRes) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[6]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ListUserRefreshTokensRes.ProtoReflect.Descriptor instead.
func (*ListUserRefreshTokensRes) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{6}
}
func (x *ListUserRefreshTokensRes) GetRefreshTokens() []*RefreshToken {
if x != nil {
return x.RefreshTokens
}
return nil
}
type RefreshToken struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RefreshToken) Reset() {
*x = RefreshToken{}
mi := &file_token_v1_token_proto_msgTypes[7]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RefreshToken) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RefreshToken) ProtoMessage() {}
func (x *RefreshToken) ProtoReflect() protoreflect.Message {
mi := &file_token_v1_token_proto_msgTypes[7]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RefreshToken.ProtoReflect.Descriptor instead.
func (*RefreshToken) Descriptor() ([]byte, []int) {
return file_token_v1_token_proto_rawDescGZIP(), []int{7}
}
func (x *RefreshToken) GetId() string {
if x != nil {
return x.Id
}
return ""
}
func (x *RefreshToken) GetDescription() string {
if x != nil {
return x.Description
}
return ""
}
var File_token_v1_token_proto protoreflect.FileDescriptor
const file_token_v1_token_proto_rawDesc = "" +
"\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"p\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"\x92\x01\n" +
"\bIssueReq\x12\x17\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
"\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" +
"\bverified\x18\x04 \x01(\bR\bverified\"M\n" +
"\bverified\x18\x04 \x01(\bR\bverified\x12 \n" +
"\vdescription\x18\x05 \x01(\tR\vdescription\"M\n" +
"\n" +
"RefreshReq\x12#\n" +
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" +
"\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" +
"\bverified\x18\x02 \x01(\bR\bverified\"?\n" +
"\tRevokeReq\x12\x19\n" +
"\btoken_id\x18\x01 \x01(\tR\atokenId\x12\x17\n" +
"\auser_id\x18\x02 \x01(\tR\x06userId\"\x87\x01\n" +
"\x05Token\x12!\n" +
"\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" +
"\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" +
"\vaccess_type\x18\x03 \x01(\tR\n" +
"accessTypeB\x10\n" +
"\x0e_refresh_token2r\n" +
"\x0e_refresh_token\"\v\n" +
"\tRevokeRes\"3\n" +
"\x18ListUserRefreshTokensReq\x12\x17\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\"Y\n" +
"\x18ListUserRefreshTokensRes\x12=\n" +
"\x0erefresh_tokens\x18\x01 \x03(\v2\x16.token.v1.RefreshTokenR\rrefreshTokens\"@\n" +
"\fRefreshToken\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12 \n" +
"\vdescription\x18\x02 \x01(\tR\vdescription2\x8b\x02\n" +
"\fTokenService\x12.\n" +
"\x05Issue\x12\x12.token.v1.IssueReq\x1a\x0f.token.v1.Token\"\x00\x122\n" +
"\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3"
"\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00\x124\n" +
"\x06Revoke\x12\x13.token.v1.RevokeReq\x1a\x13.token.v1.RevokeRes\"\x00\x12a\n" +
"\x15ListUserRefreshTokens\x12\".token.v1.ListUserRefreshTokensReq\x1a\".token.v1.ListUserRefreshTokensRes\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3"
var (
file_token_v1_token_proto_rawDescOnce sync.Once
@@ -243,22 +493,32 @@ func file_token_v1_token_proto_rawDescGZIP() []byte {
return file_token_v1_token_proto_rawDescData
}
var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
var file_token_v1_token_proto_goTypes = []any{
(*IssueReq)(nil), // 0: token.v1.IssueReq
(*RefreshReq)(nil), // 1: token.v1.RefreshReq
(*Token)(nil), // 2: token.v1.Token
(*IssueReq)(nil), // 0: token.v1.IssueReq
(*RefreshReq)(nil), // 1: token.v1.RefreshReq
(*RevokeReq)(nil), // 2: token.v1.RevokeReq
(*Token)(nil), // 3: token.v1.Token
(*RevokeRes)(nil), // 4: token.v1.RevokeRes
(*ListUserRefreshTokensReq)(nil), // 5: token.v1.ListUserRefreshTokensReq
(*ListUserRefreshTokensRes)(nil), // 6: token.v1.ListUserRefreshTokensRes
(*RefreshToken)(nil), // 7: token.v1.RefreshToken
}
var file_token_v1_token_proto_depIdxs = []int32{
0, // 0: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq
1, // 1: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq
2, // 2: token.v1.TokenService.Issue:output_type -> token.v1.Token
2, // 3: token.v1.TokenService.Refresh:output_type -> token.v1.Token
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
7, // 0: token.v1.ListUserRefreshTokensRes.refresh_tokens:type_name -> token.v1.RefreshToken
0, // 1: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq
1, // 2: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq
2, // 3: token.v1.TokenService.Revoke:input_type -> token.v1.RevokeReq
5, // 4: token.v1.TokenService.ListUserRefreshTokens:input_type -> token.v1.ListUserRefreshTokensReq
3, // 5: token.v1.TokenService.Issue:output_type -> token.v1.Token
3, // 6: token.v1.TokenService.Refresh:output_type -> token.v1.Token
4, // 7: token.v1.TokenService.Revoke:output_type -> token.v1.RevokeRes
6, // 8: token.v1.TokenService.ListUserRefreshTokens:output_type -> token.v1.ListUserRefreshTokensRes
5, // [5:9] is the sub-list for method output_type
1, // [1:5] is the sub-list for method input_type
1, // [1:1] is the sub-list for extension type_name
1, // [1:1] is the sub-list for extension extendee
0, // [0:1] is the sub-list for field type_name
}
func init() { file_token_v1_token_proto_init() }
@@ -266,14 +526,14 @@ func file_token_v1_token_proto_init() {
if File_token_v1_token_proto != nil {
return
}
file_token_v1_token_proto_msgTypes[2].OneofWrappers = []any{}
file_token_v1_token_proto_msgTypes[3].OneofWrappers = []any{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_v1_token_proto_rawDesc), len(file_token_v1_token_proto_rawDesc)),
NumEnums: 0,
NumMessages: 3,
NumMessages: 8,
NumExtensions: 0,
NumServices: 1,
},
+78 -2
View File
@@ -22,8 +22,10 @@ import (
const _ = grpc.SupportPackageIsVersion9
const (
TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue"
TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh"
TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue"
TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh"
TokenService_Revoke_FullMethodName = "/token.v1.TokenService/Revoke"
TokenService_ListUserRefreshTokens_FullMethodName = "/token.v1.TokenService/ListUserRefreshTokens"
)
// TokenServiceClient is the client API for TokenService service.
@@ -32,6 +34,8 @@ const (
type TokenServiceClient interface {
Issue(ctx context.Context, in *IssueReq, opts ...grpc.CallOption) (*Token, error)
Refresh(ctx context.Context, in *RefreshReq, opts ...grpc.CallOption) (*Token, error)
Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error)
ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error)
}
type tokenServiceClient struct {
@@ -62,12 +66,34 @@ func (c *tokenServiceClient) Refresh(ctx context.Context, in *RefreshReq, opts .
return out, nil
}
func (c *tokenServiceClient) Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RevokeRes)
err := c.cc.Invoke(ctx, TokenService_Revoke_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *tokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListUserRefreshTokensRes)
err := c.cc.Invoke(ctx, TokenService_ListUserRefreshTokens_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// TokenServiceServer is the server API for TokenService service.
// All implementations must embed UnimplementedTokenServiceServer
// for forward compatibility.
type TokenServiceServer interface {
Issue(context.Context, *IssueReq) (*Token, error)
Refresh(context.Context, *RefreshReq) (*Token, error)
Revoke(context.Context, *RevokeReq) (*RevokeRes, error)
ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error)
mustEmbedUnimplementedTokenServiceServer()
}
@@ -84,6 +110,12 @@ func (UnimplementedTokenServiceServer) Issue(context.Context, *IssueReq) (*Token
func (UnimplementedTokenServiceServer) Refresh(context.Context, *RefreshReq) (*Token, error) {
return nil, status.Error(codes.Unimplemented, "method Refresh not implemented")
}
func (UnimplementedTokenServiceServer) Revoke(context.Context, *RevokeReq) (*RevokeRes, error) {
return nil, status.Error(codes.Unimplemented, "method Revoke not implemented")
}
func (UnimplementedTokenServiceServer) ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error) {
return nil, status.Error(codes.Unimplemented, "method ListUserRefreshTokens not implemented")
}
func (UnimplementedTokenServiceServer) mustEmbedUnimplementedTokenServiceServer() {}
func (UnimplementedTokenServiceServer) testEmbeddedByValue() {}
@@ -141,6 +173,42 @@ func _TokenService_Refresh_Handler(srv interface{}, ctx context.Context, dec fun
return interceptor(ctx, in, info, handler)
}
func _TokenService_Revoke_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RevokeReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(TokenServiceServer).Revoke(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: TokenService_Revoke_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(TokenServiceServer).Revoke(ctx, req.(*RevokeReq))
}
return interceptor(ctx, in, info, handler)
}
func _TokenService_ListUserRefreshTokens_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListUserRefreshTokensReq)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: TokenService_ListUserRefreshTokens_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, req.(*ListUserRefreshTokensReq))
}
return interceptor(ctx, in, info, handler)
}
// TokenService_ServiceDesc is the grpc.ServiceDesc for TokenService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -156,6 +224,14 @@ var TokenService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Refresh",
Handler: _TokenService_Refresh_Handler,
},
{
MethodName: "Revoke",
Handler: _TokenService_Revoke_Handler,
},
{
MethodName: "ListUserRefreshTokens",
Handler: _TokenService_ListUserRefreshTokens_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "token/v1/token.proto",
+98
View File
@@ -610,6 +610,57 @@ paths:
"500":
$ref: "#/components/responses/ServiceError"
/users/tokens/revoke:
post:
operationId: revokeRefreshToken
summary: Revoke Refresh Token
description: |
Revokes a specific refresh token by its ID. This invalidates the
refresh token so it can no longer be used to obtain new access tokens.
tags:
- Users
security:
- bearerAuth: []
requestBody:
$ref: "#/components/requestBodies/RevokeRefreshTokenReq"
responses:
"204":
description: Refresh token revoked successfully.
"400":
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"415":
description: Missing or invalid content type.
"422":
description: Database can't process request.
"500":
$ref: "#/components/responses/ServiceError"
/users/tokens/refresh-tokens:
get:
operationId: listActiveRefreshTokens
summary: List Active Refresh Tokens
description: |
Lists all active refresh token sessions for the currently authenticated user.
tags:
- Users
security:
- bearerAuth: []
responses:
"200":
$ref: "#/components/responses/RefreshTokensPageRes"
"400":
description: Failed due to malformed JSON.
"401":
description: Missing or invalid access token provided.
"404":
description: A non-existent entity request.
"500":
$ref: "#/components/responses/ServiceError"
/users/send-verification:
post:
operationId: sendVerification
@@ -1042,6 +1093,30 @@ components:
- username
- password
RefreshToken:
type: object
properties:
id:
type: string
format: uuid
example: "bb7edb32-2eac-4aad-aebe-ed96fe073879"
description: Unique identifier of the refresh token.
description:
type: string
example: "Chrome browser session"
description: Description of the refresh token session.
RefreshTokensPage:
type: object
properties:
refresh_tokens:
type: array
items:
$ref: "#/components/schemas/RefreshToken"
description: List of active refresh tokens.
required:
- refresh_tokens
Error:
type: object
properties:
@@ -1463,6 +1538,22 @@ components:
format: jwt
description: Reset token generated and sent in email.
RevokeRefreshTokenReq:
description: JSON-formatted document describing the refresh token to revoke.
required: true
content:
application/json:
schema:
type: object
properties:
token_id:
type: string
format: uuid
example: "bb7edb32-2eac-4aad-aebe-ed96fe073879"
description: The unique identifier of the refresh token to revoke.
required:
- token_id
PasswordChange:
description: Password change data. User can change its password.
required: true
@@ -1574,6 +1665,13 @@ components:
example: access
description: User access token type.
RefreshTokensPageRes:
description: List of active refresh tokens for the authenticated user.
content:
application/json:
schema:
$ref: "#/components/schemas/RefreshTokensPage"
HealthRes:
description: Service Health Check.
content:
+72 -12
View File
@@ -18,14 +18,16 @@ import (
const tokenSvcName = "token.v1.TokenService"
type tokenGrpcClient struct {
issue endpoint.Endpoint
refresh endpoint.Endpoint
timeout time.Duration
issue endpoint.Endpoint
refresh endpoint.Endpoint
revoke endpoint.Endpoint
listUserRefreshTokens endpoint.Endpoint
timeout time.Duration
}
var _ grpcTokenV1.TokenServiceClient = (*tokenGrpcClient)(nil)
// NewAuthClient returns new auth gRPC client instance.
// NewTokenClient returns new token gRPC client instance.
func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.TokenServiceClient {
return &tokenGrpcClient{
issue: kitgrpc.NewClient(
@@ -44,6 +46,22 @@ func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.To
decodeRefreshResponse,
grpcTokenV1.Token{},
).Endpoint(),
revoke: kitgrpc.NewClient(
conn,
tokenSvcName,
"Revoke",
encodeRevokeRequest,
decodeRevokeResponse,
grpcTokenV1.RevokeRes{},
).Endpoint(),
listUserRefreshTokens: kitgrpc.NewClient(
conn,
tokenSvcName,
"ListUserRefreshTokens",
encodeListUserRefreshTokensRequest,
decodeListUserRefreshTokensResponse,
grpcTokenV1.ListUserRefreshTokensRes{},
).Endpoint(),
timeout: timeout,
}
}
@@ -53,10 +71,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
defer cancel()
res, err := client.issue(ctx, issueReq{
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.GetVerified(),
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.GetVerified(),
description: req.GetDescription(),
})
if err != nil {
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
@@ -67,10 +86,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(issueReq)
return &grpcTokenV1.IssueReq{
UserId: req.userID,
UserRole: uint32(req.userRole),
Type: uint32(req.keyType),
Verified: req.verified,
UserId: req.userID,
UserRole: uint32(req.userRole),
Type: uint32(req.keyType),
Verified: req.verified,
Description: req.description,
}, nil
}
@@ -97,3 +117,43 @@ func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes, nil
}
func (client tokenGrpcClient) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq, _ ...grpc.CallOption) (*grpcTokenV1.RevokeRes, error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.revoke(ctx, revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()})
if err != nil {
return &grpcTokenV1.RevokeRes{}, grpcapi.DecodeError(err)
}
return res.(*grpcTokenV1.RevokeRes), nil
}
func encodeRevokeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(revokeReq)
return &grpcTokenV1.RevokeReq{UserId: req.userID, TokenId: req.tokenID}, nil
}
func decodeRevokeResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes, nil
}
func (client tokenGrpcClient) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq, _ ...grpc.CallOption) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
res, err := client.listUserRefreshTokens(ctx, listUserRefreshTokensReq{userID: req.GetUserId()})
if err != nil {
return &grpcTokenV1.ListUserRefreshTokensRes{}, grpcapi.DecodeError(err)
}
return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil
}
func encodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(listUserRefreshTokensReq)
return &grpcTokenV1.ListUserRefreshTokensReq{UserId: req.userID}, nil
}
func decodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) {
return grpcRes, nil
}
+36 -4
View File
@@ -18,10 +18,11 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
}
key := auth.Key{
Type: req.keyType,
Subject: req.userID,
Role: req.userRole,
Verified: req.verified,
Type: req.keyType,
Subject: req.userID,
Role: req.userRole,
Verified: req.verified,
Description: req.description,
}
tkn, err := svc.Issue(ctx, "", key)
if err != nil {
@@ -56,3 +57,34 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint {
return ret, nil
}
}
func revokeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(revokeReq)
if err := req.validate(); err != nil {
return nil, err
}
err := svc.RevokeToken(ctx, req.userID, req.tokenID)
if err != nil {
return nil, err
}
return nil, nil
}
}
func listUserRefreshTokensEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(listUserRefreshTokensReq)
if err := req.validate(); err != nil {
return listUserRefreshTokensRes{}, err
}
refreshTokens, err := svc.ListUserRefreshTokens(ctx, req.userID)
if err != nil {
return listUserRefreshTokensRes{}, err
}
return listUserRefreshTokensRes{refreshTokens: refreshTokens}, nil
}
}
+97 -21
View File
@@ -24,24 +24,10 @@ import (
)
const (
port = 8082
secret = "secret"
email = "test@example.com"
id = "testID"
clientsType = "clients"
usersType = "users"
description = "Description"
groupName = "smqx"
adminPermission = "admin"
authoritiesObj = "authorities"
memberRelation = "member"
loginDuration = 30 * time.Minute
refreshDuration = 24 * time.Hour
invalidDuration = 7 * 24 * time.Hour
validToken = "valid"
inValidToken = "invalid"
validPolicy = "valid"
port = 8082
validToken = "valid"
inValidToken = "invalid"
invalidID = "invalid"
)
var (
@@ -63,9 +49,9 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server {
func TestIssue(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
@@ -127,9 +113,9 @@ func TestIssue(t *testing.T) {
func TestRefresh(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
@@ -161,9 +147,99 @@ func TestRefresh(t *testing.T) {
}
for _, tc := range cases {
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
_, err := grpcClient.Refresh(context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: tc.token})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
func TestRevoke(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
id string
err error
}{
{
desc: "revoke token with valid id",
id: validID,
err: nil,
},
{
desc: "revoke token with invalid id",
id: invalidID,
err: svcerr.ErrAuthentication,
},
{
desc: "revoke token with empty id",
id: "",
err: apiutil.ErrMissingID,
},
{
desc: "revoke already revoked token",
id: validID,
err: svcerr.ErrConflict,
},
}
for _, tc := range cases {
svcCall := svc.On("RevokeToken", mock.Anything, mock.Anything, tc.id).Return(tc.err)
_, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{TokenId: tc.id})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
func TestListUserRefreshTokens(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
defer conn.Close()
cases := []struct {
desc string
userID string
listResponse []auth.TokenInfo
err error
}{
{
desc: "list tokens for user with valid id",
userID: validID,
listResponse: []auth.TokenInfo{
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 1"},
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 2"},
},
err: nil,
},
{
desc: "list tokens for user with empty list",
userID: validID,
listResponse: []auth.TokenInfo{},
err: nil,
},
{
desc: "list tokens with invalid user id",
userID: invalidID,
listResponse: nil,
err: svcerr.ErrAuthentication,
},
{
desc: "list tokens with empty user id",
userID: "",
listResponse: nil,
err: apiutil.ErrMissingID,
},
}
for _, tc := range cases {
svcCall := svc.On("ListUserRefreshTokens", mock.Anything, tc.userID).Return(tc.listResponse, tc.err)
_, err := grpcClient.ListUserRefreshTokens(context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.userID})
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
}
}
+30 -4
View File
@@ -9,10 +9,11 @@ import (
)
type issueReq struct {
userID string
userRole auth.Role
keyType auth.KeyType
verified bool
userID string
userRole auth.Role
keyType auth.KeyType
verified bool
description string
}
func (req issueReq) validate() error {
@@ -38,3 +39,28 @@ func (req refreshReq) validate() error {
return nil
}
type revokeReq struct {
userID string
tokenID string
}
func (req revokeReq) validate() error {
if req.tokenID == "" {
return apiutil.ErrMissingID
}
return nil
}
type listUserRefreshTokensReq struct {
userID string
}
func (req listUserRefreshTokensReq) validate() error {
if req.userID == "" {
return apiutil.ErrMissingID
}
return nil
}
+6
View File
@@ -3,8 +3,14 @@
package token
import "github.com/absmach/supermq/auth"
type issueRes struct {
accessToken string
refreshToken string
accessType string
}
type listUserRefreshTokensRes struct {
refreshTokens []auth.TokenInfo
}
+66 -7
View File
@@ -16,11 +16,13 @@ var _ grpcTokenV1.TokenServiceServer = (*tokenGrpcServer)(nil)
type tokenGrpcServer struct {
grpcTokenV1.UnimplementedTokenServiceServer
issue kitgrpc.Handler
refresh kitgrpc.Handler
issue kitgrpc.Handler
refresh kitgrpc.Handler
revoke kitgrpc.Handler
listUserRefreshTokens kitgrpc.Handler
}
// NewAuthServer returns new AuthnServiceServer instance.
// NewTokenServer returns new TokenServiceServer instance.
func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer {
return &tokenGrpcServer{
issue: kitgrpc.NewServer(
@@ -33,6 +35,16 @@ func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer {
decodeRefreshRequest,
encodeIssueResponse,
),
revoke: kitgrpc.NewServer(
(revokeEndpoint(svc)),
decodeRevokeRequest,
encodeRevokeResponse,
),
listUserRefreshTokens: kitgrpc.NewServer(
(listUserRefreshTokensEndpoint(svc)),
decodeListUserRefreshTokensRequest,
encodeListUserRefreshTokensResponse,
),
}
}
@@ -55,10 +67,11 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *grpcTokenV1.RefreshR
func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcTokenV1.IssueReq)
return issueReq{
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.Verified,
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
verified: req.Verified,
description: req.GetDescription(),
}, nil
}
@@ -76,3 +89,49 @@ func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) {
AccessType: res.accessType,
}, nil
}
func (s *tokenGrpcServer) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq) (*grpcTokenV1.RevokeRes, error) {
_, res, err := s.revoke.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
}
return res.(*grpcTokenV1.RevokeRes), nil
}
func decodeRevokeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcTokenV1.RevokeReq)
return revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()}, nil
}
func encodeRevokeResponse(_ context.Context, grpcRes any) (any, error) {
return &grpcTokenV1.RevokeRes{}, nil
}
func (s *tokenGrpcServer) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
_, res, err := s.listUserRefreshTokens.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
}
return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil
}
func decodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcTokenV1.ListUserRefreshTokensReq)
return listUserRefreshTokensReq{userID: req.GetUserId()}, nil
}
func encodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) {
res := grpcRes.(listUserRefreshTokensRes)
refreshTokens := make([]*grpcTokenV1.RefreshToken, len(res.refreshTokens))
for i, refreshToken := range res.refreshTokens {
refreshTokens[i] = &grpcTokenV1.RefreshToken{
Id: refreshToken.ID,
Description: refreshToken.Description,
}
}
return &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: refreshTokens,
}, nil
}
+2
View File
@@ -1,4 +1,6 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package cache contains the domain concept definitions needed to
// support SuperMQ auth cache service functionality.
package cache
+132
View File
@@ -0,0 +1,132 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache
import (
"context"
"strconv"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
"github.com/redis/go-redis/v9"
)
const (
refreshPrefix = "refresh_tokens:"
scoreNegInf = "-inf"
scorePosInf = "+inf"
)
var _ auth.UserActiveTokensCache = (*tokensCache)(nil)
type tokensCache struct {
client *redis.Client
keyDuration time.Duration
}
// NewUserActiveTokensCache returns redis auth cache implementation.
func NewUserActiveTokensCache(client *redis.Client, duration time.Duration) (auth.UserActiveTokensCache, error) {
if duration == 0 {
return nil, errors.New("token cache duration must not be zero")
}
return &tokensCache{
client: client,
keyDuration: duration,
}, nil
}
// SaveActive saves an active refresh token ID for a user with optional description.
func (tc *tokensCache) SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error {
ttl := min(tc.keyDuration, time.Until(expiry))
pipe := tc.client.TxPipeline()
pipe.Set(ctx, tokenKey(tokenID), description, ttl)
pipe.ZAdd(ctx, userTokensKey(userID), redis.Z{
Score: float64(time.Now().Add(ttl).Unix()),
Member: tokenID,
})
_, err := pipe.Exec(ctx)
return err
}
// IsActive checks if the token ID is active for the given user.
func (tc *tokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) {
count, err := tc.client.Exists(ctx, tokenKey(tokenID)).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// ListUserTokens lists all active refresh token IDs with descriptions for a user.
func (tc *tokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
key := userTokensKey(userID)
now := strconv.FormatInt(time.Now().Unix(), 10)
pipe := tc.client.TxPipeline()
pipe.ZRemRangeByScore(ctx, key, scoreNegInf, now)
zrangeCmd := pipe.ZRangeByScore(ctx, key, &redis.ZRangeBy{Min: "(" + now, Max: scorePosInf})
if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil {
return nil, err
}
tokenIDs, err := zrangeCmd.Result()
if err != nil {
return nil, err
}
if len(tokenIDs) == 0 {
return nil, nil
}
getPipe := tc.client.Pipeline()
getCmds := make([]*redis.StringCmd, len(tokenIDs))
for i, tokenID := range tokenIDs {
getCmds[i] = getPipe.Get(ctx, tokenKey(tokenID))
}
if _, err = getPipe.Exec(ctx); err != nil && err != redis.Nil {
return nil, err
}
valid := make([]auth.TokenInfo, 0, len(tokenIDs))
for i, cmd := range getCmds {
description, err := cmd.Result()
if err == redis.Nil {
continue
}
if err != nil {
return nil, err
}
valid = append(valid, auth.TokenInfo{
ID: tokenIDs[i],
Description: description,
})
}
return valid, nil
}
// RemoveActive removes an active refresh token ID for a user.
func (tc *tokensCache) RemoveActive(ctx context.Context, userID, tokenID string) error {
pipe := tc.client.TxPipeline()
pipe.Del(ctx, tokenKey(tokenID))
pipe.ZRem(ctx, userTokensKey(userID), tokenID)
_, err := pipe.Exec(ctx)
return err
}
func tokenKey(tokenID string) string {
return refreshPrefix + "token:" + tokenID
}
func userTokensKey(userID string) string {
return refreshPrefix + "user_tokens:" + userID
}
+288
View File
@@ -0,0 +1,288 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache_test
import (
"context"
"fmt"
"os"
"testing"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/auth/cache"
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
)
var (
storeClient *redis.Client
storeURL string
)
func TestMain(m *testing.M) {
code := testsutil.RunRedisTest(m, &storeClient, &storeURL)
os.Exit(code)
}
func setupRedisTokensClient() auth.UserActiveTokensCache {
tc, err := cache.NewUserActiveTokensCache(storeClient, 10*time.Minute)
if err != nil {
panic(err)
}
return tc
}
func TestTokenSave(t *testing.T) {
storeClient.FlushAll(context.Background())
tokensCache := setupRedisTokensClient()
userID := testsutil.GenerateUUID(t)
tokenID := testsutil.GenerateUUID(t)
cases := []struct {
desc string
userID string
tokenID string
description string
expiry time.Time
err error
}{
{
desc: "Save active token",
userID: userID,
tokenID: tokenID,
description: "Test token",
expiry: time.Now().Add(10 * time.Minute),
err: nil,
},
{
desc: "Save already cached token",
userID: userID,
tokenID: tokenID,
description: "Updated token",
expiry: time.Now().Add(10 * time.Minute),
err: nil,
},
{
desc: "Save another token for same user",
userID: userID,
tokenID: testsutil.GenerateUUID(t),
description: "Another token",
expiry: time.Now().Add(10 * time.Minute),
err: nil,
},
{
desc: "Save token with empty id",
userID: userID,
tokenID: "",
description: "Empty ID token",
expiry: time.Now().Add(10 * time.Minute),
err: nil,
},
{
desc: "Save token with empty description",
userID: userID,
tokenID: testsutil.GenerateUUID(t),
description: "",
expiry: time.Now().Add(10 * time.Minute),
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tokensCache.SaveActive(context.Background(), tc.userID, tc.tokenID, tc.description, tc.expiry)
if err == nil {
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
assert.NoError(t, err)
assert.True(t, ok)
}
assert.True(t, errors.Contains(err, tc.err))
})
}
}
func TestTokenContains(t *testing.T) {
storeClient.FlushAll(context.Background())
tokensCache := setupRedisTokensClient()
userID := testsutil.GenerateUUID(t)
tokenID := testsutil.GenerateUUID(t)
err := tokensCache.SaveActive(context.Background(), userID, tokenID, "Test token", time.Now().Add(10*time.Minute))
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
cases := []struct {
desc string
userID string
tokenID string
ok bool
}{
{
desc: "IsActive for existing token",
userID: userID,
tokenID: tokenID,
ok: true,
},
{
desc: "IsActive for non existing token",
userID: userID,
tokenID: testsutil.GenerateUUID(t),
},
{
desc: "IsActive with empty token id",
userID: userID,
tokenID: "",
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
if tc.ok {
assert.NoError(t, err)
}
assert.Equal(t, tc.ok, ok)
})
}
}
func TestTokenRemove(t *testing.T) {
storeClient.FlushAll(context.Background())
tokensCache := setupRedisTokensClient()
userID := testsutil.GenerateUUID(t)
num := 10
var tokenIDs []string
for i := range num {
tokenID := testsutil.GenerateUUID(t)
err := tokensCache.SaveActive(context.Background(), userID, tokenID, fmt.Sprintf("Token %d", i), time.Now().Add(10*time.Minute))
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
tokenIDs = append(tokenIDs, tokenID)
}
cases := []struct {
desc string
userID string
tokenID string
err error
}{
{
desc: "Remove an existing token from cache",
userID: userID,
tokenID: tokenIDs[0],
err: nil,
},
{
desc: "Remove token with empty id from cache",
userID: userID,
tokenID: "",
err: nil,
},
{
desc: "Remove non existing id from cache",
userID: userID,
tokenID: testsutil.GenerateUUID(t),
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
err := tokensCache.RemoveActive(context.Background(), tc.userID, tc.tokenID)
assert.True(t, errors.Contains(err, tc.err))
if err == nil {
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
assert.NoError(t, err)
assert.False(t, ok)
}
})
}
}
func TestListUserTokens(t *testing.T) {
storeClient.FlushAll(context.Background())
tokensCache := setupRedisTokensClient()
userID := testsutil.GenerateUUID(t)
userID2 := testsutil.GenerateUUID(t)
num := 5
var expectedTokens []auth.TokenInfo
for i := range num {
tokenID := testsutil.GenerateUUID(t)
description := fmt.Sprintf("Token %d", i)
err := tokensCache.SaveActive(context.Background(), userID, tokenID, description, time.Now().Add(10*time.Minute))
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
expectedTokens = append(expectedTokens, auth.TokenInfo{
ID: tokenID,
Description: description,
})
}
tokenID2 := testsutil.GenerateUUID(t)
desc2 := "User 2 token"
err := tokensCache.SaveActive(context.Background(), userID2, tokenID2, desc2, time.Now().Add(10*time.Minute))
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
cases := []struct {
desc string
userID string
expectedCount int
expectedTokens []auth.TokenInfo
err error
}{
{
desc: "List all tokens for user with multiple tokens",
userID: userID,
expectedCount: num,
expectedTokens: expectedTokens,
err: nil,
},
{
desc: "List tokens for user with single token",
userID: userID2,
expectedCount: 1,
expectedTokens: []auth.TokenInfo{{ID: tokenID2, Description: desc2}},
err: nil,
},
{
desc: "List tokens for user with no tokens",
userID: testsutil.GenerateUUID(t),
expectedCount: 0,
expectedTokens: nil,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
tokens, err := tokensCache.ListUserTokens(context.Background(), tc.userID)
assert.True(t, errors.Contains(err, tc.err))
assert.Equal(t, tc.expectedCount, len(tokens))
if tc.expectedTokens != nil {
assert.ElementsMatch(t, tc.expectedTokens, tokens)
}
})
}
t.Run("Cleanup expired tokens from list", func(t *testing.T) {
// Remove one token directly from Redis to simulate expiration
err := tokensCache.RemoveActive(context.Background(), userID, expectedTokens[0].ID)
assert.NoError(t, err)
// List should now return only valid tokens
tokens, err := tokensCache.ListUserTokens(context.Background(), userID)
assert.NoError(t, err)
assert.Equal(t, num-1, len(tokens))
// Check that the removed token is not in the list
for _, token := range tokens {
assert.NotEqual(t, expectedTokens[0].ID, token.ID)
}
})
}
+26 -1
View File
@@ -5,13 +5,16 @@ package auth
import (
"context"
"errors"
"time"
"github.com/absmach/supermq/pkg/errors"
)
var (
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm")
ErrRevokedToken = errors.NewAuthNError("token is revoked")
)
// PublicKeyInfo represents a public key for external distribution via JWKS.
@@ -33,6 +36,7 @@ type PublicKeyInfo struct {
// Implementations manage underlying cryptographic operations and key distribution.
type Tokenizer interface {
// Issue creates a signed token string from the given key claims.
// For RefreshKey types, the token ID is stored as active in the cache.
Issue(key Key) (token string, err error)
// Parse verifies and parses a token string (JWT or PAT), returning the extracted claims.
@@ -45,6 +49,27 @@ type Tokenizer interface {
RetrieveJWKS() ([]PublicKeyInfo, error)
}
// UserActiveTokensCache represents a cache repository for managing active refresh tokens per user.
type UserActiveTokensCache interface {
// SaveActive saves an active refresh token ID for a user with optional description.
SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error
// IsActive checks if the token ID is active.
IsActive(ctx context.Context, tokenID string) (bool, error)
// ListUserTokens lists all active token IDs with descriptions for a given user.
ListUserTokens(ctx context.Context, userID string) ([]TokenInfo, error)
// RemoveActive removes an active refresh token ID.
RemoveActive(ctx context.Context, userID, tokenID string) error
}
// TokenInfo represents information about an active refresh token.
type TokenInfo struct {
ID string `json:"id"`
Description string `json:"description,omitempty"`
}
// IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based).
// Returns true for HMAC algorithms (HS256, HS384, HS512).
// Returns false for asymmetric algorithms (EdDSA).
+9 -8
View File
@@ -81,14 +81,15 @@ func (r Role) Validate() bool {
// Key represents API key.
type Key struct {
ID string `json:"id,omitempty"`
Type KeyType `json:"type,omitempty"`
Issuer string `json:"issuer,omitempty"`
Subject string `json:"subject,omitempty"` // user ID
Role Role `json:"role,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Verified bool `json:"verified,omitempty"`
ID string `json:"id,omitempty"`
Type KeyType `json:"type,omitempty"`
Issuer string `json:"issuer,omitempty"`
Subject string `json:"subject,omitempty"` // user ID
Role Role `json:"role,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Verified bool `json:"verified,omitempty"`
Description string `json:"description,omitempty"` // Optional description for refresh tokens
}
func (key Key) String() string {
+36
View File
@@ -110,6 +110,42 @@ func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) {
return lm.svc.RetrieveJWKS()
}
func (lm *loggingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.String("token_id", tokenID),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Revoke token failed", args...)
return
}
lm.logger.Info("Revoke token completed successfully", args...)
}(time.Now())
return lm.svc.RevokeToken(ctx, userID, tokenID)
}
func (lm *loggingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) (tokens []auth.TokenInfo, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.Int("tokens_count", len(tokens)),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List user refresh tokens failed", args...)
return
}
lm.logger.Info("List user refresh tokens completed successfully", args...)
}(time.Now())
return lm.svc.ListUserRefreshTokens(ctx, userID)
}
func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) (err error) {
defer func(begin time.Time) {
args := []any{
+18
View File
@@ -40,6 +40,15 @@ func (ms *metricsMiddleware) Issue(ctx context.Context, token string, key auth.K
return ms.svc.Issue(ctx, token, key)
}
func (ms *metricsMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "revoke_token").Add(1)
ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RevokeToken(ctx, userID, tokenID)
}
func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error {
defer func(begin time.Time) {
ms.counter.With("method", "revoke_key").Add(1)
@@ -75,6 +84,15 @@ func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
return ms.svc.RetrieveJWKS()
}
func (ms *metricsMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_user_refresh_tokens").Add(1)
ms.latency.With("method", "list_user_refresh_tokens").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListUserRefreshTokens(ctx, userID)
}
func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error {
defer func(begin time.Time) {
ms.counter.With("method", "authorize").Add(1)
+19
View File
@@ -65,6 +65,25 @@ func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
return tm.svc.RetrieveJWKS()
}
func (tm *tracingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error {
ctx, span := tm.tracer.Start(ctx, "revoke_token", trace.WithAttributes(
attribute.String("user_id", userID),
attribute.String("token_id", tokenID),
))
defer span.End()
return tm.svc.RevokeToken(ctx, userID, tokenID)
}
func (tm *tracingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
ctx, span := tm.tracer.Start(ctx, "list_user_refresh_tokens", trace.WithAttributes(
attribute.String("user_id", userID),
))
defer span.End()
return tm.svc.ListUserRefreshTokens(ctx, userID)
}
func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error {
attributes := []attribute.KeyValue{
attribute.String("subject", pr.Subject),
+131
View File
@@ -758,6 +758,74 @@ func (_c *Service_ListScopes_Call) RunAndReturn(run func(ctx context.Context, to
return _c
}
// ListUserRefreshTokens provides a mock function for the type Service
func (_mock *Service) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
ret := _mock.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for ListUserRefreshTokens")
}
var r0 []auth.TokenInfo
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok {
return returnFunc(ctx, userID)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok {
r0 = returnFunc(ctx, userID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.TokenInfo)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, userID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens'
type Service_ListUserRefreshTokens_Call struct {
*mock.Call
}
// ListUserRefreshTokens is a helper method to define mock.On call
// - ctx context.Context
// - userID string
func (_e *Service_Expecter) ListUserRefreshTokens(ctx interface{}, userID interface{}) *Service_ListUserRefreshTokens_Call {
return &Service_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens", ctx, userID)}
}
func (_c *Service_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, userID string)) *Service_ListUserRefreshTokens_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Service_ListUserRefreshTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *Service_ListUserRefreshTokens_Call {
_c.Call.Return(tokenInfos, err)
return _c
}
func (_c *Service_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *Service_ListUserRefreshTokens_Call {
_c.Call.Return(run)
return _c
}
// RemoveAllPAT provides a mock function for the type Service
func (_mock *Service) RemoveAllPAT(ctx context.Context, token string) error {
ret := _mock.Called(ctx, token)
@@ -1350,6 +1418,69 @@ func (_c *Service_RevokePATSecret_Call) RunAndReturn(run func(ctx context.Contex
return _c
}
// RevokeToken provides a mock function for the type Service
func (_mock *Service) RevokeToken(ctx context.Context, userID string, tokenID string) error {
ret := _mock.Called(ctx, userID, tokenID)
if len(ret) == 0 {
panic("no return value specified for RevokeToken")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = returnFunc(ctx, userID, tokenID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_RevokeToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeToken'
type Service_RevokeToken_Call struct {
*mock.Call
}
// RevokeToken is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - tokenID string
func (_e *Service_Expecter) RevokeToken(ctx interface{}, userID interface{}, tokenID interface{}) *Service_RevokeToken_Call {
return &Service_RevokeToken_Call{Call: _e.mock.On("RevokeToken", ctx, userID, tokenID)}
}
func (_c *Service_RevokeToken_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *Service_RevokeToken_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Service_RevokeToken_Call) Return(err error) *Service_RevokeToken_Call {
_c.Call.Return(err)
return _c
}
func (_c *Service_RevokeToken_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *Service_RevokeToken_Call {
_c.Call.Return(run)
return _c
}
// UpdatePATDescription provides a mock function for the type Service
func (_mock *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) {
ret := _mock.Called(ctx, token, patID, description)
+166
View File
@@ -126,6 +126,89 @@ func (_c *TokenServiceClient_Issue_Call) RunAndReturn(run func(ctx context.Conte
return _c
}
// ListUserRefreshTokens provides a mock function for the type TokenServiceClient
func (_mock *TokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for ListUserRefreshTokens")
}
var r0 *v1.ListUserRefreshTokensRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) *v1.ListUserRefreshTokensRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// TokenServiceClient_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens'
type TokenServiceClient_ListUserRefreshTokens_Call struct {
*mock.Call
}
// ListUserRefreshTokens is a helper method to define mock.On call
// - ctx context.Context
// - in *v1.ListUserRefreshTokensReq
// - opts ...grpc.CallOption
func (_e *TokenServiceClient_Expecter) ListUserRefreshTokens(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_ListUserRefreshTokens_Call {
return &TokenServiceClient_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption)) *TokenServiceClient_ListUserRefreshTokens_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v1.ListUserRefreshTokensReq
if args[1] != nil {
arg1 = args[1].(*v1.ListUserRefreshTokensReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *TokenServiceClient_ListUserRefreshTokens_Call {
_c.Call.Return(listUserRefreshTokensRes, err)
return _c
}
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)) *TokenServiceClient_ListUserRefreshTokens_Call {
_c.Call.Return(run)
return _c
}
// Refresh provides a mock function for the type TokenServiceClient
func (_mock *TokenServiceClient) Refresh(ctx context.Context, in *v1.RefreshReq, opts ...grpc.CallOption) (*v1.Token, error) {
var tmpRet mock.Arguments
@@ -208,3 +291,86 @@ func (_c *TokenServiceClient_Refresh_Call) RunAndReturn(run func(ctx context.Con
_c.Call.Return(run)
return _c
}
// Revoke provides a mock function for the type TokenServiceClient
func (_mock *TokenServiceClient) Revoke(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error) {
var tmpRet mock.Arguments
if len(opts) > 0 {
tmpRet = _mock.Called(ctx, in, opts)
} else {
tmpRet = _mock.Called(ctx, in)
}
ret := tmpRet
if len(ret) == 0 {
panic("no return value specified for Revoke")
}
var r0 *v1.RevokeRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*v1.RevokeRes, error)); ok {
return returnFunc(ctx, in, opts...)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *v1.RevokeRes); ok {
r0 = returnFunc(ctx, in, opts...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.RevokeRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok {
r1 = returnFunc(ctx, in, opts...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// TokenServiceClient_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke'
type TokenServiceClient_Revoke_Call struct {
*mock.Call
}
// Revoke is a helper method to define mock.On call
// - ctx context.Context
// - in *v1.RevokeReq
// - opts ...grpc.CallOption
func (_e *TokenServiceClient_Expecter) Revoke(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_Revoke_Call {
return &TokenServiceClient_Revoke_Call{Call: _e.mock.On("Revoke",
append([]interface{}{ctx, in}, opts...)...)}
}
func (_c *TokenServiceClient_Revoke_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *TokenServiceClient_Revoke_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 *v1.RevokeReq
if args[1] != nil {
arg1 = args[1].(*v1.RevokeReq)
}
var arg2 []grpc.CallOption
var variadicArgs []grpc.CallOption
if len(args) > 2 {
variadicArgs = args[2].([]grpc.CallOption)
}
arg2 = variadicArgs
run(
arg0,
arg1,
arg2...,
)
})
return _c
}
func (_c *TokenServiceClient_Revoke_Call) Return(revokeRes *v1.RevokeRes, err error) *TokenServiceClient_Revoke_Call {
_c.Call.Return(revokeRes, err)
return _c
}
func (_c *TokenServiceClient_Revoke_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error)) *TokenServiceClient_Revoke_Call {
_c.Call.Return(run)
return _c
}
+316
View File
@@ -0,0 +1,316 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Code generated by mockery; DO NOT EDIT.
// github.com/vektra/mockery
// template: testify
package mocks
import (
"context"
"time"
"github.com/absmach/supermq/auth"
mock "github.com/stretchr/testify/mock"
)
// NewUserActiveTokensCache creates a new instance of UserActiveTokensCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewUserActiveTokensCache(t interface {
mock.TestingT
Cleanup(func())
}) *UserActiveTokensCache {
mock := &UserActiveTokensCache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// UserActiveTokensCache is an autogenerated mock type for the UserActiveTokensCache type
type UserActiveTokensCache struct {
mock.Mock
}
type UserActiveTokensCache_Expecter struct {
mock *mock.Mock
}
func (_m *UserActiveTokensCache) EXPECT() *UserActiveTokensCache_Expecter {
return &UserActiveTokensCache_Expecter{mock: &_m.Mock}
}
// IsActive provides a mock function for the type UserActiveTokensCache
func (_mock *UserActiveTokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) {
ret := _mock.Called(ctx, tokenID)
if len(ret) == 0 {
panic("no return value specified for IsActive")
}
var r0 bool
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok {
return returnFunc(ctx, tokenID)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok {
r0 = returnFunc(ctx, tokenID)
} else {
r0 = ret.Get(0).(bool)
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, tokenID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UserActiveTokensCache_IsActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsActive'
type UserActiveTokensCache_IsActive_Call struct {
*mock.Call
}
// IsActive is a helper method to define mock.On call
// - ctx context.Context
// - tokenID string
func (_e *UserActiveTokensCache_Expecter) IsActive(ctx interface{}, tokenID interface{}) *UserActiveTokensCache_IsActive_Call {
return &UserActiveTokensCache_IsActive_Call{Call: _e.mock.On("IsActive", ctx, tokenID)}
}
func (_c *UserActiveTokensCache_IsActive_Call) Run(run func(ctx context.Context, tokenID string)) *UserActiveTokensCache_IsActive_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *UserActiveTokensCache_IsActive_Call) Return(b bool, err error) *UserActiveTokensCache_IsActive_Call {
_c.Call.Return(b, err)
return _c
}
func (_c *UserActiveTokensCache_IsActive_Call) RunAndReturn(run func(ctx context.Context, tokenID string) (bool, error)) *UserActiveTokensCache_IsActive_Call {
_c.Call.Return(run)
return _c
}
// ListUserTokens provides a mock function for the type UserActiveTokensCache
func (_mock *UserActiveTokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
ret := _mock.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for ListUserTokens")
}
var r0 []auth.TokenInfo
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok {
return returnFunc(ctx, userID)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok {
r0 = returnFunc(ctx, userID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]auth.TokenInfo)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = returnFunc(ctx, userID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UserActiveTokensCache_ListUserTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserTokens'
type UserActiveTokensCache_ListUserTokens_Call struct {
*mock.Call
}
// ListUserTokens is a helper method to define mock.On call
// - ctx context.Context
// - userID string
func (_e *UserActiveTokensCache_Expecter) ListUserTokens(ctx interface{}, userID interface{}) *UserActiveTokensCache_ListUserTokens_Call {
return &UserActiveTokensCache_ListUserTokens_Call{Call: _e.mock.On("ListUserTokens", ctx, userID)}
}
func (_c *UserActiveTokensCache_ListUserTokens_Call) Run(run func(ctx context.Context, userID string)) *UserActiveTokensCache_ListUserTokens_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *UserActiveTokensCache_ListUserTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *UserActiveTokensCache_ListUserTokens_Call {
_c.Call.Return(tokenInfos, err)
return _c
}
func (_c *UserActiveTokensCache_ListUserTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *UserActiveTokensCache_ListUserTokens_Call {
_c.Call.Return(run)
return _c
}
// RemoveActive provides a mock function for the type UserActiveTokensCache
func (_mock *UserActiveTokensCache) RemoveActive(ctx context.Context, userID string, tokenID string) error {
ret := _mock.Called(ctx, userID, tokenID)
if len(ret) == 0 {
panic("no return value specified for RemoveActive")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = returnFunc(ctx, userID, tokenID)
} else {
r0 = ret.Error(0)
}
return r0
}
// UserActiveTokensCache_RemoveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveActive'
type UserActiveTokensCache_RemoveActive_Call struct {
*mock.Call
}
// RemoveActive is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - tokenID string
func (_e *UserActiveTokensCache_Expecter) RemoveActive(ctx interface{}, userID interface{}, tokenID interface{}) *UserActiveTokensCache_RemoveActive_Call {
return &UserActiveTokensCache_RemoveActive_Call{Call: _e.mock.On("RemoveActive", ctx, userID, tokenID)}
}
func (_c *UserActiveTokensCache_RemoveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *UserActiveTokensCache_RemoveActive_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *UserActiveTokensCache_RemoveActive_Call) Return(err error) *UserActiveTokensCache_RemoveActive_Call {
_c.Call.Return(err)
return _c
}
func (_c *UserActiveTokensCache_RemoveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *UserActiveTokensCache_RemoveActive_Call {
_c.Call.Return(run)
return _c
}
// SaveActive provides a mock function for the type UserActiveTokensCache
func (_mock *UserActiveTokensCache) SaveActive(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error {
ret := _mock.Called(ctx, userID, tokenID, description, expiry)
if len(ret) == 0 {
panic("no return value specified for SaveActive")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) error); ok {
r0 = returnFunc(ctx, userID, tokenID, description, expiry)
} else {
r0 = ret.Error(0)
}
return r0
}
// UserActiveTokensCache_SaveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveActive'
type UserActiveTokensCache_SaveActive_Call struct {
*mock.Call
}
// SaveActive is a helper method to define mock.On call
// - ctx context.Context
// - userID string
// - tokenID string
// - description string
// - expiry time.Time
func (_e *UserActiveTokensCache_Expecter) SaveActive(ctx interface{}, userID interface{}, tokenID interface{}, description interface{}, expiry interface{}) *UserActiveTokensCache_SaveActive_Call {
return &UserActiveTokensCache_SaveActive_Call{Call: _e.mock.On("SaveActive", ctx, userID, tokenID, description, expiry)}
}
func (_c *UserActiveTokensCache_SaveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time)) *UserActiveTokensCache_SaveActive_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 string
if args[1] != nil {
arg1 = args[1].(string)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
var arg4 time.Time
if args[4] != nil {
arg4 = args[4].(time.Time)
}
run(
arg0,
arg1,
arg2,
arg3,
arg4,
)
})
return _c
}
func (_c *UserActiveTokensCache_SaveActive_Call) Return(err error) *UserActiveTokensCache_SaveActive_Call {
_c.Call.Return(err)
return _c
}
func (_c *UserActiveTokensCache_SaveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error) *UserActiveTokensCache_SaveActive_Call {
_c.Call.Return(run)
return _c
}
+24
View File
@@ -0,0 +1,24 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import "github.com/absmach/supermq/pkg/errors"
var _ errors.Mapper = (*duplicateErrors)(nil)
type duplicateErrors struct{}
// GetError maps constraint names to known errors.
func (d duplicateErrors) GetError(constraint string) (error, bool) {
switch constraint {
case "revoked_tokens_pkey":
return errors.NewRequestError("revoked token already exists"), true
default:
return nil, false
}
}
func NewDuplicateErrors() errors.Mapper {
return duplicateErrors{}
}
+54 -9
View File
@@ -29,13 +29,16 @@ var (
// ErrExpiry indicates that the token is expired.
ErrExpiry = errors.New("token is expired")
errIssueUser = errors.New("failed to issue new login key")
errIssueTmp = errors.New("failed to issue new temporary key")
errRevoke = errors.New("failed to remove key")
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errRoleAuth = errors.New("failed to authorize user role")
errIssueUser = errors.New("failed to issue new login key")
errIssueTmp = errors.New("failed to issue new temporary key")
errRevoke = errors.New("failed to remove key")
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errRoleAuth = errors.New("failed to authorize user role")
errSaveRefreshKey = errors.NewServiceError("failed to save refresh key")
errRevokeRefreshKey = errors.NewServiceError("failed to revoke refresh key")
errListRefreshKeys = errors.NewServiceError("failed to list refresh keys")
errMalformedPAT = errors.New("malformed personal access token")
errFailedToParseUUID = errors.New("failed to parse string to UUID")
@@ -67,6 +70,9 @@ type Authn interface {
// Issue issues a new Key, returning its token value alongside.
Issue(ctx context.Context, token string, key Key) (Token, error)
// RevokeToken revokes the refresh token by its ID.
RevokeToken(ctx context.Context, userID, tokenID string) error
// Revoke removes the Key with the provided id that is
// issued by the user identified by the provided key.
Revoke(ctx context.Context, token, id string) error
@@ -82,6 +88,9 @@ type Authn interface {
// RetrieveJWKS retrieves public keys to validate issued tokens.
RetrieveJWKS() []PublicKeyInfo
// ListUserRefreshTokens lists all active refresh token sessions for a user.
ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error)
}
// Service specifies an API that must be fulfilled by the domain service
@@ -100,6 +109,7 @@ type service struct {
keys KeyRepository
pats PATSRepository
cache Cache
tokensCache UserActiveTokensCache
hasher Hasher
idProvider supermq.IDProvider
evaluator policies.Evaluator
@@ -111,12 +121,13 @@ type service struct {
}
// New instantiates the auth service implementation.
func New(keys KeyRepository, pats PATSRepository, cache Cache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
func New(keys KeyRepository, pats PATSRepository, cache Cache, tokensCache UserActiveTokensCache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
return &service{
tokenizer: tokenizer,
keys: keys,
pats: pats,
cache: cache,
tokensCache: tokensCache,
hasher: hasher,
idProvider: idp,
evaluator: policyEvaluator,
@@ -143,6 +154,14 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
}
}
func (svc service) RevokeToken(ctx context.Context, userID, tokenID string) error {
if err := svc.tokensCache.RemoveActive(ctx, userID, tokenID); err != nil {
return errors.Wrap(errRevokeRefreshKey, err)
}
return nil
}
func (svc service) Revoke(ctx context.Context, token, id string) error {
issuerID, _, err := svc.authenticate(ctx, token)
if err != nil {
@@ -205,6 +224,15 @@ func (svc service) RetrieveJWKS() []PublicKeyInfo {
return keys
}
func (svc service) ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error) {
tokenInfo, err := svc.tokensCache.ListUserTokens(ctx, userID)
if err != nil {
return nil, errors.Wrap(errListRefreshKeys, err)
}
return tokenInfo, nil
}
func (svc service) Authorize(ctx context.Context, pr policies.Policy, patAuthz *PATAuthz) error {
if patAuthz != nil {
if err := svc.AuthorizePAT(ctx, patAuthz.UserID, patAuthz.PatID, patAuthz.EntityType, patAuthz.Domain, patAuthz.Operation, patAuthz.EntityID); err != nil {
@@ -265,10 +293,20 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration)
key.Type = RefreshKey
id, err := svc.idProvider.ID()
if err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}
key.ID = id
refresh, err := svc.tokenizer.Issue(key)
if err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}
if key.Subject != "" && key.ExpiresAt.After(time.Now()) {
if err := svc.tokensCache.SaveActive(ctx, key.Subject, key.ID, key.Description, key.ExpiresAt); err != nil {
return Token{}, errors.Wrap(errSaveRefreshKey, err)
}
}
return Token{AccessToken: access, RefreshToken: refresh}, nil
}
@@ -298,6 +336,13 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
if k.Type != RefreshKey {
return Token{}, errIssueUser
}
ok, err := svc.tokensCache.IsActive(ctx, key.ID)
if err != nil {
return Token{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if !ok {
return Token{}, ErrRevokedToken
}
key.ID = k.ID
key.Type = AccessKey
key.Subject = k.Subject
@@ -313,7 +358,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
return Token{}, errors.Wrap(errIssueTmp, err)
}
key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration)
key.ExpiresAt = k.ExpiresAt
key.Type = RefreshKey
refresh, err := svc.tokenizer.Issue(key)
if err != nil {
+232 -12
View File
@@ -54,18 +54,20 @@ var (
)
var (
krepo *mocks.KeyRepository
pService *policymocks.Service
pEvaluator *policymocks.Evaluator
patsrepo *mocks.PATSRepository
cache *mocks.Cache
hasher *mocks.Hasher
tokenizer *mocks.Tokenizer
krepo *mocks.KeyRepository
pService *policymocks.Service
pEvaluator *policymocks.Evaluator
patsrepo *mocks.PATSRepository
cache *mocks.Cache
tokensCache *mocks.UserActiveTokensCache
hasher *mocks.Hasher
tokenizer *mocks.Tokenizer
)
func newService(t *testing.T) (auth.Service, string) {
krepo = new(mocks.KeyRepository)
cache = new(mocks.Cache)
tokensCache = new(mocks.UserActiveTokensCache)
pService = new(policymocks.Service)
pEvaluator = new(policymocks.Evaluator)
patsrepo = new(mocks.PATSRepository)
@@ -76,7 +78,7 @@ func newService(t *testing.T) (auth.Service, string) {
token, _, err := signToken(t, issuerName, accessKey, false)
assert.Nil(t, err, fmt.Sprintf("Issuing access key expected to succeed: %s", err))
return auth.New(krepo, patsrepo, cache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
return auth.New(krepo, patsrepo, cache, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
}
func TestIssue(t *testing.T) {
@@ -133,7 +135,7 @@ func TestIssue(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr)
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: tc.key.Subject,
SubjectType: policies.UserType,
@@ -154,6 +156,7 @@ func TestIssue(t *testing.T) {
saveResponse auth.Key
token string
tokenizerErr error
cacheErr error
saveErr error
roleCheckErr error
err error
@@ -169,11 +172,24 @@ func TestIssue(t *testing.T) {
token: accessToken,
err: nil,
},
{
desc: "issue access key with cache error",
key: auth.Key{
Type: auth.AccessKey,
Subject: userID,
Role: auth.UserRole,
IssuedAt: time.Now(),
},
token: accessToken,
cacheErr: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases2 {
t.Run(tc.desc, func(t *testing.T) {
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr)
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr)
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
cacheCall := tokensCache.On("SaveActive", context.Background(), tc.key.Subject, mock.Anything, tc.key.Description, mock.Anything).Return(tc.cacheErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: tc.key.Subject,
SubjectType: policies.UserType,
@@ -186,6 +202,7 @@ func TestIssue(t *testing.T) {
tokenizerCall.Unset()
repoCall.Unset()
policyCall.Unset()
cacheCall.Unset()
})
}
@@ -265,7 +282,7 @@ func TestIssue(t *testing.T) {
}
for _, tc := range cases3 {
t.Run(tc.desc, func(t *testing.T) {
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr)
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr)
tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr)
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
@@ -292,6 +309,8 @@ func TestIssue(t *testing.T) {
parseErr error
roleCheckErr error
issueErr error
cacheRes bool
cacheErr error
err error
}{
{
@@ -304,6 +323,7 @@ func TestIssue(t *testing.T) {
},
token: refreshToken,
parseRes: refreshkey,
cacheRes: true,
err: nil,
},
{
@@ -339,15 +359,46 @@ func TestIssue(t *testing.T) {
Role: auth.UserRole,
},
token: refreshToken,
cacheRes: true,
parseRes: refreshkey,
roleCheckErr: errRoleAuth,
err: errRoleAuth,
},
{
desc: "issue refresh key with revoked refresh token",
key: auth.Key{
Type: auth.RefreshKey,
IssuedAt: time.Now(),
Subject: userID,
Role: auth.UserRole,
},
token: refreshToken,
parseRes: refreshkey,
cacheRes: false,
cacheErr: nil,
err: auth.ErrRevokedToken,
},
{
desc: "issue refresh key with cache error",
key: auth.Key{
Type: auth.RefreshKey,
IssuedAt: time.Now(),
Subject: userID,
Role: auth.UserRole,
},
token: refreshToken,
parseRes: refreshkey,
cacheRes: false,
cacheErr: svcerr.ErrCreateEntity,
err: svcerr.ErrCreateEntity,
},
}
for _, tc := range cases4 {
t.Run(tc.desc, func(t *testing.T) {
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr)
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr)
tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr)
tokenizerCall2 := tokenizer.On("Revoke", mock.Anything, tc.token).Return(tc.parseErr)
cacheCall := tokensCache.On("IsActive", context.Background(), tc.key.ID).Return(tc.cacheRes, tc.cacheErr)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
Subject: tc.key.Subject,
SubjectType: policies.UserType,
@@ -359,7 +410,9 @@ func TestIssue(t *testing.T) {
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
tokenizerCall.Unset()
tokenizerCall1.Unset()
tokenizerCall2.Unset()
policyCall.Unset()
cacheCall.Unset()
})
}
}
@@ -642,6 +695,173 @@ func TestIdentify(t *testing.T) {
}
}
func TestRevokeToken(t *testing.T) {
svc, _ := newService(t)
cases := []struct {
desc string
userID string
tokenID string
removeErr error
err error
}{
{
desc: "revoke token successfully",
userID: validID,
tokenID: "validTokenID",
removeErr: nil,
err: nil,
},
{
desc: "revoke token with cache error",
userID: validID,
tokenID: "validTokenID",
removeErr: svcerr.ErrRemoveEntity,
err: svcerr.ErrRemoveEntity,
},
{
desc: "revoke token with empty token ID",
userID: validID,
tokenID: "",
removeErr: nil,
err: nil,
},
{
desc: "revoke token not found",
userID: validID,
tokenID: "nonExistentTokenID",
removeErr: svcerr.ErrNotFound,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
cacheCall := tokensCache.On("RemoveActive", mock.Anything, tc.userID, tc.tokenID).Return(tc.removeErr)
err := svc.RevokeToken(context.Background(), tc.userID, tc.tokenID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
cacheCall.Unset()
})
}
}
func TestRetrieveJWKS(t *testing.T) {
svc, _ := newService(t)
publicKeys := []auth.PublicKeyInfo{
{
KeyID: "key1",
Algorithm: "RS256",
},
{
KeyID: "key2",
Algorithm: "RS256",
},
}
cases := []struct {
desc string
tokenizerRes []auth.PublicKeyInfo
tokenizerErr error
expectedResult []auth.PublicKeyInfo
}{
{
desc: "retrieve JWKS successfully",
tokenizerRes: publicKeys,
tokenizerErr: nil,
expectedResult: publicKeys,
},
{
desc: "retrieve JWKS with tokenizer error",
tokenizerRes: nil,
tokenizerErr: svcerr.ErrViewEntity,
expectedResult: nil,
},
{
desc: "retrieve JWKS with empty keys",
tokenizerRes: []auth.PublicKeyInfo{},
tokenizerErr: nil,
expectedResult: []auth.PublicKeyInfo{},
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
tokenizerCall := tokenizer.On("RetrieveJWKS").Return(tc.tokenizerRes, tc.tokenizerErr)
result := svc.RetrieveJWKS()
assert.Equal(t, tc.expectedResult, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedResult, result))
tokenizerCall.Unset()
})
}
}
func TestListUserRefreshTokens(t *testing.T) {
svc, _ := newService(t)
tokenInfos := []auth.TokenInfo{
{
ID: "token1",
Description: "Session 1",
},
{
ID: "token2",
Description: "Session 2",
},
}
cases := []struct {
desc string
userID string
cacheRes []auth.TokenInfo
cacheErr error
expectedRes []auth.TokenInfo
err error
}{
{
desc: "list user refresh tokens successfully",
userID: userID,
cacheRes: tokenInfos,
cacheErr: nil,
expectedRes: tokenInfos,
err: nil,
},
{
desc: "list user refresh tokens with cache error",
userID: userID,
cacheRes: nil,
cacheErr: svcerr.ErrViewEntity,
expectedRes: nil,
err: svcerr.ErrViewEntity,
},
{
desc: "list user refresh tokens with empty result",
userID: userID,
cacheRes: []auth.TokenInfo{},
cacheErr: nil,
expectedRes: []auth.TokenInfo{},
err: nil,
},
{
desc: "list user refresh tokens with invalid user ID",
userID: "",
cacheRes: nil,
cacheErr: svcerr.ErrViewEntity,
expectedRes: nil,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
cacheCall := tokensCache.On("ListUserTokens", mock.Anything, tc.userID).Return(tc.cacheRes, tc.cacheErr)
result, err := svc.ListUserRefreshTokens(context.Background(), tc.userID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected error %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.expectedRes, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedRes, result))
cacheCall.Unset()
})
}
}
func TestAuthorize(t *testing.T) {
svc, _ := newService(t)
+7 -3
View File
@@ -28,6 +28,7 @@ export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key"
```
The tokenizer will:
- Issue new tokens signed with the active key
- Verify tokens using the active key
- Return one public key in JWKS endpoint
@@ -42,6 +43,7 @@ export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
```
The tokenizer will:
- Issue new tokens signed with the active key
- Verify tokens using both active and retiring keys
- Return both public keys in JWKS endpoint
@@ -103,10 +105,12 @@ The grace period should be longer than your longest-lived access token duration.
- Store private keys with `0600` permissions
- Use cryptographically secure key generation:
```bash
openssl genpkey -algorithm Ed25519 -out private.key
chmod 600 private.key
```
- Rotate keys regularly:
- Standard environments: every 90 days
- High-security environments: every 30 days
@@ -139,7 +143,7 @@ rm ./keys/key-2024.pem
### Active key not found
```
```bash
Error: active key file not found: ./keys/active.key
```
@@ -149,7 +153,7 @@ Error: active key file not found: ./keys/active.key
If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key:
```
```bash
WARN: failed to load retiring key, continuing without it
```
@@ -157,7 +161,7 @@ This is by design - a missing retiring key won't prevent startup.
### Invalid key format
```
```bash
Error: failed to parse private key
```
+14 -16
View File
@@ -24,8 +24,6 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)
const patPrefix = "pat"
var (
errLoadingPrivateKey = errors.New("failed to load private key")
errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID")
@@ -95,8 +93,8 @@ func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDPr
return mgr, nil
}
func (km *tokenizer) Issue(key auth.Key) (string, error) {
if km.activeKey == nil {
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
if tok.activeKey == nil {
return "", errNoActiveKey
}
@@ -105,11 +103,11 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
return "", err
}
headers := jws.NewHeaders()
if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil {
if err := headers.Set(jwk.KeyIDKey, tok.activeKey.id); err != nil {
return "", err
}
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, tok.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
if err != nil {
return "", err
}
@@ -117,17 +115,17 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
return string(signedBytes), nil
}
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
set := jwk.NewSet()
if err := set.AddKey(km.activeKey.publicKey); err != nil {
if err := set.AddKey(tok.activeKey.publicKey); err != nil {
return auth.Key{}, err
}
if km.retiringKey != nil {
if err := set.AddKey(km.retiringKey.publicKey); err != nil {
if tok.retiringKey != nil {
if err := set.AddKey(tok.retiringKey.publicKey); err != nil {
return auth.Key{}, err
}
}
@@ -148,17 +146,17 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e
return smqjwt.ToKey(tkn)
}
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
publicKeys := make([]auth.PublicKeyInfo, 0, 2)
if km.activeKey != nil {
if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil {
if tok.activeKey != nil {
if pkInfo := extractPublicKeyInfo(tok.activeKey); pkInfo != nil {
publicKeys = append(publicKeys, *pkInfo)
}
}
if km.retiringKey != nil {
if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil {
if tok.retiringKey != nil {
if pkInfo := extractPublicKeyInfo(tok.retiringKey); pkInfo != nil {
publicKeys = append(publicKeys, *pkInfo)
}
}
+7 -13
View File
@@ -14,12 +14,6 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)
const (
patPrefix = "pat"
)
var errJWTExpiryKey = errors.New(`"exp" not satisfied`)
type tokenizer struct {
algorithm jwa.KeyAlgorithm
secret []byte
@@ -41,13 +35,13 @@ func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) {
}, nil
}
func (km *tokenizer) Issue(key auth.Key) (string, error) {
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
tkn, err := smqjwt.BuildToken(key)
if err != nil {
return "", err
}
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret))
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(tok.algorithm, tok.secret))
if err != nil {
return "", err
}
@@ -55,18 +49,18 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
return string(signedBytes), nil
}
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix {
return auth.Key{Type: auth.PersonalAccessToken}, nil
}
tkn, err := jwt.Parse(
[]byte(tokenString),
jwt.WithValidate(true),
jwt.WithKey(km.algorithm, km.secret),
jwt.WithKey(tok.algorithm, tok.secret),
)
if err != nil {
if errors.Contains(err, errJWTExpiryKey) {
if errors.Contains(err, smqjwt.ErrJWTExpiryKey) {
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry)
}
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
@@ -79,6 +73,6 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e
return smqjwt.ToKey(tkn)
}
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
return nil, auth.ErrPublicKeysNotSupported
}
+4
View File
@@ -18,6 +18,9 @@ var (
// ErrJSONHandle indicates an error in handling JSON.
ErrJSONHandle = errors.New("failed to perform operation JSON")
// ErrJWTExpiryKey indicates that the "exp" claim in the JWT token is not satisfied.
ErrJWTExpiryKey = errors.New(`"exp" not satisfied`)
errInvalidType = errors.New("invalid token type")
errInvalidRole = errors.New("invalid role")
errInvalidVerified = errors.New("invalid verified")
@@ -28,6 +31,7 @@ const (
TokenType = "type"
RoleField = "role"
VerifiedField = "verified"
PatPrefix = "pat"
)
// ToKey converts a JWT token to an auth.Key by extracting claims.
+7 -3
View File
@@ -292,17 +292,21 @@ func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error {
}
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) {
cache := cache.NewPatsCache(cacheClient, keyDuration)
patsCache := cache.NewPatsCache(cacheClient, keyDuration)
tokensCache, err := cache.NewUserActiveTokensCache(cacheClient, keyDuration)
if err != nil {
return nil, err
}
database := pgclient.NewDatabase(db, dbConfig, tracer)
keysRepo := apostgres.New(database)
patsRepo := apostgres.NewPatRepo(database, cache)
patsRepo := apostgres.NewPatRepo(database, patsCache)
hasher := hasher.New()
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
pService := spicedb.NewPolicyService(spicedbClient, logger)
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc := auth.New(keysRepo, patsRepo, nil, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc = middleware.NewLogging(svc, logger)
counter, latency := prometheus.MakeMetrics("auth", "api")
svc = middleware.NewMetrics(svc, counter, latency)
+1 -1
View File
@@ -399,7 +399,7 @@ func createAdmin(ctx context.Context, c config, repo users.Repository, hsr users
if _, err = repo.Save(ctx, user); err != nil {
return "", err
}
if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword); err != nil {
if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword, ""); err != nil {
return "", err
}
return user.ID, nil
+2
View File
@@ -95,6 +95,8 @@ services:
- supermq-base-net
volumes:
- supermq-auth-redis-volume:/data
- ./redis/redis.conf:/etc/redis/redis.conf:ro
command: ["redis-server", "/etc/redis/redis.conf"]
auth:
image: docker.io/supermq/auth:${SMQ_RELEASE_TAG}
+14
View File
@@ -0,0 +1,14 @@
# Copyright (c) Abstract Machines
# SPDX-License-Identifier: Apache-2.0
# Enable AOF persistence
appendonly yes
appendfilename "appendonly.aof"
appendfsync everysec
# Enable periodic snapshots
save 300 10
# Persist data in Docker volume
dir /data
+26
View File
@@ -9,6 +9,9 @@ option go_package = "github.com/absmach/supermq/api/grpc/token/v1";
service TokenService {
rpc Issue(IssueReq) returns (Token) {}
rpc Refresh(RefreshReq) returns (Token) {}
rpc Revoke(RevokeReq) returns (RevokeRes) {}
rpc ListUserRefreshTokens(ListUserRefreshTokensReq)
returns (ListUserRefreshTokensRes) {}
}
message IssueReq {
@@ -16,6 +19,7 @@ message IssueReq {
uint32 user_role = 2;
uint32 type = 3;
bool verified = 4;
string description = 5;
}
message RefreshReq {
@@ -23,6 +27,11 @@ message RefreshReq {
bool verified = 2;
}
message RevokeReq {
string token_id = 1;
string user_id = 2;
}
// If a token is not carrying any information itself, the type
// field can be used to determine how to validate the token.
// Also, different tokens can be encoded in different ways.
@@ -31,3 +40,20 @@ message Token {
optional string refresh_token = 2;
string access_type = 3;
}
message RevokeRes{
}
message ListUserRefreshTokensReq {
string user_id = 1;
}
message ListUserRefreshTokensRes {
repeated RefreshToken refresh_tokens = 1;
}
message RefreshToken {
string id = 1;
string description = 2;
}
+3 -2
View File
@@ -21,8 +21,9 @@ type Token struct {
}
type Login struct {
Username string `json:"username"`
Password string `json:"password"`
Username string `json:"username"`
Password string `json:"password"`
Description string `json:"description,omitempty"`
}
func (sdk mgSDK) CreateToken(ctx context.Context, lt Login) (Token, errors.SDKError) {
+2 -2
View File
@@ -101,12 +101,12 @@ func TestIssueToken(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password).Return(tc.svcRes, tc.svcErr)
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description).Return(tc.svcRes, tc.svcErr)
resp, err := mgsdk.CreateToken(context.Background(), tc.login)
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.response, resp)
if tc.err == nil {
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password)
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description)
assert.True(t, ok)
}
svcCall.Unset()
+1 -1
View File
@@ -59,6 +59,7 @@ packages:
Cache:
Hasher:
KeyRepository:
UserActiveTokensCache:
Tokenizer:
PATS:
PATSRepository:
@@ -145,4 +146,3 @@ packages:
github.com/absmach/supermq/notifications:
interfaces:
Notifier:
+210 -1
View File
@@ -2384,7 +2384,9 @@ func TestIssueToken(t *testing.T) {
defer us.Close()
validUsername := "valid"
validDescription := "test token"
dataFormat := `{"username": "%s", "password": "%s"}`
dataFormatWithDesc := `{"username": "%s", "password": "%s", "description": "%s"}`
cases := []struct {
desc string
@@ -2400,6 +2402,13 @@ func TestIssueToken(t *testing.T) {
status: http.StatusCreated,
err: nil,
},
{
desc: "issue token with valid identity, secret and description",
data: fmt.Sprintf(dataFormatWithDesc, validUsername, secret, validDescription),
contentType: contentType,
status: http.StatusCreated,
err: nil,
},
{
desc: "issue token with empty identity",
data: fmt.Sprintf(dataFormat, "", secret),
@@ -2447,7 +2456,7 @@ func TestIssueToken(t *testing.T) {
body: strings.NewReader(tc.data),
}
svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if tc.err != nil {
@@ -2565,6 +2574,206 @@ func TestRefreshToken(t *testing.T) {
}
}
func TestRevokeRefreshToken(t *testing.T) {
us, svc, authn := newUsersServer()
defer us.Close()
cases := []struct {
desc string
data string
contentType string
token string
authnRes smqauthn.Session
authnErr error
status int
svcErr error
err error
}{
{
desc: "revoke refresh token with valid token",
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
contentType: contentType,
token: validToken,
authnRes: verifiedSession,
status: http.StatusNoContent,
err: nil,
},
{
desc: "revoke refresh token with invalid token",
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
contentType: contentType,
token: inValidToken,
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "revoke refresh token with empty token",
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
contentType: contentType,
token: "",
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
err: apiutil.ErrBearerToken,
},
{
desc: "revoke refresh token with empty token id",
data: `{"token_id": ""}`,
contentType: contentType,
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrMissingID,
},
{
desc: "revoke refresh token with malformed data",
data: fmt.Sprintf(`{"token_id": %s}`, validToken),
contentType: contentType,
token: validToken,
authnRes: verifiedSession,
status: http.StatusBadRequest,
err: apiutil.ErrMalformedRequestBody,
},
{
desc: "revoke refresh token with invalid content type",
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
contentType: "application/xml",
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnsupportedMediaType,
err: apiutil.ErrUnsupportedContentType,
},
{
desc: "revoke refresh token with service error",
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
contentType: contentType,
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnprocessableEntity,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodPost,
url: fmt.Sprintf("%s/users/tokens/revoke", us.URL),
contentType: tc.contentType,
body: strings.NewReader(tc.data),
token: tc.token,
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("RevokeRefreshToken", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if tc.err != nil {
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
func TestListActiveRefreshTokens(t *testing.T) {
us, svc, authn := newUsersServer()
defer us.Close()
cases := []struct {
desc string
token string
authnRes smqauthn.Session
authnErr error
status int
svcRes *grpcTokenV1.ListUserRefreshTokensRes
svcErr error
err error
}{
{
desc: "list active refresh tokens with valid token",
token: validToken,
authnRes: verifiedSession,
status: http.StatusOK,
svcRes: &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: []*grpcTokenV1.RefreshToken{
{Id: "token1", Description: "token-1"},
{Id: "token2", Description: "token-2"},
},
},
err: nil,
},
{
desc: "list active refresh tokens with invalid token",
token: inValidToken,
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "list active refresh tokens with empty token",
token: "",
status: http.StatusUnauthorized,
authnErr: svcerr.ErrAuthentication,
err: apiutil.ErrBearerToken,
},
{
desc: "list active refresh tokens with service error",
token: validToken,
authnRes: verifiedSession,
status: http.StatusUnprocessableEntity,
svcErr: svcerr.ErrViewEntity,
err: svcerr.ErrViewEntity,
},
{
desc: "list active refresh tokens with empty list",
token: validToken,
authnRes: verifiedSession,
status: http.StatusOK,
svcRes: &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: []*grpcTokenV1.RefreshToken{},
},
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
req := testRequest{
user: us.Client(),
method: http.MethodGet,
url: fmt.Sprintf("%s/users/tokens/refresh-tokens", us.URL),
token: tc.token,
}
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
svcCall := svc.On("ListActiveRefreshTokens", mock.Anything, tc.authnRes).Return(tc.svcRes, tc.svcErr)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
if tc.err != nil {
var resBody respBody
err = json.NewDecoder(res.Body).Decode(&resBody)
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
if resBody.Err != "" || resBody.Message != "" {
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
}
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
svcCall.Unset()
authnCall.Unset()
})
}
}
func TestEnable(t *testing.T) {
us, svc, authn := newUsersServer()
defer us.Close()
+38 -1
View File
@@ -416,7 +416,7 @@ func issueTokenEndpoint(svc users.Service) endpoint.Endpoint {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
token, err := svc.IssueToken(ctx, req.Username, req.Password)
token, err := svc.IssueToken(ctx, req.Username, req.Password, req.Description)
if err != nil {
return nil, err
}
@@ -454,6 +454,43 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
}
}
func revokeRefreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(revokeTokenReq)
if err := req.validate(); err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
err := svc.RevokeRefreshToken(ctx, session, req.TokenID)
if err != nil {
return nil, err
}
return revokeRes{}, nil
}
}
func listActiveRefreshTokensEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
if !ok {
return nil, svcerr.ErrAuthentication
}
refreshTokens, err := svc.ListActiveRefreshTokens(ctx, session)
if err != nil {
return nil, err
}
return listRefreshTokensRes{RefreshTokens: refreshTokens.GetRefreshTokens()}, nil
}
}
func enableEndpoint(svc users.Service) endpoint.Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(changeUserStatusReq)
+15 -2
View File
@@ -270,8 +270,9 @@ func (req changeUserStatusReq) validate() error {
}
type loginUserReq struct {
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
Description string `json:"description,omitempty"`
}
func (req loginUserReq) validate() error {
@@ -297,6 +298,18 @@ func (req tokenReq) validate() error {
return nil
}
type revokeTokenReq struct {
TokenID string `json:"token_id,omitempty"`
}
func (req revokeTokenReq) validate() error {
if req.TokenID == "" {
return apiutil.ErrMissingID
}
return nil
}
type passResetReq struct {
Email string `json:"email"`
}
+33 -1
View File
@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/absmach/supermq"
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
"github.com/absmach/supermq/users"
)
@@ -25,8 +26,9 @@ var (
_ supermq.Response = (*passResetReqRes)(nil)
_ supermq.Response = (*passChangeRes)(nil)
_ supermq.Response = (*updateUserRes)(nil)
_ supermq.Response = (*tokenRes)(nil)
_ supermq.Response = (*revokeRes)(nil)
_ supermq.Response = (*deleteUserRes)(nil)
_ supermq.Response = (*listRefreshTokensRes)(nil)
)
type pageRes struct {
@@ -80,6 +82,36 @@ func (res tokenRes) Empty() bool {
return res.AccessToken == "" || res.RefreshToken == ""
}
type revokeRes struct{}
func (res revokeRes) Code() int {
return http.StatusNoContent
}
func (res revokeRes) Headers() map[string]string {
return map[string]string{}
}
func (res revokeRes) Empty() bool {
return true
}
type listRefreshTokensRes struct {
RefreshTokens []*grpcTokenV1.RefreshToken `json:"refresh_tokens"`
}
func (res listRefreshTokensRes) Code() int {
return http.StatusOK
}
func (res listRefreshTokensRes) Headers() map[string]string {
return map[string]string{}
}
func (res listRefreshTokensRes) Empty() bool {
return false
}
type sendVerificationRes struct{}
func (res sendVerificationRes) Code() int {
+29
View File
@@ -78,6 +78,18 @@ func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient
api.EncodeResponse,
opts...,
), "refresh_token").ServeHTTP)
r.Post("/tokens/revoke", otelhttp.NewHandler(kithttp.NewServer(
revokeRefreshTokenEndpoint(svc),
decodeRevokeRefreshToken,
api.EncodeResponse,
opts...,
), "revoke_refresh_token").ServeHTTP)
r.Get("/tokens/refresh-tokens", otelhttp.NewHandler(kithttp.NewServer(
listActiveRefreshTokensEndpoint(svc),
decodeListActiveRefreshTokens,
api.EncodeResponse,
opts...,
), "list_active_refresh_tokens").ServeHTTP)
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
updateEmailEndpoint(svc),
decodeUpdateUserEmail,
@@ -532,6 +544,23 @@ func decodeRefreshToken(_ context.Context, r *http.Request) (any, error) {
return req, nil
}
func decodeRevokeRefreshToken(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
}
var req revokeTokenReq
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
}
return req, nil
}
func decodeListActiveRefreshTokens(_ context.Context, r *http.Request) (any, error) {
return nil, nil
}
func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) {
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
+15 -2
View File
@@ -29,11 +29,10 @@ const (
profileView = userPrefix + "view_profile"
userList = userPrefix + "list"
userSearch = userPrefix + "search"
userListByGroup = userPrefix + "list_by_group"
userIdentify = userPrefix + "identify"
generateResetToken = userPrefix + "generate_reset_token"
issueToken = userPrefix + "issue_token"
refreshToken = userPrefix + "refresh_token"
revokeRefreshToken = userPrefix + "revoke_refresh_token"
resetSecret = userPrefix + "reset_secret"
sendPasswordReset = userPrefix + "send_password_reset"
oauthCallback = userPrefix + "oauth_callback"
@@ -56,6 +55,7 @@ var (
_ events.Event = (*identifyUserEvent)(nil)
_ events.Event = (*issueTokenEvent)(nil)
_ events.Event = (*refreshTokenEvent)(nil)
_ events.Event = (*revokeRefreshTokenEvent)(nil)
_ events.Event = (*resetSecretEvent)(nil)
_ events.Event = (*sendPasswordResetEvent)(nil)
_ events.Event = (*oauthCallbackEvent)(nil)
@@ -492,6 +492,19 @@ func (rte refreshTokenEvent) Encode() (map[string]any, error) {
}, nil
}
type revokeRefreshTokenEvent struct {
tokenID string
requestID string
}
func (rrte revokeRefreshTokenEvent) Encode() (map[string]any, error) {
return map[string]any{
"operation": revokeRefreshToken,
"token_id": rrte.tokenID,
"request_id": rrte.requestID,
}, nil
}
type resetSecretEvent struct {
requestID string
}
+46 -27
View File
@@ -15,31 +15,32 @@ import (
)
const (
supermqPrefix = "supermq."
createStream = supermqPrefix + userCreate
sendVerificationStream = supermqPrefix + userSendVerification
verifyEmailStream = supermqPrefix + userVerifyEmail
updateStream = supermqPrefix + userUpdate
updateRoleStream = supermqPrefix + userUpdateRole
updateTagsStream = supermqPrefix + userUpdateTags
updateSecretStream = supermqPrefix + userUpdateSecret
updateUsernameStream = supermqPrefix + userUpdateUsername
updatePictureStream = supermqPrefix + userUpdateProfilePicture
UpdateEmailStream = supermqPrefix + userUpdateEmail
enableStream = supermqPrefix + userEnable
disableStream = supermqPrefix + userDisable
viewStream = supermqPrefix + userView
viewProfileStream = supermqPrefix + profileView
listStream = supermqPrefix + userList
searchStream = supermqPrefix + userSearch
identifyStream = supermqPrefix + userIdentify
issueTokenStream = supermqPrefix + issueToken
refreshTokenStream = supermqPrefix + refreshToken
resetSecretStream = supermqPrefix + resetSecret
sendPasswordResetStream = supermqPrefix + sendPasswordReset
oauthStream = supermqPrefix + oauthCallback
addPolicyStream = supermqPrefix + addClientPolicy
deleteStream = supermqPrefix + deleteUser
supermqPrefix = "supermq."
createStream = supermqPrefix + userCreate
sendVerificationStream = supermqPrefix + userSendVerification
verifyEmailStream = supermqPrefix + userVerifyEmail
updateStream = supermqPrefix + userUpdate
updateRoleStream = supermqPrefix + userUpdateRole
updateTagsStream = supermqPrefix + userUpdateTags
updateSecretStream = supermqPrefix + userUpdateSecret
updateUsernameStream = supermqPrefix + userUpdateUsername
updatePictureStream = supermqPrefix + userUpdateProfilePicture
UpdateEmailStream = supermqPrefix + userUpdateEmail
enableStream = supermqPrefix + userEnable
disableStream = supermqPrefix + userDisable
viewStream = supermqPrefix + userView
viewProfileStream = supermqPrefix + profileView
listStream = supermqPrefix + userList
searchStream = supermqPrefix + userSearch
identifyStream = supermqPrefix + userIdentify
issueTokenStream = supermqPrefix + issueToken
refreshTokenStream = supermqPrefix + refreshToken
revokeRefreshTokenStream = supermqPrefix + revokeRefreshToken
resetSecretStream = supermqPrefix + resetSecret
sendPasswordResetStream = supermqPrefix + sendPasswordReset
oauthStream = supermqPrefix + oauthCallback
addPolicyStream = supermqPrefix + addClientPolicy
deleteStream = supermqPrefix + deleteUser
)
var _ users.Service = (*eventStore)(nil)
@@ -350,8 +351,8 @@ func (es *eventStore) SendPasswordReset(ctx context.Context, email string) error
return es.Publish(ctx, sendPasswordResetStream, event)
}
func (es *eventStore) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
token, err := es.svc.IssueToken(ctx, username, secret)
func (es *eventStore) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
token, err := es.svc.IssueToken(ctx, username, secret, description)
if err != nil {
return token, err
}
@@ -385,6 +386,24 @@ func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, r
return token, nil
}
func (es *eventStore) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
err := es.svc.RevokeRefreshToken(ctx, session, tokenID)
if err != nil {
return err
}
event := revokeRefreshTokenEvent{
tokenID: tokenID,
requestID: middleware.GetReqID(ctx),
}
return es.Publish(ctx, revokeRefreshTokenStream, event)
}
func (es *eventStore) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
return es.svc.ListActiveRefreshTokens(ctx, session)
}
func (es *eventStore) ResetSecret(ctx context.Context, session authn.Session, secret string) error {
if err := es.svc.ResetSecret(ctx, session, secret); err != nil {
return err
+105 -16
View File
@@ -902,22 +902,24 @@ func TestIssueToken(t *testing.T) {
}
cases := []struct {
desc string
username string
secret string
svcRes *grpcTokenV1.Token
svcErr error
resp *grpcTokenV1.Token
err error
desc string
username string
secret string
description string
svcRes *grpcTokenV1.Token
svcErr error
resp *grpcTokenV1.Token
err error
}{
{
desc: "publish successfully",
username: validUser.Credentials.Username,
secret: validUser.Credentials.Secret,
svcRes: validToken,
svcErr: nil,
resp: validToken,
err: nil,
desc: "publish successfully",
username: validUser.Credentials.Username,
secret: validUser.Credentials.Secret,
description: "valid token",
svcRes: validToken,
svcErr: nil,
resp: validToken,
err: nil,
},
{
desc: "failed to publish with service error",
@@ -932,8 +934,8 @@ func TestIssueToken(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret)
svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret, tc.description).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret, tc.description)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
@@ -990,6 +992,93 @@ func TestRefreshToken(t *testing.T) {
}
}
func TestRevokeRefreshToken(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
validTokenID := "validTokenID"
cases := []struct {
desc string
session authn.Session
tokenID string
svcErr error
err error
}{
{
desc: "publish successfully",
session: validSession,
tokenID: validTokenID,
svcErr: nil,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
tokenID: validTokenID,
svcErr: svcerr.ErrUpdateEntity,
err: svcerr.ErrUpdateEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("RevokeRefreshToken", validCtx, tc.session, tc.tokenID).Return(tc.svcErr)
err := nsvc.RevokeRefreshToken(validCtx, tc.session, tc.tokenID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}
func TestListActiveRefreshTokens(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
validTokensList := &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: []*grpcTokenV1.RefreshToken{
{Id: "token1", Description: "token1"},
{Id: "token2", Description: "token2"},
},
}
cases := []struct {
desc string
session authn.Session
svcRes *grpcTokenV1.ListUserRefreshTokensRes
svcErr error
resp *grpcTokenV1.ListUserRefreshTokensRes
err error
}{
{
desc: "publish successfully",
session: validSession,
svcRes: validTokensList,
svcErr: nil,
resp: validTokensList,
err: nil,
},
{
desc: "failed to publish with service error",
session: validSession,
svcRes: nil,
svcErr: svcerr.ErrViewEntity,
resp: nil,
err: svcerr.ErrViewEntity,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("ListActiveRefreshTokens", validCtx, tc.session).Return(tc.svcRes, tc.svcErr)
resp, err := nsvc.ListActiveRefreshTokens(validCtx, tc.session)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
svcCall.Unset()
})
}
}
func TestResetSecret(t *testing.T) {
svc, nsvc := newEventStoreMiddleware(t)
+10 -2
View File
@@ -162,14 +162,22 @@ func (am *authorizationMiddleware) Identify(ctx context.Context, session authn.S
return am.svc.Identify(ctx, session)
}
func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
return am.svc.IssueToken(ctx, username, secret)
func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
return am.svc.IssueToken(ctx, username, secret, description)
}
func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) {
return am.svc.RefreshToken(ctx, session, refreshToken)
}
func (am *authorizationMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
return am.svc.RevokeRefreshToken(ctx, session, tokenID)
}
func (am *authorizationMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
return am.svc.ListActiveRefreshTokens(ctx, session)
}
func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, user users.User) (users.User, error) {
return am.svc.OAuthCallback(ctx, user)
}
+41 -2
View File
@@ -89,7 +89,7 @@ func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken
// IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) {
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (t *grpcTokenV1.Token, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -105,7 +105,7 @@ func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret st
}
lm.logger.Info("Issue token completed successfully", args...)
}(time.Now())
return lm.svc.IssueToken(ctx, username, secret)
return lm.svc.IssueToken(ctx, username, secret, description)
}
// RefreshToken logs the refresh_token request. It logs the refreshtoken, token type and the time it took to complete the request.
@@ -129,6 +129,45 @@ func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Ses
return lm.svc.RefreshToken(ctx, session, refreshToken)
}
// RevokeRefreshToken logs the revoke_refresh_token request. It logs the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("request_id", middleware.GetReqID(ctx)),
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("Revoke refresh token failed", args...)
return
}
lm.logger.Info("Revoke refresh token completed successfully", args...)
}(time.Now())
return lm.svc.RevokeRefreshToken(ctx, session, tokenID)
}
// ListActiveRefreshTokens logs the list_active_refresh_tokens request. It logs the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (tokens *grpcTokenV1.ListUserRefreshTokensRes, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("request_id", middleware.GetReqID(ctx)),
}
if tokens != nil {
args = append(args, slog.Int("tokens_count", len(tokens.GetRefreshTokens())))
}
if err != nil {
args = append(args, slog.String("error", err.Error()))
lm.logger.Warn("List active refresh tokens failed", args...)
return
}
lm.logger.Info("List active refresh tokens completed successfully", args...)
}(time.Now())
return lm.svc.ListActiveRefreshTokens(ctx, session)
}
// View logs the view_user request. It logs the user id and the time it took to complete the request.
// If the request fails, it logs the error.
func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id string) (c users.User, err error) {
+20 -2
View File
@@ -58,12 +58,12 @@ func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken
}
// IssueToken instruments IssueToken method with metrics.
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
defer func(begin time.Time) {
ms.counter.With("method", "issue_token").Add(1)
ms.latency.With("method", "issue_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.IssueToken(ctx, username, secret)
return ms.svc.IssueToken(ctx, username, secret, description)
}
// RefreshToken instruments RefreshToken method with metrics.
@@ -75,6 +75,24 @@ func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Ses
return ms.svc.RefreshToken(ctx, session, refreshToken)
}
// RevokeRefreshToken instruments RevokeRefreshToken method with metrics.
func (ms *metricsMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "revoke_refresh_token").Add(1)
ms.latency.With("method", "revoke_refresh_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RevokeRefreshToken(ctx, session, tokenID)
}
// ListActiveRefreshTokens instruments ListActiveRefreshTokens method with metrics.
func (ms *metricsMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_active_refresh_tokens").Add(1)
ms.latency.With("method", "list_active_refresh_tokens").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListActiveRefreshTokens(ctx, session)
}
// View instruments View method with metrics.
func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
defer func(begin time.Time) {
+18 -2
View File
@@ -50,11 +50,11 @@ func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken
}
// IssueToken traces the "IssueToken" operation of the wrapped users.Service.
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username)))
defer span.End()
return tm.svc.IssueToken(ctx, username, secret)
return tm.svc.IssueToken(ctx, username, secret, description)
}
// RefreshToken traces the "RefreshToken" operation of the wrapped users.Service.
@@ -65,6 +65,22 @@ func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Ses
return tm.svc.RefreshToken(ctx, session, refreshToken)
}
// RevokeRefreshToken traces the "RevokeRefreshToken" operation of the wrapped users.Service.
func (tm *tracingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_revoke_refresh_token")
defer span.End()
return tm.svc.RevokeRefreshToken(ctx, session, tokenID)
}
// ListActiveRefreshTokens traces the "ListActiveRefreshTokens" operation of the wrapped users.Service.
func (tm *tracingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_active_refresh_tokens")
defer span.End()
return tm.svc.ListActiveRefreshTokens(ctx, session)
}
// View traces the "View" operation of the wrapped users.Service.
func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_user", trace.WithAttributes(attribute.String("id", id)))
+149 -12
View File
@@ -318,8 +318,8 @@ func (_c *Service_Identify_Call) RunAndReturn(run func(ctx context.Context, sess
}
// IssueToken provides a mock function for the type Service
func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string) (*v1.Token, error) {
ret := _mock.Called(ctx, identity, secret)
func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string, description string) (*v1.Token, error) {
ret := _mock.Called(ctx, identity, secret, description)
if len(ret) == 0 {
panic("no return value specified for IssueToken")
@@ -327,18 +327,18 @@ func (_mock *Service) IssueToken(ctx context.Context, identity string, secret st
var r0 *v1.Token
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (*v1.Token, error)); ok {
return returnFunc(ctx, identity, secret)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (*v1.Token, error)); ok {
return returnFunc(ctx, identity, secret, description)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) *v1.Token); ok {
r0 = returnFunc(ctx, identity, secret)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) *v1.Token); ok {
r0 = returnFunc(ctx, identity, secret, description)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.Token)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
r1 = returnFunc(ctx, identity, secret)
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
r1 = returnFunc(ctx, identity, secret, description)
} else {
r1 = ret.Error(1)
}
@@ -354,11 +354,12 @@ type Service_IssueToken_Call struct {
// - ctx context.Context
// - identity string
// - secret string
func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}) *Service_IssueToken_Call {
return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret)}
// - description string
func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}, description interface{}) *Service_IssueToken_Call {
return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret, description)}
}
func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string)) *Service_IssueToken_Call {
func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string, description string)) *Service_IssueToken_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -372,10 +373,15 @@ func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity st
if args[2] != nil {
arg2 = args[2].(string)
}
var arg3 string
if args[3] != nil {
arg3 = args[3].(string)
}
run(
arg0,
arg1,
arg2,
arg3,
)
})
return _c
@@ -386,7 +392,75 @@ func (_c *Service_IssueToken_Call) Return(token *v1.Token, err error) *Service_I
return _c
}
func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string) (*v1.Token, error)) *Service_IssueToken_Call {
func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string, description string) (*v1.Token, error)) *Service_IssueToken_Call {
_c.Call.Return(run)
return _c
}
// ListActiveRefreshTokens provides a mock function for the type Service
func (_mock *Service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error) {
ret := _mock.Called(ctx, session)
if len(ret) == 0 {
panic("no return value specified for ListActiveRefreshTokens")
}
var r0 *v1.ListUserRefreshTokensRes
var r1 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) (*v1.ListUserRefreshTokensRes, error)); ok {
return returnFunc(ctx, session)
}
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) *v1.ListUserRefreshTokensRes); ok {
r0 = returnFunc(ctx, session)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes)
}
}
if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session) error); ok {
r1 = returnFunc(ctx, session)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Service_ListActiveRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListActiveRefreshTokens'
type Service_ListActiveRefreshTokens_Call struct {
*mock.Call
}
// ListActiveRefreshTokens is a helper method to define mock.On call
// - ctx context.Context
// - session authn.Session
func (_e *Service_Expecter) ListActiveRefreshTokens(ctx interface{}, session interface{}) *Service_ListActiveRefreshTokens_Call {
return &Service_ListActiveRefreshTokens_Call{Call: _e.mock.On("ListActiveRefreshTokens", ctx, session)}
}
func (_c *Service_ListActiveRefreshTokens_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_ListActiveRefreshTokens_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 authn.Session
if args[1] != nil {
arg1 = args[1].(authn.Session)
}
run(
arg0,
arg1,
)
})
return _c
}
func (_c *Service_ListActiveRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *Service_ListActiveRefreshTokens_Call {
_c.Call.Return(listUserRefreshTokensRes, err)
return _c
}
func (_c *Service_ListActiveRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error)) *Service_ListActiveRefreshTokens_Call {
_c.Call.Return(run)
return _c
}
@@ -801,6 +875,69 @@ func (_c *Service_ResetSecret_Call) RunAndReturn(run func(ctx context.Context, s
return _c
}
// RevokeRefreshToken provides a mock function for the type Service
func (_mock *Service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
ret := _mock.Called(ctx, session, tokenID)
if len(ret) == 0 {
panic("no return value specified for RevokeRefreshToken")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
r0 = returnFunc(ctx, session, tokenID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Service_RevokeRefreshToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeRefreshToken'
type Service_RevokeRefreshToken_Call struct {
*mock.Call
}
// RevokeRefreshToken is a helper method to define mock.On call
// - ctx context.Context
// - session authn.Session
// - tokenID string
func (_e *Service_Expecter) RevokeRefreshToken(ctx interface{}, session interface{}, tokenID interface{}) *Service_RevokeRefreshToken_Call {
return &Service_RevokeRefreshToken_Call{Call: _e.mock.On("RevokeRefreshToken", ctx, session, tokenID)}
}
func (_c *Service_RevokeRefreshToken_Call) Run(run func(ctx context.Context, session authn.Session, tokenID string)) *Service_RevokeRefreshToken_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
arg0 = args[0].(context.Context)
}
var arg1 authn.Session
if args[1] != nil {
arg1 = args[1].(authn.Session)
}
var arg2 string
if args[2] != nil {
arg2 = args[2].(string)
}
run(
arg0,
arg1,
arg2,
)
})
return _c
}
func (_c *Service_RevokeRefreshToken_Call) Return(err error) *Service_RevokeRefreshToken_Call {
_c.Call.Return(err)
return _c
}
func (_c *Service_RevokeRefreshToken_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, tokenID string) error) *Service_RevokeRefreshToken_Call {
_c.Call.Return(run)
return _c
}
// SearchUsers provides a mock function for the type Service
func (_mock *Service) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) {
ret := _mock.Called(ctx, pm)
+45 -3
View File
@@ -183,7 +183,7 @@ func (svc service) VerifyEmail(ctx context.Context, token string) (User, error)
return user, nil
}
func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) {
func (svc service) IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error) {
var dbUser User
var err error
@@ -205,7 +205,13 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
}
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()})
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{
UserId: dbUser.ID,
UserRole: uint32(dbUser.Role + 1),
Type: uint32(smqauth.AccessKey),
Verified: !dbUser.VerifiedAt.IsZero(),
Description: description,
})
if err != nil {
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
}
@@ -229,6 +235,42 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
return token, nil
}
func (svc service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
if err != nil {
return errors.Wrap(svcerr.ErrAuthentication, err)
}
if dbUser.Status == DisabledStatus {
return errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
}
_, err = svc.token.Revoke(ctx, &grpcTokenV1.RevokeReq{UserId: session.UserID, TokenId: tokenID})
if err != nil {
if errors.Contains(err, svcerr.ErrNotFound) {
return errors.Wrap(svcerr.ErrNotFound, err)
}
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
func (svc service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
if err != nil {
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
}
if dbUser.Status == DisabledStatus {
return nil, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
}
refreshTokens, err := svc.token.ListUserRefreshTokens(ctx, &grpcTokenV1.ListUserRefreshTokensReq{UserId: session.UserID})
if err != nil {
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
}
return refreshTokens, nil
}
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
user, err := svc.users.RetrieveByID(ctx, id)
if err != nil {
@@ -453,7 +495,7 @@ func (svc service) UpdateSecret(ctx context.Context, session authn.Session, oldS
if err != nil {
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
}
if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret); err != nil {
if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret, ""); err != nil {
return User{}, err
}
newSecret, err = svc.hasher.Hash(newSecret)
+176 -1
View File
@@ -1615,7 +1615,7 @@ func TestIssueToken(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret, "")
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
@@ -1710,6 +1710,181 @@ func TestRefreshToken(t *testing.T) {
}
}
func TestRevokeRefreshToken(t *testing.T) {
svc, authsvc, crepo, _, _ := newService()
rUser := user
rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret)
cases := []struct {
desc string
session authn.Session
tokenID string
revokeResp *grpcTokenV1.RevokeRes
revokeErr error
repoResp users.User
repoErr error
err error
}{
{
desc: "revoke refresh token successfully",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
tokenID: validToken,
revokeResp: &grpcTokenV1.RevokeRes{},
repoResp: rUser,
err: nil,
},
{
desc: "revoke refresh token with empty domain id",
session: authn.Session{UserID: validID},
tokenID: validToken,
revokeResp: &grpcTokenV1.RevokeRes{},
repoResp: rUser,
err: nil,
},
{
desc: "revoke refresh token for non-existing user",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
tokenID: validToken,
repoErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "revoke refresh token for disabled user",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
tokenID: validToken,
repoResp: users.User{Status: users.DisabledStatus},
err: svcerr.ErrAuthentication,
},
{
desc: "revoke refresh token with revoke service error",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
tokenID: validToken,
revokeResp: &grpcTokenV1.RevokeRes{},
revokeErr: svcerr.ErrAuthorization,
repoResp: rUser,
err: svcerr.ErrAuthorization,
},
{
desc: "revoke refresh token not found",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
tokenID: validToken,
revokeResp: &grpcTokenV1.RevokeRes{},
revokeErr: svcerr.ErrNotFound,
repoResp: rUser,
err: svcerr.ErrNotFound,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
authCall := authsvc.On("Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}).Return(tc.revokeResp, tc.revokeErr)
err := svc.RevokeRefreshToken(context.Background(), tc.session, tc.tokenID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID})
assert.True(t, ok, fmt.Sprintf("Revoke was not called on %s", tc.desc))
}
repoCall.Unset()
authCall.Unset()
})
}
}
func TestListActiveRefreshTokens(t *testing.T) {
svc, authsvc, crepo, _, _ := newService()
rUser := user
rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret)
cases := []struct {
desc string
session authn.Session
listResp *grpcTokenV1.ListUserRefreshTokensRes
listErr error
repoResp users.User
repoErr error
expectedTokens int
err error
}{
{
desc: "list active refresh tokens successfully",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
listResp: &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: []*grpcTokenV1.RefreshToken{
{Id: "token1", Description: "token1"},
{Id: "token2", Description: "token2"},
},
},
repoResp: rUser,
expectedTokens: 2,
err: nil,
},
{
desc: "list active refresh tokens with empty domain id",
session: authn.Session{UserID: validID},
listResp: &grpcTokenV1.ListUserRefreshTokensRes{
RefreshTokens: []*grpcTokenV1.RefreshToken{
{Id: "token1", Description: "token1"},
},
},
repoResp: rUser,
expectedTokens: 1,
err: nil,
},
{
desc: "list active refresh tokens for non-existing user",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
repoErr: repoerr.ErrNotFound,
err: repoerr.ErrNotFound,
},
{
desc: "list active refresh tokens for disabled user",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
repoResp: users.User{Status: users.DisabledStatus},
err: svcerr.ErrAuthentication,
},
{
desc: "list active refresh tokens with list service error",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
listResp: &grpcTokenV1.ListUserRefreshTokensRes{},
listErr: svcerr.ErrAuthentication,
repoResp: rUser,
err: svcerr.ErrAuthentication,
},
{
desc: "list active refresh tokens with empty list",
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
listResp: &grpcTokenV1.ListUserRefreshTokensRes{RefreshTokens: []*grpcTokenV1.RefreshToken{}},
repoResp: rUser,
expectedTokens: 0,
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
authCall := authsvc.On("ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}).Return(tc.listResp, tc.listErr)
tokens, err := svc.ListActiveRefreshTokens(context.Background(), tc.session)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
assert.NotNil(t, tokens, fmt.Sprintf("%s: expected tokens not to be nil\n", tc.desc))
assert.Equal(t, tc.expectedTokens, len(tokens.GetRefreshTokens()), fmt.Sprintf("%s: expected %d tokens got %d\n", tc.desc, tc.expectedTokens, len(tokens.GetRefreshTokens())))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID})
assert.True(t, ok, fmt.Sprintf("ListUserRefreshTokens was not called on %s", tc.desc))
}
repoCall.Unset()
authCall.Unset()
})
}
}
func TestSendPasswordReset(t *testing.T) {
svc, auth, cRepo, _, e := newService()
+7 -1
View File
@@ -264,13 +264,19 @@ type Service interface {
Identify(ctx context.Context, session authn.Session) (string, error)
// IssueToken issues a new access and refresh token when provided with either a username or email.
IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error)
IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error)
// RefreshToken refreshes expired access tokens.
// After an access token expires, the refresh token is used to get
// a new pair of access and refresh tokens.
RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error)
// RevokeRefreshToken revokes a refresh token by its ID.
RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error
// ListActiveRefreshTokens lists all active refresh tokens for the authenticated user.
ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error)
// OAuthCallback handles the callback from any supported OAuth provider.
// It processes the OAuth tokens and either signs in or signs up the user based on the provided state.
OAuthCallback(ctx context.Context, user User) (User, error)