mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-22 20:00:22 +00:00
SMQ-1672 - Revoke refresh token (#3241)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com> Signed-off-by: nyagamunene <stevenyaga2014@gmail.com> Co-authored-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
@@ -165,11 +165,8 @@ mocks: $(MOCKERY)
|
||||
|
||||
$(MOCKERY):
|
||||
@mkdir -p $(GOBIN)
|
||||
@mkdir -p mockery
|
||||
@echo ">> downloading mockery $(MOCKERY_VERSION)..."
|
||||
@curl -sL https://github.com/vektra/mockery/releases/download/v$(MOCKERY_VERSION)/mockery_$(MOCKERY_VERSION)_Linux_x86_64.tar.gz | tar -xz -C mockery
|
||||
@mv mockery/mockery $(GOBIN)
|
||||
@rm -r mockery
|
||||
@echo ">> installing mockery $(MOCKERY_VERSION)..."
|
||||
@go install github.com/vektra/mockery/v3@v$(MOCKERY_VERSION)
|
||||
|
||||
DIRS = consumers readers postgres internal
|
||||
test: mocks
|
||||
|
||||
+283
-23
@@ -30,6 +30,7 @@ type IssueReq struct {
|
||||
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
|
||||
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
|
||||
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"`
|
||||
Description string `protobuf:"bytes,5,opt,name=description,proto3" json:"description,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -92,6 +93,13 @@ func (x *IssueReq) GetVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (x *IssueReq) GetDescription() string {
|
||||
if x != nil {
|
||||
return x.Description
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type RefreshReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
RefreshToken string `protobuf:"bytes,1,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"`
|
||||
@@ -144,6 +152,58 @@ func (x *RefreshReq) GetVerified() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type RevokeReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
TokenId string `protobuf:"bytes,1,opt,name=token_id,json=tokenId,proto3" json:"token_id,omitempty"`
|
||||
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RevokeReq) Reset() {
|
||||
*x = RevokeReq{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RevokeReq) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RevokeReq) ProtoMessage() {}
|
||||
|
||||
func (x *RevokeReq) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[2]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RevokeReq.ProtoReflect.Descriptor instead.
|
||||
func (*RevokeReq) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *RevokeReq) GetTokenId() string {
|
||||
if x != nil {
|
||||
return x.TokenId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RevokeReq) GetUserId() string {
|
||||
if x != nil {
|
||||
return x.UserId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// If a token is not carrying any information itself, the type
|
||||
// field can be used to determine how to validate the token.
|
||||
// Also, different tokens can be encoded in different ways.
|
||||
@@ -158,7 +218,7 @@ type Token struct {
|
||||
|
||||
func (x *Token) Reset() {
|
||||
*x = Token{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[2]
|
||||
mi := &file_token_v1_token_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
@@ -170,7 +230,7 @@ func (x *Token) String() string {
|
||||
func (*Token) ProtoMessage() {}
|
||||
|
||||
func (x *Token) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[2]
|
||||
mi := &file_token_v1_token_proto_msgTypes[3]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
@@ -183,7 +243,7 @@ func (x *Token) ProtoReflect() protoreflect.Message {
|
||||
|
||||
// Deprecated: Use Token.ProtoReflect.Descriptor instead.
|
||||
func (*Token) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{2}
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *Token) GetAccessToken() string {
|
||||
@@ -207,29 +267,219 @@ func (x *Token) GetAccessType() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type RevokeRes struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RevokeRes) Reset() {
|
||||
*x = RevokeRes{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RevokeRes) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RevokeRes) ProtoMessage() {}
|
||||
|
||||
func (x *RevokeRes) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[4]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RevokeRes.ProtoReflect.Descriptor instead.
|
||||
func (*RevokeRes) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
type ListUserRefreshTokensReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensReq) Reset() {
|
||||
*x = ListUserRefreshTokensReq{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[5]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensReq) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListUserRefreshTokensReq) ProtoMessage() {}
|
||||
|
||||
func (x *ListUserRefreshTokensReq) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[5]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListUserRefreshTokensReq.ProtoReflect.Descriptor instead.
|
||||
func (*ListUserRefreshTokensReq) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{5}
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensReq) GetUserId() string {
|
||||
if x != nil {
|
||||
return x.UserId
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ListUserRefreshTokensRes struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
RefreshTokens []*RefreshToken `protobuf:"bytes,1,rep,name=refresh_tokens,json=refreshTokens,proto3" json:"refresh_tokens,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensRes) Reset() {
|
||||
*x = ListUserRefreshTokensRes{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[6]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensRes) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*ListUserRefreshTokensRes) ProtoMessage() {}
|
||||
|
||||
func (x *ListUserRefreshTokensRes) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[6]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use ListUserRefreshTokensRes.ProtoReflect.Descriptor instead.
|
||||
func (*ListUserRefreshTokensRes) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{6}
|
||||
}
|
||||
|
||||
func (x *ListUserRefreshTokensRes) GetRefreshTokens() []*RefreshToken {
|
||||
if x != nil {
|
||||
return x.RefreshTokens
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type RefreshToken struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
|
||||
Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *RefreshToken) Reset() {
|
||||
*x = RefreshToken{}
|
||||
mi := &file_token_v1_token_proto_msgTypes[7]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
|
||||
func (x *RefreshToken) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RefreshToken) ProtoMessage() {}
|
||||
|
||||
func (x *RefreshToken) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_token_v1_token_proto_msgTypes[7]
|
||||
if x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RefreshToken.ProtoReflect.Descriptor instead.
|
||||
func (*RefreshToken) Descriptor() ([]byte, []int) {
|
||||
return file_token_v1_token_proto_rawDescGZIP(), []int{7}
|
||||
}
|
||||
|
||||
func (x *RefreshToken) GetId() string {
|
||||
if x != nil {
|
||||
return x.Id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *RefreshToken) GetDescription() string {
|
||||
if x != nil {
|
||||
return x.Description
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_token_v1_token_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_token_v1_token_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"p\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"\x92\x01\n" +
|
||||
"\bIssueReq\x12\x17\n" +
|
||||
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
|
||||
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
|
||||
"\x04type\x18\x03 \x01(\rR\x04type\x12\x1a\n" +
|
||||
"\bverified\x18\x04 \x01(\bR\bverified\"M\n" +
|
||||
"\bverified\x18\x04 \x01(\bR\bverified\x12 \n" +
|
||||
"\vdescription\x18\x05 \x01(\tR\vdescription\"M\n" +
|
||||
"\n" +
|
||||
"RefreshReq\x12#\n" +
|
||||
"\rrefresh_token\x18\x01 \x01(\tR\frefreshToken\x12\x1a\n" +
|
||||
"\bverified\x18\x02 \x01(\bR\bverified\"\x87\x01\n" +
|
||||
"\bverified\x18\x02 \x01(\bR\bverified\"?\n" +
|
||||
"\tRevokeReq\x12\x19\n" +
|
||||
"\btoken_id\x18\x01 \x01(\tR\atokenId\x12\x17\n" +
|
||||
"\auser_id\x18\x02 \x01(\tR\x06userId\"\x87\x01\n" +
|
||||
"\x05Token\x12!\n" +
|
||||
"\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12(\n" +
|
||||
"\rrefresh_token\x18\x02 \x01(\tH\x00R\frefreshToken\x88\x01\x01\x12\x1f\n" +
|
||||
"\vaccess_type\x18\x03 \x01(\tR\n" +
|
||||
"accessTypeB\x10\n" +
|
||||
"\x0e_refresh_token2r\n" +
|
||||
"\x0e_refresh_token\"\v\n" +
|
||||
"\tRevokeRes\"3\n" +
|
||||
"\x18ListUserRefreshTokensReq\x12\x17\n" +
|
||||
"\auser_id\x18\x01 \x01(\tR\x06userId\"Y\n" +
|
||||
"\x18ListUserRefreshTokensRes\x12=\n" +
|
||||
"\x0erefresh_tokens\x18\x01 \x03(\v2\x16.token.v1.RefreshTokenR\rrefreshTokens\"@\n" +
|
||||
"\fRefreshToken\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12 \n" +
|
||||
"\vdescription\x18\x02 \x01(\tR\vdescription2\x8b\x02\n" +
|
||||
"\fTokenService\x12.\n" +
|
||||
"\x05Issue\x12\x12.token.v1.IssueReq\x1a\x0f.token.v1.Token\"\x00\x122\n" +
|
||||
"\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3"
|
||||
"\aRefresh\x12\x14.token.v1.RefreshReq\x1a\x0f.token.v1.Token\"\x00\x124\n" +
|
||||
"\x06Revoke\x12\x13.token.v1.RevokeReq\x1a\x13.token.v1.RevokeRes\"\x00\x12a\n" +
|
||||
"\x15ListUserRefreshTokens\x12\".token.v1.ListUserRefreshTokensReq\x1a\".token.v1.ListUserRefreshTokensRes\"\x00B.Z,github.com/absmach/supermq/api/grpc/token/v1b\x06proto3"
|
||||
|
||||
var (
|
||||
file_token_v1_token_proto_rawDescOnce sync.Once
|
||||
@@ -243,22 +493,32 @@ func file_token_v1_token_proto_rawDescGZIP() []byte {
|
||||
return file_token_v1_token_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 3)
|
||||
var file_token_v1_token_proto_msgTypes = make([]protoimpl.MessageInfo, 8)
|
||||
var file_token_v1_token_proto_goTypes = []any{
|
||||
(*IssueReq)(nil), // 0: token.v1.IssueReq
|
||||
(*RefreshReq)(nil), // 1: token.v1.RefreshReq
|
||||
(*Token)(nil), // 2: token.v1.Token
|
||||
(*IssueReq)(nil), // 0: token.v1.IssueReq
|
||||
(*RefreshReq)(nil), // 1: token.v1.RefreshReq
|
||||
(*RevokeReq)(nil), // 2: token.v1.RevokeReq
|
||||
(*Token)(nil), // 3: token.v1.Token
|
||||
(*RevokeRes)(nil), // 4: token.v1.RevokeRes
|
||||
(*ListUserRefreshTokensReq)(nil), // 5: token.v1.ListUserRefreshTokensReq
|
||||
(*ListUserRefreshTokensRes)(nil), // 6: token.v1.ListUserRefreshTokensRes
|
||||
(*RefreshToken)(nil), // 7: token.v1.RefreshToken
|
||||
}
|
||||
var file_token_v1_token_proto_depIdxs = []int32{
|
||||
0, // 0: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq
|
||||
1, // 1: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq
|
||||
2, // 2: token.v1.TokenService.Issue:output_type -> token.v1.Token
|
||||
2, // 3: token.v1.TokenService.Refresh:output_type -> token.v1.Token
|
||||
2, // [2:4] is the sub-list for method output_type
|
||||
0, // [0:2] is the sub-list for method input_type
|
||||
0, // [0:0] is the sub-list for extension type_name
|
||||
0, // [0:0] is the sub-list for extension extendee
|
||||
0, // [0:0] is the sub-list for field type_name
|
||||
7, // 0: token.v1.ListUserRefreshTokensRes.refresh_tokens:type_name -> token.v1.RefreshToken
|
||||
0, // 1: token.v1.TokenService.Issue:input_type -> token.v1.IssueReq
|
||||
1, // 2: token.v1.TokenService.Refresh:input_type -> token.v1.RefreshReq
|
||||
2, // 3: token.v1.TokenService.Revoke:input_type -> token.v1.RevokeReq
|
||||
5, // 4: token.v1.TokenService.ListUserRefreshTokens:input_type -> token.v1.ListUserRefreshTokensReq
|
||||
3, // 5: token.v1.TokenService.Issue:output_type -> token.v1.Token
|
||||
3, // 6: token.v1.TokenService.Refresh:output_type -> token.v1.Token
|
||||
4, // 7: token.v1.TokenService.Revoke:output_type -> token.v1.RevokeRes
|
||||
6, // 8: token.v1.TokenService.ListUserRefreshTokens:output_type -> token.v1.ListUserRefreshTokensRes
|
||||
5, // [5:9] is the sub-list for method output_type
|
||||
1, // [1:5] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_token_v1_token_proto_init() }
|
||||
@@ -266,14 +526,14 @@ func file_token_v1_token_proto_init() {
|
||||
if File_token_v1_token_proto != nil {
|
||||
return
|
||||
}
|
||||
file_token_v1_token_proto_msgTypes[2].OneofWrappers = []any{}
|
||||
file_token_v1_token_proto_msgTypes[3].OneofWrappers = []any{}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_token_v1_token_proto_rawDesc), len(file_token_v1_token_proto_rawDesc)),
|
||||
NumEnums: 0,
|
||||
NumMessages: 3,
|
||||
NumMessages: 8,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
|
||||
@@ -22,8 +22,10 @@ import (
|
||||
const _ = grpc.SupportPackageIsVersion9
|
||||
|
||||
const (
|
||||
TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue"
|
||||
TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh"
|
||||
TokenService_Issue_FullMethodName = "/token.v1.TokenService/Issue"
|
||||
TokenService_Refresh_FullMethodName = "/token.v1.TokenService/Refresh"
|
||||
TokenService_Revoke_FullMethodName = "/token.v1.TokenService/Revoke"
|
||||
TokenService_ListUserRefreshTokens_FullMethodName = "/token.v1.TokenService/ListUserRefreshTokens"
|
||||
)
|
||||
|
||||
// TokenServiceClient is the client API for TokenService service.
|
||||
@@ -32,6 +34,8 @@ const (
|
||||
type TokenServiceClient interface {
|
||||
Issue(ctx context.Context, in *IssueReq, opts ...grpc.CallOption) (*Token, error)
|
||||
Refresh(ctx context.Context, in *RefreshReq, opts ...grpc.CallOption) (*Token, error)
|
||||
Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error)
|
||||
ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error)
|
||||
}
|
||||
|
||||
type tokenServiceClient struct {
|
||||
@@ -62,12 +66,34 @@ func (c *tokenServiceClient) Refresh(ctx context.Context, in *RefreshReq, opts .
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *tokenServiceClient) Revoke(ctx context.Context, in *RevokeReq, opts ...grpc.CallOption) (*RevokeRes, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(RevokeRes)
|
||||
err := c.cc.Invoke(ctx, TokenService_Revoke_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *tokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *ListUserRefreshTokensReq, opts ...grpc.CallOption) (*ListUserRefreshTokensRes, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ListUserRefreshTokensRes)
|
||||
err := c.cc.Invoke(ctx, TokenService_ListUserRefreshTokens_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// TokenServiceServer is the server API for TokenService service.
|
||||
// All implementations must embed UnimplementedTokenServiceServer
|
||||
// for forward compatibility.
|
||||
type TokenServiceServer interface {
|
||||
Issue(context.Context, *IssueReq) (*Token, error)
|
||||
Refresh(context.Context, *RefreshReq) (*Token, error)
|
||||
Revoke(context.Context, *RevokeReq) (*RevokeRes, error)
|
||||
ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error)
|
||||
mustEmbedUnimplementedTokenServiceServer()
|
||||
}
|
||||
|
||||
@@ -84,6 +110,12 @@ func (UnimplementedTokenServiceServer) Issue(context.Context, *IssueReq) (*Token
|
||||
func (UnimplementedTokenServiceServer) Refresh(context.Context, *RefreshReq) (*Token, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Refresh not implemented")
|
||||
}
|
||||
func (UnimplementedTokenServiceServer) Revoke(context.Context, *RevokeReq) (*RevokeRes, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Revoke not implemented")
|
||||
}
|
||||
func (UnimplementedTokenServiceServer) ListUserRefreshTokens(context.Context, *ListUserRefreshTokensReq) (*ListUserRefreshTokensRes, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method ListUserRefreshTokens not implemented")
|
||||
}
|
||||
func (UnimplementedTokenServiceServer) mustEmbedUnimplementedTokenServiceServer() {}
|
||||
func (UnimplementedTokenServiceServer) testEmbeddedByValue() {}
|
||||
|
||||
@@ -141,6 +173,42 @@ func _TokenService_Refresh_Handler(srv interface{}, ctx context.Context, dec fun
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _TokenService_Revoke_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(RevokeReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(TokenServiceServer).Revoke(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: TokenService_Revoke_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(TokenServiceServer).Revoke(ctx, req.(*RevokeReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _TokenService_ListUserRefreshTokens_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ListUserRefreshTokensReq)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: TokenService_ListUserRefreshTokens_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(TokenServiceServer).ListUserRefreshTokens(ctx, req.(*ListUserRefreshTokensReq))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// TokenService_ServiceDesc is the grpc.ServiceDesc for TokenService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
@@ -156,6 +224,14 @@ var TokenService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "Refresh",
|
||||
Handler: _TokenService_Refresh_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Revoke",
|
||||
Handler: _TokenService_Revoke_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ListUserRefreshTokens",
|
||||
Handler: _TokenService_ListUserRefreshTokens_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{},
|
||||
Metadata: "token/v1/token.proto",
|
||||
|
||||
@@ -610,6 +610,57 @@ paths:
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/users/tokens/revoke:
|
||||
post:
|
||||
operationId: revokeRefreshToken
|
||||
summary: Revoke Refresh Token
|
||||
description: |
|
||||
Revokes a specific refresh token by its ID. This invalidates the
|
||||
refresh token so it can no longer be used to obtain new access tokens.
|
||||
tags:
|
||||
- Users
|
||||
security:
|
||||
- bearerAuth: []
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/RevokeRefreshTokenReq"
|
||||
responses:
|
||||
"204":
|
||||
description: Refresh token revoked successfully.
|
||||
"400":
|
||||
description: Failed due to malformed JSON.
|
||||
"401":
|
||||
description: Missing or invalid access token provided.
|
||||
"404":
|
||||
description: A non-existent entity request.
|
||||
"415":
|
||||
description: Missing or invalid content type.
|
||||
"422":
|
||||
description: Database can't process request.
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/users/tokens/refresh-tokens:
|
||||
get:
|
||||
operationId: listActiveRefreshTokens
|
||||
summary: List Active Refresh Tokens
|
||||
description: |
|
||||
Lists all active refresh token sessions for the currently authenticated user.
|
||||
tags:
|
||||
- Users
|
||||
security:
|
||||
- bearerAuth: []
|
||||
responses:
|
||||
"200":
|
||||
$ref: "#/components/responses/RefreshTokensPageRes"
|
||||
"400":
|
||||
description: Failed due to malformed JSON.
|
||||
"401":
|
||||
description: Missing or invalid access token provided.
|
||||
"404":
|
||||
description: A non-existent entity request.
|
||||
"500":
|
||||
$ref: "#/components/responses/ServiceError"
|
||||
|
||||
/users/send-verification:
|
||||
post:
|
||||
operationId: sendVerification
|
||||
@@ -1042,6 +1093,30 @@ components:
|
||||
- username
|
||||
- password
|
||||
|
||||
RefreshToken:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
format: uuid
|
||||
example: "bb7edb32-2eac-4aad-aebe-ed96fe073879"
|
||||
description: Unique identifier of the refresh token.
|
||||
description:
|
||||
type: string
|
||||
example: "Chrome browser session"
|
||||
description: Description of the refresh token session.
|
||||
|
||||
RefreshTokensPage:
|
||||
type: object
|
||||
properties:
|
||||
refresh_tokens:
|
||||
type: array
|
||||
items:
|
||||
$ref: "#/components/schemas/RefreshToken"
|
||||
description: List of active refresh tokens.
|
||||
required:
|
||||
- refresh_tokens
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
@@ -1463,6 +1538,22 @@ components:
|
||||
format: jwt
|
||||
description: Reset token generated and sent in email.
|
||||
|
||||
RevokeRefreshTokenReq:
|
||||
description: JSON-formatted document describing the refresh token to revoke.
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
token_id:
|
||||
type: string
|
||||
format: uuid
|
||||
example: "bb7edb32-2eac-4aad-aebe-ed96fe073879"
|
||||
description: The unique identifier of the refresh token to revoke.
|
||||
required:
|
||||
- token_id
|
||||
|
||||
PasswordChange:
|
||||
description: Password change data. User can change its password.
|
||||
required: true
|
||||
@@ -1574,6 +1665,13 @@ components:
|
||||
example: access
|
||||
description: User access token type.
|
||||
|
||||
RefreshTokensPageRes:
|
||||
description: List of active refresh tokens for the authenticated user.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: "#/components/schemas/RefreshTokensPage"
|
||||
|
||||
HealthRes:
|
||||
description: Service Health Check.
|
||||
content:
|
||||
|
||||
@@ -18,14 +18,16 @@ import (
|
||||
const tokenSvcName = "token.v1.TokenService"
|
||||
|
||||
type tokenGrpcClient struct {
|
||||
issue endpoint.Endpoint
|
||||
refresh endpoint.Endpoint
|
||||
timeout time.Duration
|
||||
issue endpoint.Endpoint
|
||||
refresh endpoint.Endpoint
|
||||
revoke endpoint.Endpoint
|
||||
listUserRefreshTokens endpoint.Endpoint
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
var _ grpcTokenV1.TokenServiceClient = (*tokenGrpcClient)(nil)
|
||||
|
||||
// NewAuthClient returns new auth gRPC client instance.
|
||||
// NewTokenClient returns new token gRPC client instance.
|
||||
func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.TokenServiceClient {
|
||||
return &tokenGrpcClient{
|
||||
issue: kitgrpc.NewClient(
|
||||
@@ -44,6 +46,22 @@ func NewTokenClient(conn *grpc.ClientConn, timeout time.Duration) grpcTokenV1.To
|
||||
decodeRefreshResponse,
|
||||
grpcTokenV1.Token{},
|
||||
).Endpoint(),
|
||||
revoke: kitgrpc.NewClient(
|
||||
conn,
|
||||
tokenSvcName,
|
||||
"Revoke",
|
||||
encodeRevokeRequest,
|
||||
decodeRevokeResponse,
|
||||
grpcTokenV1.RevokeRes{},
|
||||
).Endpoint(),
|
||||
listUserRefreshTokens: kitgrpc.NewClient(
|
||||
conn,
|
||||
tokenSvcName,
|
||||
"ListUserRefreshTokens",
|
||||
encodeListUserRefreshTokensRequest,
|
||||
decodeListUserRefreshTokensResponse,
|
||||
grpcTokenV1.ListUserRefreshTokensRes{},
|
||||
).Endpoint(),
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
@@ -53,10 +71,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
|
||||
defer cancel()
|
||||
|
||||
res, err := client.issue(ctx, issueReq{
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.GetVerified(),
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.GetVerified(),
|
||||
description: req.GetDescription(),
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
|
||||
@@ -67,10 +86,11 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
|
||||
func encodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(issueReq)
|
||||
return &grpcTokenV1.IssueReq{
|
||||
UserId: req.userID,
|
||||
UserRole: uint32(req.userRole),
|
||||
Type: uint32(req.keyType),
|
||||
Verified: req.verified,
|
||||
UserId: req.userID,
|
||||
UserRole: uint32(req.userRole),
|
||||
Type: uint32(req.keyType),
|
||||
Verified: req.verified,
|
||||
Description: req.description,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -97,3 +117,43 @@ func encodeRefreshRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
func decodeRefreshResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes, nil
|
||||
}
|
||||
|
||||
func (client tokenGrpcClient) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq, _ ...grpc.CallOption) (*grpcTokenV1.RevokeRes, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.revoke(ctx, revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.RevokeRes{}, grpcapi.DecodeError(err)
|
||||
}
|
||||
return res.(*grpcTokenV1.RevokeRes), nil
|
||||
}
|
||||
|
||||
func encodeRevokeRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(revokeReq)
|
||||
return &grpcTokenV1.RevokeReq{UserId: req.userID, TokenId: req.tokenID}, nil
|
||||
}
|
||||
|
||||
func decodeRevokeResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes, nil
|
||||
}
|
||||
|
||||
func (client tokenGrpcClient) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq, _ ...grpc.CallOption) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, client.timeout)
|
||||
defer cancel()
|
||||
|
||||
res, err := client.listUserRefreshTokens(ctx, listUserRefreshTokensReq{userID: req.GetUserId()})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.ListUserRefreshTokensRes{}, grpcapi.DecodeError(err)
|
||||
}
|
||||
return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil
|
||||
}
|
||||
|
||||
func encodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(listUserRefreshTokensReq)
|
||||
return &grpcTokenV1.ListUserRefreshTokensReq{UserId: req.userID}, nil
|
||||
}
|
||||
|
||||
func decodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return grpcRes, nil
|
||||
}
|
||||
|
||||
@@ -18,10 +18,11 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
key := auth.Key{
|
||||
Type: req.keyType,
|
||||
Subject: req.userID,
|
||||
Role: req.userRole,
|
||||
Verified: req.verified,
|
||||
Type: req.keyType,
|
||||
Subject: req.userID,
|
||||
Role: req.userRole,
|
||||
Verified: req.verified,
|
||||
Description: req.description,
|
||||
}
|
||||
tkn, err := svc.Issue(ctx, "", key)
|
||||
if err != nil {
|
||||
@@ -56,3 +57,34 @@ func refreshEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
func revokeEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(revokeReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := svc.RevokeToken(ctx, req.userID, req.tokenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listUserRefreshTokensEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(listUserRefreshTokensReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return listUserRefreshTokensRes{}, err
|
||||
}
|
||||
|
||||
refreshTokens, err := svc.ListUserRefreshTokens(ctx, req.userID)
|
||||
if err != nil {
|
||||
return listUserRefreshTokensRes{}, err
|
||||
}
|
||||
|
||||
return listUserRefreshTokensRes{refreshTokens: refreshTokens}, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,24 +24,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
port = 8082
|
||||
secret = "secret"
|
||||
email = "test@example.com"
|
||||
id = "testID"
|
||||
clientsType = "clients"
|
||||
usersType = "users"
|
||||
description = "Description"
|
||||
groupName = "smqx"
|
||||
adminPermission = "admin"
|
||||
|
||||
authoritiesObj = "authorities"
|
||||
memberRelation = "member"
|
||||
loginDuration = 30 * time.Minute
|
||||
refreshDuration = 24 * time.Hour
|
||||
invalidDuration = 7 * 24 * time.Hour
|
||||
validToken = "valid"
|
||||
inValidToken = "invalid"
|
||||
validPolicy = "valid"
|
||||
port = 8082
|
||||
validToken = "valid"
|
||||
inValidToken = "invalid"
|
||||
invalidID = "invalid"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -63,9 +49,9 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server {
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
defer conn.Close()
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -127,9 +113,9 @@ func TestIssue(t *testing.T) {
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
defer conn.Close()
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -161,9 +147,99 @@ func TestRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
|
||||
svcCall := svc.On("Issue", mock.Anything, mock.Anything, mock.Anything).Return(tc.issueResponse, tc.err)
|
||||
_, err := grpcClient.Refresh(context.Background(), &grpcTokenV1.RefreshReq{RefreshToken: tc.token})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
id string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke token with valid id",
|
||||
id: validID,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with invalid id",
|
||||
id: invalidID,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with empty id",
|
||||
id: "",
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "revoke already revoked token",
|
||||
id: validID,
|
||||
err: svcerr.ErrConflict,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("RevokeToken", mock.Anything, mock.Anything, tc.id).Return(tc.err)
|
||||
_, err := grpcClient.Revoke(context.Background(), &grpcTokenV1.RevokeReq{TokenId: tc.id})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserRefreshTokens(t *testing.T) {
|
||||
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
|
||||
grpcClient := grpcapi.NewTokenClient(conn, time.Second)
|
||||
defer conn.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
listResponse []auth.TokenInfo
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list tokens for user with valid id",
|
||||
userID: validID,
|
||||
listResponse: []auth.TokenInfo{
|
||||
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 1"},
|
||||
{ID: testsutil.GenerateUUID(&testing.T{}), Description: "Token 2"},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list tokens for user with empty list",
|
||||
userID: validID,
|
||||
listResponse: []auth.TokenInfo{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list tokens with invalid user id",
|
||||
userID: invalidID,
|
||||
listResponse: nil,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list tokens with empty user id",
|
||||
userID: "",
|
||||
listResponse: nil,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
svcCall := svc.On("ListUserRefreshTokens", mock.Anything, tc.userID).Return(tc.listResponse, tc.err)
|
||||
_, err := grpcClient.ListUserRefreshTokens(context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.userID})
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
)
|
||||
|
||||
type issueReq struct {
|
||||
userID string
|
||||
userRole auth.Role
|
||||
keyType auth.KeyType
|
||||
verified bool
|
||||
userID string
|
||||
userRole auth.Role
|
||||
keyType auth.KeyType
|
||||
verified bool
|
||||
description string
|
||||
}
|
||||
|
||||
func (req issueReq) validate() error {
|
||||
@@ -38,3 +39,28 @@ func (req refreshReq) validate() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type revokeReq struct {
|
||||
userID string
|
||||
tokenID string
|
||||
}
|
||||
|
||||
func (req revokeReq) validate() error {
|
||||
if req.tokenID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type listUserRefreshTokensReq struct {
|
||||
userID string
|
||||
}
|
||||
|
||||
func (req listUserRefreshTokensReq) validate() error {
|
||||
if req.userID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,8 +3,14 @@
|
||||
|
||||
package token
|
||||
|
||||
import "github.com/absmach/supermq/auth"
|
||||
|
||||
type issueRes struct {
|
||||
accessToken string
|
||||
refreshToken string
|
||||
accessType string
|
||||
}
|
||||
|
||||
type listUserRefreshTokensRes struct {
|
||||
refreshTokens []auth.TokenInfo
|
||||
}
|
||||
|
||||
@@ -16,11 +16,13 @@ var _ grpcTokenV1.TokenServiceServer = (*tokenGrpcServer)(nil)
|
||||
|
||||
type tokenGrpcServer struct {
|
||||
grpcTokenV1.UnimplementedTokenServiceServer
|
||||
issue kitgrpc.Handler
|
||||
refresh kitgrpc.Handler
|
||||
issue kitgrpc.Handler
|
||||
refresh kitgrpc.Handler
|
||||
revoke kitgrpc.Handler
|
||||
listUserRefreshTokens kitgrpc.Handler
|
||||
}
|
||||
|
||||
// NewAuthServer returns new AuthnServiceServer instance.
|
||||
// NewTokenServer returns new TokenServiceServer instance.
|
||||
func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer {
|
||||
return &tokenGrpcServer{
|
||||
issue: kitgrpc.NewServer(
|
||||
@@ -33,6 +35,16 @@ func NewTokenServer(svc auth.Service) grpcTokenV1.TokenServiceServer {
|
||||
decodeRefreshRequest,
|
||||
encodeIssueResponse,
|
||||
),
|
||||
revoke: kitgrpc.NewServer(
|
||||
(revokeEndpoint(svc)),
|
||||
decodeRevokeRequest,
|
||||
encodeRevokeResponse,
|
||||
),
|
||||
listUserRefreshTokens: kitgrpc.NewServer(
|
||||
(listUserRefreshTokensEndpoint(svc)),
|
||||
decodeListUserRefreshTokensRequest,
|
||||
encodeListUserRefreshTokensResponse,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,10 +67,11 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *grpcTokenV1.RefreshR
|
||||
func decodeIssueRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcTokenV1.IssueReq)
|
||||
return issueReq{
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.Verified,
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
verified: req.Verified,
|
||||
description: req.GetDescription(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -76,3 +89,49 @@ func encodeIssueResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
AccessType: res.accessType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *tokenGrpcServer) Revoke(ctx context.Context, req *grpcTokenV1.RevokeReq) (*grpcTokenV1.RevokeRes, error) {
|
||||
_, res, err := s.revoke.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, grpcapi.EncodeError(err)
|
||||
}
|
||||
return res.(*grpcTokenV1.RevokeRes), nil
|
||||
}
|
||||
|
||||
func decodeRevokeRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcTokenV1.RevokeReq)
|
||||
return revokeReq{userID: req.GetUserId(), tokenID: req.GetTokenId()}, nil
|
||||
}
|
||||
|
||||
func encodeRevokeResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
return &grpcTokenV1.RevokeRes{}, nil
|
||||
}
|
||||
|
||||
func (s *tokenGrpcServer) ListUserRefreshTokens(ctx context.Context, req *grpcTokenV1.ListUserRefreshTokensReq) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
_, res, err := s.listUserRefreshTokens.ServeGRPC(ctx, req)
|
||||
if err != nil {
|
||||
return nil, grpcapi.EncodeError(err)
|
||||
}
|
||||
return res.(*grpcTokenV1.ListUserRefreshTokensRes), nil
|
||||
}
|
||||
|
||||
func decodeListUserRefreshTokensRequest(_ context.Context, grpcReq any) (any, error) {
|
||||
req := grpcReq.(*grpcTokenV1.ListUserRefreshTokensReq)
|
||||
return listUserRefreshTokensReq{userID: req.GetUserId()}, nil
|
||||
}
|
||||
|
||||
func encodeListUserRefreshTokensResponse(_ context.Context, grpcRes any) (any, error) {
|
||||
res := grpcRes.(listUserRefreshTokensRes)
|
||||
|
||||
refreshTokens := make([]*grpcTokenV1.RefreshToken, len(res.refreshTokens))
|
||||
for i, refreshToken := range res.refreshTokens {
|
||||
refreshTokens[i] = &grpcTokenV1.RefreshToken{
|
||||
Id: refreshToken.ID,
|
||||
Description: refreshToken.Description,
|
||||
}
|
||||
}
|
||||
|
||||
return &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: refreshTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Vendored
+2
@@ -1,4 +1,6 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Package cache contains the domain concept definitions needed to
|
||||
// support SuperMQ auth cache service functionality.
|
||||
package cache
|
||||
|
||||
Vendored
+132
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
refreshPrefix = "refresh_tokens:"
|
||||
scoreNegInf = "-inf"
|
||||
scorePosInf = "+inf"
|
||||
)
|
||||
|
||||
var _ auth.UserActiveTokensCache = (*tokensCache)(nil)
|
||||
|
||||
type tokensCache struct {
|
||||
client *redis.Client
|
||||
keyDuration time.Duration
|
||||
}
|
||||
|
||||
// NewUserActiveTokensCache returns redis auth cache implementation.
|
||||
func NewUserActiveTokensCache(client *redis.Client, duration time.Duration) (auth.UserActiveTokensCache, error) {
|
||||
if duration == 0 {
|
||||
return nil, errors.New("token cache duration must not be zero")
|
||||
}
|
||||
return &tokensCache{
|
||||
client: client,
|
||||
keyDuration: duration,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SaveActive saves an active refresh token ID for a user with optional description.
|
||||
func (tc *tokensCache) SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error {
|
||||
ttl := min(tc.keyDuration, time.Until(expiry))
|
||||
|
||||
pipe := tc.client.TxPipeline()
|
||||
|
||||
pipe.Set(ctx, tokenKey(tokenID), description, ttl)
|
||||
pipe.ZAdd(ctx, userTokensKey(userID), redis.Z{
|
||||
Score: float64(time.Now().Add(ttl).Unix()),
|
||||
Member: tokenID,
|
||||
})
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// IsActive checks if the token ID is active for the given user.
|
||||
func (tc *tokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) {
|
||||
count, err := tc.client.Exists(ctx, tokenKey(tokenID)).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ListUserTokens lists all active refresh token IDs with descriptions for a user.
|
||||
func (tc *tokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
|
||||
key := userTokensKey(userID)
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
|
||||
pipe := tc.client.TxPipeline()
|
||||
pipe.ZRemRangeByScore(ctx, key, scoreNegInf, now)
|
||||
zrangeCmd := pipe.ZRangeByScore(ctx, key, &redis.ZRangeBy{Min: "(" + now, Max: scorePosInf})
|
||||
if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenIDs, err := zrangeCmd.Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tokenIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
getPipe := tc.client.Pipeline()
|
||||
getCmds := make([]*redis.StringCmd, len(tokenIDs))
|
||||
for i, tokenID := range tokenIDs {
|
||||
getCmds[i] = getPipe.Get(ctx, tokenKey(tokenID))
|
||||
}
|
||||
|
||||
if _, err = getPipe.Exec(ctx); err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
valid := make([]auth.TokenInfo, 0, len(tokenIDs))
|
||||
for i, cmd := range getCmds {
|
||||
description, err := cmd.Result()
|
||||
if err == redis.Nil {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
valid = append(valid, auth.TokenInfo{
|
||||
ID: tokenIDs[i],
|
||||
Description: description,
|
||||
})
|
||||
}
|
||||
|
||||
return valid, nil
|
||||
}
|
||||
|
||||
// RemoveActive removes an active refresh token ID for a user.
|
||||
func (tc *tokensCache) RemoveActive(ctx context.Context, userID, tokenID string) error {
|
||||
pipe := tc.client.TxPipeline()
|
||||
pipe.Del(ctx, tokenKey(tokenID))
|
||||
pipe.ZRem(ctx, userTokensKey(userID), tokenID)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func tokenKey(tokenID string) string {
|
||||
return refreshPrefix + "token:" + tokenID
|
||||
}
|
||||
|
||||
func userTokensKey(userID string) string {
|
||||
return refreshPrefix + "user_tokens:" + userID
|
||||
}
|
||||
Vendored
+288
@@ -0,0 +1,288 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package cache_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/auth/cache"
|
||||
"github.com/absmach/supermq/internal/testsutil"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var (
|
||||
storeClient *redis.Client
|
||||
storeURL string
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
code := testsutil.RunRedisTest(m, &storeClient, &storeURL)
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func setupRedisTokensClient() auth.UserActiveTokensCache {
|
||||
tc, err := cache.NewUserActiveTokensCache(storeClient, 10*time.Minute)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tc
|
||||
}
|
||||
|
||||
func TestTokenSave(t *testing.T) {
|
||||
storeClient.FlushAll(context.Background())
|
||||
tokensCache := setupRedisTokensClient()
|
||||
|
||||
userID := testsutil.GenerateUUID(t)
|
||||
tokenID := testsutil.GenerateUUID(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
tokenID string
|
||||
description string
|
||||
expiry time.Time
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Save active token",
|
||||
userID: userID,
|
||||
tokenID: tokenID,
|
||||
description: "Test token",
|
||||
expiry: time.Now().Add(10 * time.Minute),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Save already cached token",
|
||||
userID: userID,
|
||||
tokenID: tokenID,
|
||||
description: "Updated token",
|
||||
expiry: time.Now().Add(10 * time.Minute),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Save another token for same user",
|
||||
userID: userID,
|
||||
tokenID: testsutil.GenerateUUID(t),
|
||||
description: "Another token",
|
||||
expiry: time.Now().Add(10 * time.Minute),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Save token with empty id",
|
||||
userID: userID,
|
||||
tokenID: "",
|
||||
description: "Empty ID token",
|
||||
expiry: time.Now().Add(10 * time.Minute),
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Save token with empty description",
|
||||
userID: userID,
|
||||
tokenID: testsutil.GenerateUUID(t),
|
||||
description: "",
|
||||
expiry: time.Now().Add(10 * time.Minute),
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
err := tokensCache.SaveActive(context.Background(), tc.userID, tc.tokenID, tc.description, tc.expiry)
|
||||
if err == nil {
|
||||
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenContains(t *testing.T) {
|
||||
storeClient.FlushAll(context.Background())
|
||||
tokensCache := setupRedisTokensClient()
|
||||
|
||||
userID := testsutil.GenerateUUID(t)
|
||||
tokenID := testsutil.GenerateUUID(t)
|
||||
|
||||
err := tokensCache.SaveActive(context.Background(), userID, tokenID, "Test token", time.Now().Add(10*time.Minute))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
tokenID string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
desc: "IsActive for existing token",
|
||||
userID: userID,
|
||||
tokenID: tokenID,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
desc: "IsActive for non existing token",
|
||||
userID: userID,
|
||||
tokenID: testsutil.GenerateUUID(t),
|
||||
},
|
||||
{
|
||||
desc: "IsActive with empty token id",
|
||||
userID: userID,
|
||||
tokenID: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
|
||||
if tc.ok {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tc.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenRemove(t *testing.T) {
|
||||
storeClient.FlushAll(context.Background())
|
||||
tokensCache := setupRedisTokensClient()
|
||||
|
||||
userID := testsutil.GenerateUUID(t)
|
||||
num := 10
|
||||
var tokenIDs []string
|
||||
for i := range num {
|
||||
tokenID := testsutil.GenerateUUID(t)
|
||||
err := tokensCache.SaveActive(context.Background(), userID, tokenID, fmt.Sprintf("Token %d", i), time.Now().Add(10*time.Minute))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
|
||||
tokenIDs = append(tokenIDs, tokenID)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
tokenID string
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "Remove an existing token from cache",
|
||||
userID: userID,
|
||||
tokenID: tokenIDs[0],
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Remove token with empty id from cache",
|
||||
userID: userID,
|
||||
tokenID: "",
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "Remove non existing id from cache",
|
||||
userID: userID,
|
||||
tokenID: testsutil.GenerateUUID(t),
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
err := tokensCache.RemoveActive(context.Background(), tc.userID, tc.tokenID)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
if err == nil {
|
||||
ok, err := tokensCache.IsActive(context.Background(), tc.tokenID)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserTokens(t *testing.T) {
|
||||
storeClient.FlushAll(context.Background())
|
||||
tokensCache := setupRedisTokensClient()
|
||||
|
||||
userID := testsutil.GenerateUUID(t)
|
||||
userID2 := testsutil.GenerateUUID(t)
|
||||
num := 5
|
||||
var expectedTokens []auth.TokenInfo
|
||||
|
||||
for i := range num {
|
||||
tokenID := testsutil.GenerateUUID(t)
|
||||
description := fmt.Sprintf("Token %d", i)
|
||||
err := tokensCache.SaveActive(context.Background(), userID, tokenID, description, time.Now().Add(10*time.Minute))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
|
||||
expectedTokens = append(expectedTokens, auth.TokenInfo{
|
||||
ID: tokenID,
|
||||
Description: description,
|
||||
})
|
||||
}
|
||||
|
||||
tokenID2 := testsutil.GenerateUUID(t)
|
||||
desc2 := "User 2 token"
|
||||
err := tokensCache.SaveActive(context.Background(), userID2, tokenID2, desc2, time.Now().Add(10*time.Minute))
|
||||
assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err))
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
expectedCount int
|
||||
expectedTokens []auth.TokenInfo
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "List all tokens for user with multiple tokens",
|
||||
userID: userID,
|
||||
expectedCount: num,
|
||||
expectedTokens: expectedTokens,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "List tokens for user with single token",
|
||||
userID: userID2,
|
||||
expectedCount: 1,
|
||||
expectedTokens: []auth.TokenInfo{{ID: tokenID2, Description: desc2}},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "List tokens for user with no tokens",
|
||||
userID: testsutil.GenerateUUID(t),
|
||||
expectedCount: 0,
|
||||
expectedTokens: nil,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokens, err := tokensCache.ListUserTokens(context.Background(), tc.userID)
|
||||
assert.True(t, errors.Contains(err, tc.err))
|
||||
assert.Equal(t, tc.expectedCount, len(tokens))
|
||||
if tc.expectedTokens != nil {
|
||||
assert.ElementsMatch(t, tc.expectedTokens, tokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Cleanup expired tokens from list", func(t *testing.T) {
|
||||
// Remove one token directly from Redis to simulate expiration
|
||||
err := tokensCache.RemoveActive(context.Background(), userID, expectedTokens[0].ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// List should now return only valid tokens
|
||||
tokens, err := tokensCache.ListUserTokens(context.Background(), userID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, num-1, len(tokens))
|
||||
|
||||
// Check that the removed token is not in the list
|
||||
for _, token := range tokens {
|
||||
assert.NotEqual(t, expectedTokens[0].ID, token.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
+26
-1
@@ -5,13 +5,16 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnsupportedKeyAlgorithm = errors.New("unsupported key algorithm")
|
||||
ErrInvalidSymmetricKey = errors.New("invalid symmetric key")
|
||||
ErrPublicKeysNotSupported = errors.New("public keys not supported for symmetric algorithm")
|
||||
ErrRevokedToken = errors.NewAuthNError("token is revoked")
|
||||
)
|
||||
|
||||
// PublicKeyInfo represents a public key for external distribution via JWKS.
|
||||
@@ -33,6 +36,7 @@ type PublicKeyInfo struct {
|
||||
// Implementations manage underlying cryptographic operations and key distribution.
|
||||
type Tokenizer interface {
|
||||
// Issue creates a signed token string from the given key claims.
|
||||
// For RefreshKey types, the token ID is stored as active in the cache.
|
||||
Issue(key Key) (token string, err error)
|
||||
|
||||
// Parse verifies and parses a token string (JWT or PAT), returning the extracted claims.
|
||||
@@ -45,6 +49,27 @@ type Tokenizer interface {
|
||||
RetrieveJWKS() ([]PublicKeyInfo, error)
|
||||
}
|
||||
|
||||
// UserActiveTokensCache represents a cache repository for managing active refresh tokens per user.
|
||||
type UserActiveTokensCache interface {
|
||||
// SaveActive saves an active refresh token ID for a user with optional description.
|
||||
SaveActive(ctx context.Context, userID, tokenID, description string, expiry time.Time) error
|
||||
|
||||
// IsActive checks if the token ID is active.
|
||||
IsActive(ctx context.Context, tokenID string) (bool, error)
|
||||
|
||||
// ListUserTokens lists all active token IDs with descriptions for a given user.
|
||||
ListUserTokens(ctx context.Context, userID string) ([]TokenInfo, error)
|
||||
|
||||
// RemoveActive removes an active refresh token ID.
|
||||
RemoveActive(ctx context.Context, userID, tokenID string) error
|
||||
}
|
||||
|
||||
// TokenInfo represents information about an active refresh token.
|
||||
type TokenInfo struct {
|
||||
ID string `json:"id"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// IsSymmetricAlgorithm determines if the given algorithm is symmetric (HMAC-based).
|
||||
// Returns true for HMAC algorithms (HS256, HS384, HS512).
|
||||
// Returns false for asymmetric algorithms (EdDSA).
|
||||
|
||||
+9
-8
@@ -81,14 +81,15 @@ func (r Role) Validate() bool {
|
||||
|
||||
// Key represents API key.
|
||||
type Key struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type KeyType `json:"type,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
Subject string `json:"subject,omitempty"` // user ID
|
||||
Role Role `json:"role,omitempty"`
|
||||
IssuedAt time.Time `json:"issued_at,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
Verified bool `json:"verified,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type KeyType `json:"type,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
Subject string `json:"subject,omitempty"` // user ID
|
||||
Role Role `json:"role,omitempty"`
|
||||
IssuedAt time.Time `json:"issued_at,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
Verified bool `json:"verified,omitempty"`
|
||||
Description string `json:"description,omitempty"` // Optional description for refresh tokens
|
||||
}
|
||||
|
||||
func (key Key) String() string {
|
||||
|
||||
@@ -110,6 +110,42 @@ func (lm *loggingMiddleware) RetrieveJWKS() (jwks []auth.PublicKeyInfo) {
|
||||
return lm.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("user_id", userID),
|
||||
slog.String("token_id", tokenID),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Revoke token failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Revoke token completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.RevokeToken(ctx, userID, tokenID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) (tokens []auth.TokenInfo, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("user_id", userID),
|
||||
slog.Int("tokens_count", len(tokens)),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List user refresh tokens failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List user refresh tokens completed successfully", args...)
|
||||
}(time.Now())
|
||||
|
||||
return lm.svc.ListUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
|
||||
@@ -40,6 +40,15 @@ func (ms *metricsMiddleware) Issue(ctx context.Context, token string, key auth.K
|
||||
return ms.svc.Issue(ctx, token, key)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "revoke_token").Add(1)
|
||||
ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.RevokeToken(ctx, userID, tokenID)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "revoke_key").Add(1)
|
||||
@@ -75,6 +84,15 @@ func (ms *metricsMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
|
||||
return ms.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_user_refresh_tokens").Add(1)
|
||||
ms.latency.With("method", "list_user_refresh_tokens").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
|
||||
return ms.svc.ListUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "authorize").Add(1)
|
||||
|
||||
@@ -65,6 +65,25 @@ func (tm *tracingMiddleware) RetrieveJWKS() []auth.PublicKeyInfo {
|
||||
return tm.svc.RetrieveJWKS()
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) RevokeToken(ctx context.Context, userID, tokenID string) error {
|
||||
ctx, span := tm.tracer.Start(ctx, "revoke_token", trace.WithAttributes(
|
||||
attribute.String("user_id", userID),
|
||||
attribute.String("token_id", tokenID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.RevokeToken(ctx, userID, tokenID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
|
||||
ctx, span := tm.tracer.Start(ctx, "list_user_refresh_tokens", trace.WithAttributes(
|
||||
attribute.String("user_id", userID),
|
||||
))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.ListUserRefreshTokens(ctx, userID)
|
||||
}
|
||||
|
||||
func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy, patAuthz *auth.PATAuthz) error {
|
||||
attributes := []attribute.KeyValue{
|
||||
attribute.String("subject", pr.Subject),
|
||||
|
||||
@@ -758,6 +758,74 @@ func (_c *Service_ListScopes_Call) RunAndReturn(run func(ctx context.Context, to
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListUserRefreshTokens provides a mock function for the type Service
|
||||
func (_mock *Service) ListUserRefreshTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
|
||||
ret := _mock.Called(ctx, userID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserRefreshTokens")
|
||||
}
|
||||
|
||||
var r0 []auth.TokenInfo
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok {
|
||||
return returnFunc(ctx, userID)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok {
|
||||
r0 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.TokenInfo)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens'
|
||||
type Service_ListUserRefreshTokens_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListUserRefreshTokens is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
func (_e *Service_Expecter) ListUserRefreshTokens(ctx interface{}, userID interface{}) *Service_ListUserRefreshTokens_Call {
|
||||
return &Service_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens", ctx, userID)}
|
||||
}
|
||||
|
||||
func (_c *Service_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, userID string)) *Service_ListUserRefreshTokens_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ListUserRefreshTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *Service_ListUserRefreshTokens_Call {
|
||||
_c.Call.Return(tokenInfos, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *Service_ListUserRefreshTokens_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveAllPAT provides a mock function for the type Service
|
||||
func (_mock *Service) RemoveAllPAT(ctx context.Context, token string) error {
|
||||
ret := _mock.Called(ctx, token)
|
||||
@@ -1350,6 +1418,69 @@ func (_c *Service_RevokePATSecret_Call) RunAndReturn(run func(ctx context.Contex
|
||||
return _c
|
||||
}
|
||||
|
||||
// RevokeToken provides a mock function for the type Service
|
||||
func (_mock *Service) RevokeToken(ctx context.Context, userID string, tokenID string) error {
|
||||
ret := _mock.Called(ctx, userID, tokenID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RevokeToken")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, userID, tokenID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_RevokeToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeToken'
|
||||
type Service_RevokeToken_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RevokeToken is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - tokenID string
|
||||
func (_e *Service_Expecter) RevokeToken(ctx interface{}, userID interface{}, tokenID interface{}) *Service_RevokeToken_Call {
|
||||
return &Service_RevokeToken_Call{Call: _e.mock.On("RevokeToken", ctx, userID, tokenID)}
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeToken_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *Service_RevokeToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeToken_Call) Return(err error) *Service_RevokeToken_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeToken_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *Service_RevokeToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdatePATDescription provides a mock function for the type Service
|
||||
func (_mock *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) {
|
||||
ret := _mock.Called(ctx, token, patID, description)
|
||||
|
||||
@@ -126,6 +126,89 @@ func (_c *TokenServiceClient_Issue_Call) RunAndReturn(run func(ctx context.Conte
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListUserRefreshTokens provides a mock function for the type TokenServiceClient
|
||||
func (_mock *TokenServiceClient) ListUserRefreshTokens(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserRefreshTokens")
|
||||
}
|
||||
|
||||
var r0 *v1.ListUserRefreshTokensRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) *v1.ListUserRefreshTokensRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.ListUserRefreshTokensReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// TokenServiceClient_ListUserRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserRefreshTokens'
|
||||
type TokenServiceClient_ListUserRefreshTokens_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListUserRefreshTokens is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v1.ListUserRefreshTokensReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *TokenServiceClient_Expecter) ListUserRefreshTokens(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_ListUserRefreshTokens_Call {
|
||||
return &TokenServiceClient_ListUserRefreshTokens_Call{Call: _e.mock.On("ListUserRefreshTokens",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Run(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption)) *TokenServiceClient_ListUserRefreshTokens_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v1.ListUserRefreshTokensReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v1.ListUserRefreshTokensReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *TokenServiceClient_ListUserRefreshTokens_Call {
|
||||
_c.Call.Return(listUserRefreshTokensRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_ListUserRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, in *v1.ListUserRefreshTokensReq, opts ...grpc.CallOption) (*v1.ListUserRefreshTokensRes, error)) *TokenServiceClient_ListUserRefreshTokens_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Refresh provides a mock function for the type TokenServiceClient
|
||||
func (_mock *TokenServiceClient) Refresh(ctx context.Context, in *v1.RefreshReq, opts ...grpc.CallOption) (*v1.Token, error) {
|
||||
var tmpRet mock.Arguments
|
||||
@@ -208,3 +291,86 @@ func (_c *TokenServiceClient_Refresh_Call) RunAndReturn(run func(ctx context.Con
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Revoke provides a mock function for the type TokenServiceClient
|
||||
func (_mock *TokenServiceClient) Revoke(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error) {
|
||||
var tmpRet mock.Arguments
|
||||
if len(opts) > 0 {
|
||||
tmpRet = _mock.Called(ctx, in, opts)
|
||||
} else {
|
||||
tmpRet = _mock.Called(ctx, in)
|
||||
}
|
||||
ret := tmpRet
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for Revoke")
|
||||
}
|
||||
|
||||
var r0 *v1.RevokeRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) (*v1.RevokeRes, error)); ok {
|
||||
return returnFunc(ctx, in, opts...)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) *v1.RevokeRes); ok {
|
||||
r0 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.RevokeRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, *v1.RevokeReq, ...grpc.CallOption) error); ok {
|
||||
r1 = returnFunc(ctx, in, opts...)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// TokenServiceClient_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke'
|
||||
type TokenServiceClient_Revoke_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Revoke is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - in *v1.RevokeReq
|
||||
// - opts ...grpc.CallOption
|
||||
func (_e *TokenServiceClient_Expecter) Revoke(ctx interface{}, in interface{}, opts ...interface{}) *TokenServiceClient_Revoke_Call {
|
||||
return &TokenServiceClient_Revoke_Call{Call: _e.mock.On("Revoke",
|
||||
append([]interface{}{ctx, in}, opts...)...)}
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_Revoke_Call) Run(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption)) *TokenServiceClient_Revoke_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 *v1.RevokeReq
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(*v1.RevokeReq)
|
||||
}
|
||||
var arg2 []grpc.CallOption
|
||||
var variadicArgs []grpc.CallOption
|
||||
if len(args) > 2 {
|
||||
variadicArgs = args[2].([]grpc.CallOption)
|
||||
}
|
||||
arg2 = variadicArgs
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2...,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_Revoke_Call) Return(revokeRes *v1.RevokeRes, err error) *TokenServiceClient_Revoke_Call {
|
||||
_c.Call.Return(revokeRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenServiceClient_Revoke_Call) RunAndReturn(run func(ctx context.Context, in *v1.RevokeReq, opts ...grpc.CallOption) (*v1.RevokeRes, error)) *TokenServiceClient_Revoke_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
// Code generated by mockery; DO NOT EDIT.
|
||||
// github.com/vektra/mockery
|
||||
// template: testify
|
||||
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// NewUserActiveTokensCache creates a new instance of UserActiveTokensCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewUserActiveTokensCache(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *UserActiveTokensCache {
|
||||
mock := &UserActiveTokensCache{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// UserActiveTokensCache is an autogenerated mock type for the UserActiveTokensCache type
|
||||
type UserActiveTokensCache struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type UserActiveTokensCache_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *UserActiveTokensCache) EXPECT() *UserActiveTokensCache_Expecter {
|
||||
return &UserActiveTokensCache_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// IsActive provides a mock function for the type UserActiveTokensCache
|
||||
func (_mock *UserActiveTokensCache) IsActive(ctx context.Context, tokenID string) (bool, error) {
|
||||
ret := _mock.Called(ctx, tokenID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IsActive")
|
||||
}
|
||||
|
||||
var r0 bool
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok {
|
||||
return returnFunc(ctx, tokenID)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) bool); ok {
|
||||
r0 = returnFunc(ctx, tokenID)
|
||||
} else {
|
||||
r0 = ret.Get(0).(bool)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, tokenID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// UserActiveTokensCache_IsActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsActive'
|
||||
type UserActiveTokensCache_IsActive_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// IsActive is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - tokenID string
|
||||
func (_e *UserActiveTokensCache_Expecter) IsActive(ctx interface{}, tokenID interface{}) *UserActiveTokensCache_IsActive_Call {
|
||||
return &UserActiveTokensCache_IsActive_Call{Call: _e.mock.On("IsActive", ctx, tokenID)}
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_IsActive_Call) Run(run func(ctx context.Context, tokenID string)) *UserActiveTokensCache_IsActive_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_IsActive_Call) Return(b bool, err error) *UserActiveTokensCache_IsActive_Call {
|
||||
_c.Call.Return(b, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_IsActive_Call) RunAndReturn(run func(ctx context.Context, tokenID string) (bool, error)) *UserActiveTokensCache_IsActive_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListUserTokens provides a mock function for the type UserActiveTokensCache
|
||||
func (_mock *UserActiveTokensCache) ListUserTokens(ctx context.Context, userID string) ([]auth.TokenInfo, error) {
|
||||
ret := _mock.Called(ctx, userID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListUserTokens")
|
||||
}
|
||||
|
||||
var r0 []auth.TokenInfo
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) ([]auth.TokenInfo, error)); ok {
|
||||
return returnFunc(ctx, userID)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string) []auth.TokenInfo); ok {
|
||||
r0 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]auth.TokenInfo)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = returnFunc(ctx, userID)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// UserActiveTokensCache_ListUserTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserTokens'
|
||||
type UserActiveTokensCache_ListUserTokens_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListUserTokens is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
func (_e *UserActiveTokensCache_Expecter) ListUserTokens(ctx interface{}, userID interface{}) *UserActiveTokensCache_ListUserTokens_Call {
|
||||
return &UserActiveTokensCache_ListUserTokens_Call{Call: _e.mock.On("ListUserTokens", ctx, userID)}
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_ListUserTokens_Call) Run(run func(ctx context.Context, userID string)) *UserActiveTokensCache_ListUserTokens_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_ListUserTokens_Call) Return(tokenInfos []auth.TokenInfo, err error) *UserActiveTokensCache_ListUserTokens_Call {
|
||||
_c.Call.Return(tokenInfos, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_ListUserTokens_Call) RunAndReturn(run func(ctx context.Context, userID string) ([]auth.TokenInfo, error)) *UserActiveTokensCache_ListUserTokens_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveActive provides a mock function for the type UserActiveTokensCache
|
||||
func (_mock *UserActiveTokensCache) RemoveActive(ctx context.Context, userID string, tokenID string) error {
|
||||
ret := _mock.Called(ctx, userID, tokenID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RemoveActive")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
|
||||
r0 = returnFunc(ctx, userID, tokenID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// UserActiveTokensCache_RemoveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveActive'
|
||||
type UserActiveTokensCache_RemoveActive_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RemoveActive is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - tokenID string
|
||||
func (_e *UserActiveTokensCache_Expecter) RemoveActive(ctx interface{}, userID interface{}, tokenID interface{}) *UserActiveTokensCache_RemoveActive_Call {
|
||||
return &UserActiveTokensCache_RemoveActive_Call{Call: _e.mock.On("RemoveActive", ctx, userID, tokenID)}
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_RemoveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string)) *UserActiveTokensCache_RemoveActive_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_RemoveActive_Call) Return(err error) *UserActiveTokensCache_RemoveActive_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_RemoveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string) error) *UserActiveTokensCache_RemoveActive_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SaveActive provides a mock function for the type UserActiveTokensCache
|
||||
func (_mock *UserActiveTokensCache) SaveActive(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error {
|
||||
ret := _mock.Called(ctx, userID, tokenID, description, expiry)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for SaveActive")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) error); ok {
|
||||
r0 = returnFunc(ctx, userID, tokenID, description, expiry)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// UserActiveTokensCache_SaveActive_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveActive'
|
||||
type UserActiveTokensCache_SaveActive_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SaveActive is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - userID string
|
||||
// - tokenID string
|
||||
// - description string
|
||||
// - expiry time.Time
|
||||
func (_e *UserActiveTokensCache_Expecter) SaveActive(ctx interface{}, userID interface{}, tokenID interface{}, description interface{}, expiry interface{}) *UserActiveTokensCache_SaveActive_Call {
|
||||
return &UserActiveTokensCache_SaveActive_Call{Call: _e.mock.On("SaveActive", ctx, userID, tokenID, description, expiry)}
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_SaveActive_Call) Run(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time)) *UserActiveTokensCache_SaveActive_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 string
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(string)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 string
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(string)
|
||||
}
|
||||
var arg4 time.Time
|
||||
if args[4] != nil {
|
||||
arg4 = args[4].(time.Time)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
arg4,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_SaveActive_Call) Return(err error) *UserActiveTokensCache_SaveActive_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *UserActiveTokensCache_SaveActive_Call) RunAndReturn(run func(ctx context.Context, userID string, tokenID string, description string, expiry time.Time) error) *UserActiveTokensCache_SaveActive_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Abstract Machines
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package postgres
|
||||
|
||||
import "github.com/absmach/supermq/pkg/errors"
|
||||
|
||||
var _ errors.Mapper = (*duplicateErrors)(nil)
|
||||
|
||||
type duplicateErrors struct{}
|
||||
|
||||
// GetError maps constraint names to known errors.
|
||||
func (d duplicateErrors) GetError(constraint string) (error, bool) {
|
||||
switch constraint {
|
||||
case "revoked_tokens_pkey":
|
||||
return errors.NewRequestError("revoked token already exists"), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func NewDuplicateErrors() errors.Mapper {
|
||||
return duplicateErrors{}
|
||||
}
|
||||
+54
-9
@@ -29,13 +29,16 @@ var (
|
||||
// ErrExpiry indicates that the token is expired.
|
||||
ErrExpiry = errors.New("token is expired")
|
||||
|
||||
errIssueUser = errors.New("failed to issue new login key")
|
||||
errIssueTmp = errors.New("failed to issue new temporary key")
|
||||
errRevoke = errors.New("failed to remove key")
|
||||
errRetrieve = errors.New("failed to retrieve key data")
|
||||
errIdentify = errors.New("failed to validate token")
|
||||
errPlatform = errors.New("invalid platform id")
|
||||
errRoleAuth = errors.New("failed to authorize user role")
|
||||
errIssueUser = errors.New("failed to issue new login key")
|
||||
errIssueTmp = errors.New("failed to issue new temporary key")
|
||||
errRevoke = errors.New("failed to remove key")
|
||||
errRetrieve = errors.New("failed to retrieve key data")
|
||||
errIdentify = errors.New("failed to validate token")
|
||||
errPlatform = errors.New("invalid platform id")
|
||||
errRoleAuth = errors.New("failed to authorize user role")
|
||||
errSaveRefreshKey = errors.NewServiceError("failed to save refresh key")
|
||||
errRevokeRefreshKey = errors.NewServiceError("failed to revoke refresh key")
|
||||
errListRefreshKeys = errors.NewServiceError("failed to list refresh keys")
|
||||
|
||||
errMalformedPAT = errors.New("malformed personal access token")
|
||||
errFailedToParseUUID = errors.New("failed to parse string to UUID")
|
||||
@@ -67,6 +70,9 @@ type Authn interface {
|
||||
// Issue issues a new Key, returning its token value alongside.
|
||||
Issue(ctx context.Context, token string, key Key) (Token, error)
|
||||
|
||||
// RevokeToken revokes the refresh token by its ID.
|
||||
RevokeToken(ctx context.Context, userID, tokenID string) error
|
||||
|
||||
// Revoke removes the Key with the provided id that is
|
||||
// issued by the user identified by the provided key.
|
||||
Revoke(ctx context.Context, token, id string) error
|
||||
@@ -82,6 +88,9 @@ type Authn interface {
|
||||
|
||||
// RetrieveJWKS retrieves public keys to validate issued tokens.
|
||||
RetrieveJWKS() []PublicKeyInfo
|
||||
|
||||
// ListUserRefreshTokens lists all active refresh token sessions for a user.
|
||||
ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error)
|
||||
}
|
||||
|
||||
// Service specifies an API that must be fulfilled by the domain service
|
||||
@@ -100,6 +109,7 @@ type service struct {
|
||||
keys KeyRepository
|
||||
pats PATSRepository
|
||||
cache Cache
|
||||
tokensCache UserActiveTokensCache
|
||||
hasher Hasher
|
||||
idProvider supermq.IDProvider
|
||||
evaluator policies.Evaluator
|
||||
@@ -111,12 +121,13 @@ type service struct {
|
||||
}
|
||||
|
||||
// New instantiates the auth service implementation.
|
||||
func New(keys KeyRepository, pats PATSRepository, cache Cache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
|
||||
func New(keys KeyRepository, pats PATSRepository, cache Cache, tokensCache UserActiveTokensCache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
|
||||
return &service{
|
||||
tokenizer: tokenizer,
|
||||
keys: keys,
|
||||
pats: pats,
|
||||
cache: cache,
|
||||
tokensCache: tokensCache,
|
||||
hasher: hasher,
|
||||
idProvider: idp,
|
||||
evaluator: policyEvaluator,
|
||||
@@ -143,6 +154,14 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
|
||||
}
|
||||
}
|
||||
|
||||
func (svc service) RevokeToken(ctx context.Context, userID, tokenID string) error {
|
||||
if err := svc.tokensCache.RemoveActive(ctx, userID, tokenID); err != nil {
|
||||
return errors.Wrap(errRevokeRefreshKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) Revoke(ctx context.Context, token, id string) error {
|
||||
issuerID, _, err := svc.authenticate(ctx, token)
|
||||
if err != nil {
|
||||
@@ -205,6 +224,15 @@ func (svc service) RetrieveJWKS() []PublicKeyInfo {
|
||||
return keys
|
||||
}
|
||||
|
||||
func (svc service) ListUserRefreshTokens(ctx context.Context, userID string) ([]TokenInfo, error) {
|
||||
tokenInfo, err := svc.tokensCache.ListUserTokens(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(errListRefreshKeys, err)
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func (svc service) Authorize(ctx context.Context, pr policies.Policy, patAuthz *PATAuthz) error {
|
||||
if patAuthz != nil {
|
||||
if err := svc.AuthorizePAT(ctx, patAuthz.UserID, patAuthz.PatID, patAuthz.EntityType, patAuthz.Domain, patAuthz.Operation, patAuthz.EntityID); err != nil {
|
||||
@@ -265,10 +293,20 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
|
||||
|
||||
key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration)
|
||||
key.Type = RefreshKey
|
||||
id, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
}
|
||||
key.ID = id
|
||||
refresh, err := svc.tokenizer.Issue(key)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
}
|
||||
if key.Subject != "" && key.ExpiresAt.After(time.Now()) {
|
||||
if err := svc.tokensCache.SaveActive(ctx, key.Subject, key.ID, key.Description, key.ExpiresAt); err != nil {
|
||||
return Token{}, errors.Wrap(errSaveRefreshKey, err)
|
||||
}
|
||||
}
|
||||
|
||||
return Token{AccessToken: access, RefreshToken: refresh}, nil
|
||||
}
|
||||
@@ -298,6 +336,13 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
|
||||
if k.Type != RefreshKey {
|
||||
return Token{}, errIssueUser
|
||||
}
|
||||
ok, err := svc.tokensCache.IsActive(ctx, key.ID)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if !ok {
|
||||
return Token{}, ErrRevokedToken
|
||||
}
|
||||
key.ID = k.ID
|
||||
key.Type = AccessKey
|
||||
key.Subject = k.Subject
|
||||
@@ -313,7 +358,7 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
}
|
||||
|
||||
key.ExpiresAt = time.Now().UTC().Add(svc.refreshDuration)
|
||||
key.ExpiresAt = k.ExpiresAt
|
||||
key.Type = RefreshKey
|
||||
refresh, err := svc.tokenizer.Issue(key)
|
||||
if err != nil {
|
||||
|
||||
+232
-12
@@ -54,18 +54,20 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
krepo *mocks.KeyRepository
|
||||
pService *policymocks.Service
|
||||
pEvaluator *policymocks.Evaluator
|
||||
patsrepo *mocks.PATSRepository
|
||||
cache *mocks.Cache
|
||||
hasher *mocks.Hasher
|
||||
tokenizer *mocks.Tokenizer
|
||||
krepo *mocks.KeyRepository
|
||||
pService *policymocks.Service
|
||||
pEvaluator *policymocks.Evaluator
|
||||
patsrepo *mocks.PATSRepository
|
||||
cache *mocks.Cache
|
||||
tokensCache *mocks.UserActiveTokensCache
|
||||
hasher *mocks.Hasher
|
||||
tokenizer *mocks.Tokenizer
|
||||
)
|
||||
|
||||
func newService(t *testing.T) (auth.Service, string) {
|
||||
krepo = new(mocks.KeyRepository)
|
||||
cache = new(mocks.Cache)
|
||||
tokensCache = new(mocks.UserActiveTokensCache)
|
||||
pService = new(policymocks.Service)
|
||||
pEvaluator = new(policymocks.Evaluator)
|
||||
patsrepo = new(mocks.PATSRepository)
|
||||
@@ -76,7 +78,7 @@ func newService(t *testing.T) (auth.Service, string) {
|
||||
token, _, err := signToken(t, issuerName, accessKey, false)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing access key expected to succeed: %s", err))
|
||||
|
||||
return auth.New(krepo, patsrepo, cache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
|
||||
return auth.New(krepo, patsrepo, cache, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
|
||||
}
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
@@ -133,7 +135,7 @@ func TestIssue(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr)
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
|
||||
Subject: tc.key.Subject,
|
||||
SubjectType: policies.UserType,
|
||||
@@ -154,6 +156,7 @@ func TestIssue(t *testing.T) {
|
||||
saveResponse auth.Key
|
||||
token string
|
||||
tokenizerErr error
|
||||
cacheErr error
|
||||
saveErr error
|
||||
roleCheckErr error
|
||||
err error
|
||||
@@ -169,11 +172,24 @@ func TestIssue(t *testing.T) {
|
||||
token: accessToken,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue access key with cache error",
|
||||
key: auth.Key{
|
||||
Type: auth.AccessKey,
|
||||
Subject: userID,
|
||||
Role: auth.UserRole,
|
||||
IssuedAt: time.Now(),
|
||||
},
|
||||
token: accessToken,
|
||||
cacheErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases2 {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.tokenizerErr)
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.tokenizerErr)
|
||||
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
|
||||
cacheCall := tokensCache.On("SaveActive", context.Background(), tc.key.Subject, mock.Anything, tc.key.Description, mock.Anything).Return(tc.cacheErr)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
|
||||
Subject: tc.key.Subject,
|
||||
SubjectType: policies.UserType,
|
||||
@@ -186,6 +202,7 @@ func TestIssue(t *testing.T) {
|
||||
tokenizerCall.Unset()
|
||||
repoCall.Unset()
|
||||
policyCall.Unset()
|
||||
cacheCall.Unset()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -265,7 +282,7 @@ func TestIssue(t *testing.T) {
|
||||
}
|
||||
for _, tc := range cases3 {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr)
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr)
|
||||
tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr)
|
||||
repoCall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, tc.saveErr)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
|
||||
@@ -292,6 +309,8 @@ func TestIssue(t *testing.T) {
|
||||
parseErr error
|
||||
roleCheckErr error
|
||||
issueErr error
|
||||
cacheRes bool
|
||||
cacheErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
@@ -304,6 +323,7 @@ func TestIssue(t *testing.T) {
|
||||
},
|
||||
token: refreshToken,
|
||||
parseRes: refreshkey,
|
||||
cacheRes: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
@@ -339,15 +359,46 @@ func TestIssue(t *testing.T) {
|
||||
Role: auth.UserRole,
|
||||
},
|
||||
token: refreshToken,
|
||||
cacheRes: true,
|
||||
parseRes: refreshkey,
|
||||
roleCheckErr: errRoleAuth,
|
||||
err: errRoleAuth,
|
||||
},
|
||||
{
|
||||
desc: "issue refresh key with revoked refresh token",
|
||||
key: auth.Key{
|
||||
Type: auth.RefreshKey,
|
||||
IssuedAt: time.Now(),
|
||||
Subject: userID,
|
||||
Role: auth.UserRole,
|
||||
},
|
||||
token: refreshToken,
|
||||
parseRes: refreshkey,
|
||||
cacheRes: false,
|
||||
cacheErr: nil,
|
||||
err: auth.ErrRevokedToken,
|
||||
},
|
||||
{
|
||||
desc: "issue refresh key with cache error",
|
||||
key: auth.Key{
|
||||
Type: auth.RefreshKey,
|
||||
IssuedAt: time.Now(),
|
||||
Subject: userID,
|
||||
Role: auth.UserRole,
|
||||
},
|
||||
token: refreshToken,
|
||||
parseRes: refreshkey,
|
||||
cacheRes: false,
|
||||
cacheErr: svcerr.ErrCreateEntity,
|
||||
err: svcerr.ErrCreateEntity,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases4 {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything).Return(tc.token, tc.issueErr)
|
||||
tokenizerCall := tokenizer.On("Issue", mock.Anything, mock.Anything).Return(tc.token, tc.issueErr)
|
||||
tokenizerCall1 := tokenizer.On("Parse", mock.Anything, tc.token).Return(tc.parseRes, tc.parseErr)
|
||||
tokenizerCall2 := tokenizer.On("Revoke", mock.Anything, tc.token).Return(tc.parseErr)
|
||||
cacheCall := tokensCache.On("IsActive", context.Background(), tc.key.ID).Return(tc.cacheRes, tc.cacheErr)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, policies.Policy{
|
||||
Subject: tc.key.Subject,
|
||||
SubjectType: policies.UserType,
|
||||
@@ -359,7 +410,9 @@ func TestIssue(t *testing.T) {
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
|
||||
tokenizerCall.Unset()
|
||||
tokenizerCall1.Unset()
|
||||
tokenizerCall2.Unset()
|
||||
policyCall.Unset()
|
||||
cacheCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -642,6 +695,173 @@ func TestIdentify(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeToken(t *testing.T) {
|
||||
svc, _ := newService(t)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
tokenID string
|
||||
removeErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke token successfully",
|
||||
userID: validID,
|
||||
tokenID: "validTokenID",
|
||||
removeErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with cache error",
|
||||
userID: validID,
|
||||
tokenID: "validTokenID",
|
||||
removeErr: svcerr.ErrRemoveEntity,
|
||||
err: svcerr.ErrRemoveEntity,
|
||||
},
|
||||
{
|
||||
desc: "revoke token with empty token ID",
|
||||
userID: validID,
|
||||
tokenID: "",
|
||||
removeErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke token not found",
|
||||
userID: validID,
|
||||
tokenID: "nonExistentTokenID",
|
||||
removeErr: svcerr.ErrNotFound,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
cacheCall := tokensCache.On("RemoveActive", mock.Anything, tc.userID, tc.tokenID).Return(tc.removeErr)
|
||||
err := svc.RevokeToken(context.Background(), tc.userID, tc.tokenID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err))
|
||||
cacheCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveJWKS(t *testing.T) {
|
||||
svc, _ := newService(t)
|
||||
|
||||
publicKeys := []auth.PublicKeyInfo{
|
||||
{
|
||||
KeyID: "key1",
|
||||
Algorithm: "RS256",
|
||||
},
|
||||
{
|
||||
KeyID: "key2",
|
||||
Algorithm: "RS256",
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
tokenizerRes []auth.PublicKeyInfo
|
||||
tokenizerErr error
|
||||
expectedResult []auth.PublicKeyInfo
|
||||
}{
|
||||
{
|
||||
desc: "retrieve JWKS successfully",
|
||||
tokenizerRes: publicKeys,
|
||||
tokenizerErr: nil,
|
||||
expectedResult: publicKeys,
|
||||
},
|
||||
{
|
||||
desc: "retrieve JWKS with tokenizer error",
|
||||
tokenizerRes: nil,
|
||||
tokenizerErr: svcerr.ErrViewEntity,
|
||||
expectedResult: nil,
|
||||
},
|
||||
{
|
||||
desc: "retrieve JWKS with empty keys",
|
||||
tokenizerRes: []auth.PublicKeyInfo{},
|
||||
tokenizerErr: nil,
|
||||
expectedResult: []auth.PublicKeyInfo{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
tokenizerCall := tokenizer.On("RetrieveJWKS").Return(tc.tokenizerRes, tc.tokenizerErr)
|
||||
result := svc.RetrieveJWKS()
|
||||
assert.Equal(t, tc.expectedResult, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedResult, result))
|
||||
tokenizerCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListUserRefreshTokens(t *testing.T) {
|
||||
svc, _ := newService(t)
|
||||
|
||||
tokenInfos := []auth.TokenInfo{
|
||||
{
|
||||
ID: "token1",
|
||||
Description: "Session 1",
|
||||
},
|
||||
{
|
||||
ID: "token2",
|
||||
Description: "Session 2",
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
userID string
|
||||
cacheRes []auth.TokenInfo
|
||||
cacheErr error
|
||||
expectedRes []auth.TokenInfo
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list user refresh tokens successfully",
|
||||
userID: userID,
|
||||
cacheRes: tokenInfos,
|
||||
cacheErr: nil,
|
||||
expectedRes: tokenInfos,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user refresh tokens with cache error",
|
||||
userID: userID,
|
||||
cacheRes: nil,
|
||||
cacheErr: svcerr.ErrViewEntity,
|
||||
expectedRes: nil,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "list user refresh tokens with empty result",
|
||||
userID: userID,
|
||||
cacheRes: []auth.TokenInfo{},
|
||||
cacheErr: nil,
|
||||
expectedRes: []auth.TokenInfo{},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list user refresh tokens with invalid user ID",
|
||||
userID: "",
|
||||
cacheRes: nil,
|
||||
cacheErr: svcerr.ErrViewEntity,
|
||||
expectedRes: nil,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
cacheCall := tokensCache.On("ListUserTokens", mock.Anything, tc.userID).Return(tc.cacheRes, tc.cacheErr)
|
||||
result, err := svc.ListUserRefreshTokens(context.Background(), tc.userID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected error %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.expectedRes, result, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.expectedRes, result))
|
||||
cacheCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorize(t *testing.T) {
|
||||
svc, _ := newService(t)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ export SMQ_AUTH_KEYS_ACTIVE_KEY_PATH="./keys/private.key"
|
||||
```
|
||||
|
||||
The tokenizer will:
|
||||
|
||||
- Issue new tokens signed with the active key
|
||||
- Verify tokens using the active key
|
||||
- Return one public key in JWKS endpoint
|
||||
@@ -42,6 +43,7 @@ export SMQ_AUTH_KEYS_RETIRING_KEY_PATH="./keys/retiring.key"
|
||||
```
|
||||
|
||||
The tokenizer will:
|
||||
|
||||
- Issue new tokens signed with the active key
|
||||
- Verify tokens using both active and retiring keys
|
||||
- Return both public keys in JWKS endpoint
|
||||
@@ -103,10 +105,12 @@ The grace period should be longer than your longest-lived access token duration.
|
||||
|
||||
- Store private keys with `0600` permissions
|
||||
- Use cryptographically secure key generation:
|
||||
|
||||
```bash
|
||||
openssl genpkey -algorithm Ed25519 -out private.key
|
||||
chmod 600 private.key
|
||||
```
|
||||
|
||||
- Rotate keys regularly:
|
||||
- Standard environments: every 90 days
|
||||
- High-security environments: every 30 days
|
||||
@@ -139,7 +143,7 @@ rm ./keys/key-2024.pem
|
||||
|
||||
### Active key not found
|
||||
|
||||
```
|
||||
```bash
|
||||
Error: active key file not found: ./keys/active.key
|
||||
```
|
||||
|
||||
@@ -149,7 +153,7 @@ Error: active key file not found: ./keys/active.key
|
||||
|
||||
If the retiring key path is set but the file is missing or invalid, the tokenizer logs a warning but continues with only the active key:
|
||||
|
||||
```
|
||||
```bash
|
||||
WARN: failed to load retiring key, continuing without it
|
||||
```
|
||||
|
||||
@@ -157,7 +161,7 @@ This is by design - a missing retiring key won't prevent startup.
|
||||
|
||||
### Invalid key format
|
||||
|
||||
```
|
||||
```bash
|
||||
Error: failed to parse private key
|
||||
```
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
const patPrefix = "pat"
|
||||
|
||||
var (
|
||||
errLoadingPrivateKey = errors.New("failed to load private key")
|
||||
errDuplicateRetiringKeyID = errors.New("retiring key ID matches active key ID")
|
||||
@@ -95,8 +93,8 @@ func NewTokenizer(activeKeyPath, retiringKeyPath string, idProvider supermq.IDPr
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
if km.activeKey == nil {
|
||||
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
if tok.activeKey == nil {
|
||||
return "", errNoActiveKey
|
||||
}
|
||||
|
||||
@@ -105,11 +103,11 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
headers := jws.NewHeaders()
|
||||
if err := headers.Set(jwk.KeyIDKey, km.activeKey.id); err != nil {
|
||||
if err := headers.Set(jwk.KeyIDKey, tok.activeKey.id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, km.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(jwa.EdDSA, tok.activeKey.privateKey, jws.WithProtectedHeaders(headers)))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -117,17 +115,17 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
return string(signedBytes), nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
|
||||
func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
|
||||
set := jwk.NewSet()
|
||||
if err := set.AddKey(km.activeKey.publicKey); err != nil {
|
||||
if err := set.AddKey(tok.activeKey.publicKey); err != nil {
|
||||
return auth.Key{}, err
|
||||
}
|
||||
if km.retiringKey != nil {
|
||||
if err := set.AddKey(km.retiringKey.publicKey); err != nil {
|
||||
if tok.retiringKey != nil {
|
||||
if err := set.AddKey(tok.retiringKey.publicKey); err != nil {
|
||||
return auth.Key{}, err
|
||||
}
|
||||
}
|
||||
@@ -148,17 +146,17 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e
|
||||
return smqjwt.ToKey(tkn)
|
||||
}
|
||||
|
||||
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
publicKeys := make([]auth.PublicKeyInfo, 0, 2)
|
||||
|
||||
if km.activeKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(km.activeKey); pkInfo != nil {
|
||||
if tok.activeKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(tok.activeKey); pkInfo != nil {
|
||||
publicKeys = append(publicKeys, *pkInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if km.retiringKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(km.retiringKey); pkInfo != nil {
|
||||
if tok.retiringKey != nil {
|
||||
if pkInfo := extractPublicKeyInfo(tok.retiringKey); pkInfo != nil {
|
||||
publicKeys = append(publicKeys, *pkInfo)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,6 @@ import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
patPrefix = "pat"
|
||||
)
|
||||
|
||||
var errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
|
||||
type tokenizer struct {
|
||||
algorithm jwa.KeyAlgorithm
|
||||
secret []byte
|
||||
@@ -41,13 +35,13 @@ func NewTokenizer(algorithm string, secret []byte) (auth.Tokenizer, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
tkn, err := smqjwt.BuildToken(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(km.algorithm, km.secret))
|
||||
signedBytes, err := jwt.Sign(tkn, jwt.WithKey(tok.algorithm, tok.secret))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -55,18 +49,18 @@ func (km *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
return string(signedBytes), nil
|
||||
}
|
||||
|
||||
func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == patPrefix {
|
||||
func (tok *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, error) {
|
||||
if len(tokenString) >= 3 && tokenString[:3] == smqjwt.PatPrefix {
|
||||
return auth.Key{Type: auth.PersonalAccessToken}, nil
|
||||
}
|
||||
|
||||
tkn, err := jwt.Parse(
|
||||
[]byte(tokenString),
|
||||
jwt.WithValidate(true),
|
||||
jwt.WithKey(km.algorithm, km.secret),
|
||||
jwt.WithKey(tok.algorithm, tok.secret),
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Contains(err, errJWTExpiryKey) {
|
||||
if errors.Contains(err, smqjwt.ErrJWTExpiryKey) {
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, auth.ErrExpiry)
|
||||
}
|
||||
return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
@@ -79,6 +73,6 @@ func (km *tokenizer) Parse(ctx context.Context, tokenString string) (auth.Key, e
|
||||
return smqjwt.ToKey(tkn)
|
||||
}
|
||||
|
||||
func (km *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
func (tok *tokenizer) RetrieveJWKS() ([]auth.PublicKeyInfo, error) {
|
||||
return nil, auth.ErrPublicKeysNotSupported
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ var (
|
||||
// ErrJSONHandle indicates an error in handling JSON.
|
||||
ErrJSONHandle = errors.New("failed to perform operation JSON")
|
||||
|
||||
// ErrJWTExpiryKey indicates that the "exp" claim in the JWT token is not satisfied.
|
||||
ErrJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
|
||||
errInvalidType = errors.New("invalid token type")
|
||||
errInvalidRole = errors.New("invalid role")
|
||||
errInvalidVerified = errors.New("invalid verified")
|
||||
@@ -28,6 +31,7 @@ const (
|
||||
TokenType = "type"
|
||||
RoleField = "role"
|
||||
VerifiedField = "verified"
|
||||
PatPrefix = "pat"
|
||||
)
|
||||
|
||||
// ToKey converts a JWT token to an auth.Key by extracting claims.
|
||||
|
||||
+7
-3
@@ -292,17 +292,21 @@ func validateKeyConfig(isSymmetric bool, cfg config, l *slog.Logger) error {
|
||||
}
|
||||
|
||||
func newService(db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration, tokenizer auth.Tokenizer, idProvider supermq.IDProvider) (auth.Service, error) {
|
||||
cache := cache.NewPatsCache(cacheClient, keyDuration)
|
||||
patsCache := cache.NewPatsCache(cacheClient, keyDuration)
|
||||
tokensCache, err := cache.NewUserActiveTokensCache(cacheClient, keyDuration)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := pgclient.NewDatabase(db, dbConfig, tracer)
|
||||
keysRepo := apostgres.New(database)
|
||||
patsRepo := apostgres.NewPatRepo(database, cache)
|
||||
patsRepo := apostgres.NewPatRepo(database, patsCache)
|
||||
hasher := hasher.New()
|
||||
|
||||
pEvaluator := spicedb.NewPolicyEvaluator(spicedbClient, logger)
|
||||
pService := spicedb.NewPolicyService(spicedbClient, logger)
|
||||
|
||||
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
|
||||
svc := auth.New(keysRepo, patsRepo, nil, tokensCache, hasher, idProvider, tokenizer, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
|
||||
svc = middleware.NewLogging(svc, logger)
|
||||
counter, latency := prometheus.MakeMetrics("auth", "api")
|
||||
svc = middleware.NewMetrics(svc, counter, latency)
|
||||
|
||||
+1
-1
@@ -399,7 +399,7 @@ func createAdmin(ctx context.Context, c config, repo users.Repository, hsr users
|
||||
if _, err = repo.Save(ctx, user); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword); err != nil {
|
||||
if _, err = svc.IssueToken(ctx, c.AdminUsername, c.AdminPassword, ""); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return user.ID, nil
|
||||
|
||||
@@ -95,6 +95,8 @@ services:
|
||||
- supermq-base-net
|
||||
volumes:
|
||||
- supermq-auth-redis-volume:/data
|
||||
- ./redis/redis.conf:/etc/redis/redis.conf:ro
|
||||
command: ["redis-server", "/etc/redis/redis.conf"]
|
||||
|
||||
auth:
|
||||
image: docker.io/supermq/auth:${SMQ_RELEASE_TAG}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) Abstract Machines
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Enable AOF persistence
|
||||
appendonly yes
|
||||
appendfilename "appendonly.aof"
|
||||
appendfsync everysec
|
||||
|
||||
# Enable periodic snapshots
|
||||
save 300 10
|
||||
|
||||
# Persist data in Docker volume
|
||||
dir /data
|
||||
|
||||
@@ -9,6 +9,9 @@ option go_package = "github.com/absmach/supermq/api/grpc/token/v1";
|
||||
service TokenService {
|
||||
rpc Issue(IssueReq) returns (Token) {}
|
||||
rpc Refresh(RefreshReq) returns (Token) {}
|
||||
rpc Revoke(RevokeReq) returns (RevokeRes) {}
|
||||
rpc ListUserRefreshTokens(ListUserRefreshTokensReq)
|
||||
returns (ListUserRefreshTokensRes) {}
|
||||
}
|
||||
|
||||
message IssueReq {
|
||||
@@ -16,6 +19,7 @@ message IssueReq {
|
||||
uint32 user_role = 2;
|
||||
uint32 type = 3;
|
||||
bool verified = 4;
|
||||
string description = 5;
|
||||
}
|
||||
|
||||
message RefreshReq {
|
||||
@@ -23,6 +27,11 @@ message RefreshReq {
|
||||
bool verified = 2;
|
||||
}
|
||||
|
||||
message RevokeReq {
|
||||
string token_id = 1;
|
||||
string user_id = 2;
|
||||
}
|
||||
|
||||
// If a token is not carrying any information itself, the type
|
||||
// field can be used to determine how to validate the token.
|
||||
// Also, different tokens can be encoded in different ways.
|
||||
@@ -31,3 +40,20 @@ message Token {
|
||||
optional string refresh_token = 2;
|
||||
string access_type = 3;
|
||||
}
|
||||
|
||||
message RevokeRes{
|
||||
|
||||
}
|
||||
|
||||
message ListUserRefreshTokensReq {
|
||||
string user_id = 1;
|
||||
}
|
||||
|
||||
message ListUserRefreshTokensRes {
|
||||
repeated RefreshToken refresh_tokens = 1;
|
||||
}
|
||||
|
||||
message RefreshToken {
|
||||
string id = 1;
|
||||
string description = 2;
|
||||
}
|
||||
|
||||
+3
-2
@@ -21,8 +21,9 @@ type Token struct {
|
||||
}
|
||||
|
||||
type Login struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
func (sdk mgSDK) CreateToken(ctx context.Context, lt Login) (Token, errors.SDKError) {
|
||||
|
||||
@@ -101,12 +101,12 @@ func TestIssueToken(t *testing.T) {
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password).Return(tc.svcRes, tc.svcErr)
|
||||
svcCall := svc.On("IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := mgsdk.CreateToken(context.Background(), tc.login)
|
||||
assert.Equal(t, tc.err, err)
|
||||
assert.Equal(t, tc.response, resp)
|
||||
if tc.err == nil {
|
||||
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password)
|
||||
ok := svcCall.Parent.AssertCalled(t, "IssueToken", mock.Anything, tc.login.Username, tc.login.Password, tc.login.Description)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
svcCall.Unset()
|
||||
|
||||
@@ -59,6 +59,7 @@ packages:
|
||||
Cache:
|
||||
Hasher:
|
||||
KeyRepository:
|
||||
UserActiveTokensCache:
|
||||
Tokenizer:
|
||||
PATS:
|
||||
PATSRepository:
|
||||
@@ -145,4 +146,3 @@ packages:
|
||||
github.com/absmach/supermq/notifications:
|
||||
interfaces:
|
||||
Notifier:
|
||||
|
||||
|
||||
+210
-1
@@ -2384,7 +2384,9 @@ func TestIssueToken(t *testing.T) {
|
||||
defer us.Close()
|
||||
|
||||
validUsername := "valid"
|
||||
validDescription := "test token"
|
||||
dataFormat := `{"username": "%s", "password": "%s"}`
|
||||
dataFormatWithDesc := `{"username": "%s", "password": "%s", "description": "%s"}`
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
@@ -2400,6 +2402,13 @@ func TestIssueToken(t *testing.T) {
|
||||
status: http.StatusCreated,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with valid identity, secret and description",
|
||||
data: fmt.Sprintf(dataFormatWithDesc, validUsername, secret, validDescription),
|
||||
contentType: contentType,
|
||||
status: http.StatusCreated,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "issue token with empty identity",
|
||||
data: fmt.Sprintf(dataFormat, "", secret),
|
||||
@@ -2447,7 +2456,7 @@ func TestIssueToken(t *testing.T) {
|
||||
body: strings.NewReader(tc.data),
|
||||
}
|
||||
|
||||
svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
|
||||
svcCall := svc.On("IssueToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&grpcTokenV1.Token{AccessToken: validToken}, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if tc.err != nil {
|
||||
@@ -2565,6 +2574,206 @@ func TestRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeRefreshToken(t *testing.T) {
|
||||
us, svc, authn := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
data string
|
||||
contentType string
|
||||
token string
|
||||
authnRes smqauthn.Session
|
||||
authnErr error
|
||||
status int
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke refresh token with valid token",
|
||||
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
|
||||
contentType: contentType,
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusNoContent,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with invalid token",
|
||||
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
|
||||
contentType: contentType,
|
||||
token: inValidToken,
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with empty token",
|
||||
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
|
||||
contentType: contentType,
|
||||
token: "",
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with empty token id",
|
||||
data: `{"token_id": ""}`,
|
||||
contentType: contentType,
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMissingID,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with malformed data",
|
||||
data: fmt.Sprintf(`{"token_id": %s}`, validToken),
|
||||
contentType: contentType,
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusBadRequest,
|
||||
err: apiutil.ErrMalformedRequestBody,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with invalid content type",
|
||||
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
|
||||
contentType: "application/xml",
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnsupportedMediaType,
|
||||
err: apiutil.ErrUnsupportedContentType,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with service error",
|
||||
data: fmt.Sprintf(`{"token_id": "%s"}`, validToken),
|
||||
contentType: contentType,
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodPost,
|
||||
url: fmt.Sprintf("%s/users/tokens/revoke", us.URL),
|
||||
contentType: tc.contentType,
|
||||
body: strings.NewReader(tc.data),
|
||||
token: tc.token,
|
||||
}
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("RevokeRefreshToken", mock.Anything, tc.authnRes, mock.Anything).Return(tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if tc.err != nil {
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListActiveRefreshTokens(t *testing.T) {
|
||||
us, svc, authn := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
token string
|
||||
authnRes smqauthn.Session
|
||||
authnErr error
|
||||
status int
|
||||
svcRes *grpcTokenV1.ListUserRefreshTokensRes
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list active refresh tokens with valid token",
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusOK,
|
||||
svcRes: &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: []*grpcTokenV1.RefreshToken{
|
||||
{Id: "token1", Description: "token-1"},
|
||||
{Id: "token2", Description: "token-2"},
|
||||
},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with invalid token",
|
||||
token: inValidToken,
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with empty token",
|
||||
token: "",
|
||||
status: http.StatusUnauthorized,
|
||||
authnErr: svcerr.ErrAuthentication,
|
||||
err: apiutil.ErrBearerToken,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with service error",
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusUnprocessableEntity,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with empty list",
|
||||
token: validToken,
|
||||
authnRes: verifiedSession,
|
||||
status: http.StatusOK,
|
||||
svcRes: &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: []*grpcTokenV1.RefreshToken{},
|
||||
},
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
req := testRequest{
|
||||
user: us.Client(),
|
||||
method: http.MethodGet,
|
||||
url: fmt.Sprintf("%s/users/tokens/refresh-tokens", us.URL),
|
||||
token: tc.token,
|
||||
}
|
||||
authnCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr)
|
||||
svcCall := svc.On("ListActiveRefreshTokens", mock.Anything, tc.authnRes).Return(tc.svcRes, tc.svcErr)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
if tc.err != nil {
|
||||
var resBody respBody
|
||||
err = json.NewDecoder(res.Body).Decode(&resBody)
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error while decoding response body: %s", tc.desc, err))
|
||||
if resBody.Err != "" || resBody.Message != "" {
|
||||
err = errors.Wrap(errors.New(resBody.Err), errors.New(resBody.Message))
|
||||
}
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
}
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
svcCall.Unset()
|
||||
authnCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnable(t *testing.T) {
|
||||
us, svc, authn := newUsersServer()
|
||||
defer us.Close()
|
||||
|
||||
+38
-1
@@ -416,7 +416,7 @@ func issueTokenEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
token, err := svc.IssueToken(ctx, req.Username, req.Password)
|
||||
token, err := svc.IssueToken(ctx, req.Username, req.Password, req.Description)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -454,6 +454,43 @@ func refreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
func revokeRefreshTokenEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(revokeTokenReq)
|
||||
if err := req.validate(); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, err)
|
||||
}
|
||||
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
err := svc.RevokeRefreshToken(ctx, session, req.TokenID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return revokeRes{}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func listActiveRefreshTokensEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
session, ok := ctx.Value(authn.SessionKey).(authn.Session)
|
||||
if !ok {
|
||||
return nil, svcerr.ErrAuthentication
|
||||
}
|
||||
|
||||
refreshTokens, err := svc.ListActiveRefreshTokens(ctx, session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listRefreshTokensRes{RefreshTokens: refreshTokens.GetRefreshTokens()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func enableEndpoint(svc users.Service) endpoint.Endpoint {
|
||||
return func(ctx context.Context, request any) (any, error) {
|
||||
req := request.(changeUserStatusReq)
|
||||
|
||||
+15
-2
@@ -270,8 +270,9 @@ func (req changeUserStatusReq) validate() error {
|
||||
}
|
||||
|
||||
type loginUserReq struct {
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
func (req loginUserReq) validate() error {
|
||||
@@ -297,6 +298,18 @@ func (req tokenReq) validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type revokeTokenReq struct {
|
||||
TokenID string `json:"token_id,omitempty"`
|
||||
}
|
||||
|
||||
func (req revokeTokenReq) validate() error {
|
||||
if req.TokenID == "" {
|
||||
return apiutil.ErrMissingID
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type passResetReq struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
+33
-1
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/absmach/supermq"
|
||||
grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1"
|
||||
"github.com/absmach/supermq/users"
|
||||
)
|
||||
|
||||
@@ -25,8 +26,9 @@ var (
|
||||
_ supermq.Response = (*passResetReqRes)(nil)
|
||||
_ supermq.Response = (*passChangeRes)(nil)
|
||||
_ supermq.Response = (*updateUserRes)(nil)
|
||||
_ supermq.Response = (*tokenRes)(nil)
|
||||
_ supermq.Response = (*revokeRes)(nil)
|
||||
_ supermq.Response = (*deleteUserRes)(nil)
|
||||
_ supermq.Response = (*listRefreshTokensRes)(nil)
|
||||
)
|
||||
|
||||
type pageRes struct {
|
||||
@@ -80,6 +82,36 @@ func (res tokenRes) Empty() bool {
|
||||
return res.AccessToken == "" || res.RefreshToken == ""
|
||||
}
|
||||
|
||||
type revokeRes struct{}
|
||||
|
||||
func (res revokeRes) Code() int {
|
||||
return http.StatusNoContent
|
||||
}
|
||||
|
||||
func (res revokeRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res revokeRes) Empty() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type listRefreshTokensRes struct {
|
||||
RefreshTokens []*grpcTokenV1.RefreshToken `json:"refresh_tokens"`
|
||||
}
|
||||
|
||||
func (res listRefreshTokensRes) Code() int {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
func (res listRefreshTokensRes) Headers() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
|
||||
func (res listRefreshTokensRes) Empty() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type sendVerificationRes struct{}
|
||||
|
||||
func (res sendVerificationRes) Code() int {
|
||||
|
||||
@@ -78,6 +78,18 @@ func usersHandler(svc users.Service, authn smqauthn.AuthNMiddleware, tokenClient
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "refresh_token").ServeHTTP)
|
||||
r.Post("/tokens/revoke", otelhttp.NewHandler(kithttp.NewServer(
|
||||
revokeRefreshTokenEndpoint(svc),
|
||||
decodeRevokeRefreshToken,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "revoke_refresh_token").ServeHTTP)
|
||||
r.Get("/tokens/refresh-tokens", otelhttp.NewHandler(kithttp.NewServer(
|
||||
listActiveRefreshTokensEndpoint(svc),
|
||||
decodeListActiveRefreshTokens,
|
||||
api.EncodeResponse,
|
||||
opts...,
|
||||
), "list_active_refresh_tokens").ServeHTTP)
|
||||
r.Patch("/{id}/email", otelhttp.NewHandler(kithttp.NewServer(
|
||||
updateEmailEndpoint(svc),
|
||||
decodeUpdateUserEmail,
|
||||
@@ -532,6 +544,23 @@ func decodeRefreshToken(_ context.Context, r *http.Request) (any, error) {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeRevokeRefreshToken(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
}
|
||||
|
||||
var req revokeTokenReq
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
return nil, errors.Wrap(apiutil.ErrMalformedRequestBody, err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeListActiveRefreshTokens(_ context.Context, r *http.Request) (any, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func decodeCreateUserReq(_ context.Context, r *http.Request) (any, error) {
|
||||
if !strings.Contains(r.Header.Get("Content-Type"), api.ContentType) {
|
||||
return nil, errors.Wrap(apiutil.ErrValidation, apiutil.ErrUnsupportedContentType)
|
||||
|
||||
+15
-2
@@ -29,11 +29,10 @@ const (
|
||||
profileView = userPrefix + "view_profile"
|
||||
userList = userPrefix + "list"
|
||||
userSearch = userPrefix + "search"
|
||||
userListByGroup = userPrefix + "list_by_group"
|
||||
userIdentify = userPrefix + "identify"
|
||||
generateResetToken = userPrefix + "generate_reset_token"
|
||||
issueToken = userPrefix + "issue_token"
|
||||
refreshToken = userPrefix + "refresh_token"
|
||||
revokeRefreshToken = userPrefix + "revoke_refresh_token"
|
||||
resetSecret = userPrefix + "reset_secret"
|
||||
sendPasswordReset = userPrefix + "send_password_reset"
|
||||
oauthCallback = userPrefix + "oauth_callback"
|
||||
@@ -56,6 +55,7 @@ var (
|
||||
_ events.Event = (*identifyUserEvent)(nil)
|
||||
_ events.Event = (*issueTokenEvent)(nil)
|
||||
_ events.Event = (*refreshTokenEvent)(nil)
|
||||
_ events.Event = (*revokeRefreshTokenEvent)(nil)
|
||||
_ events.Event = (*resetSecretEvent)(nil)
|
||||
_ events.Event = (*sendPasswordResetEvent)(nil)
|
||||
_ events.Event = (*oauthCallbackEvent)(nil)
|
||||
@@ -492,6 +492,19 @@ func (rte refreshTokenEvent) Encode() (map[string]any, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
type revokeRefreshTokenEvent struct {
|
||||
tokenID string
|
||||
requestID string
|
||||
}
|
||||
|
||||
func (rrte revokeRefreshTokenEvent) Encode() (map[string]any, error) {
|
||||
return map[string]any{
|
||||
"operation": revokeRefreshToken,
|
||||
"token_id": rrte.tokenID,
|
||||
"request_id": rrte.requestID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type resetSecretEvent struct {
|
||||
requestID string
|
||||
}
|
||||
|
||||
+46
-27
@@ -15,31 +15,32 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
supermqPrefix = "supermq."
|
||||
createStream = supermqPrefix + userCreate
|
||||
sendVerificationStream = supermqPrefix + userSendVerification
|
||||
verifyEmailStream = supermqPrefix + userVerifyEmail
|
||||
updateStream = supermqPrefix + userUpdate
|
||||
updateRoleStream = supermqPrefix + userUpdateRole
|
||||
updateTagsStream = supermqPrefix + userUpdateTags
|
||||
updateSecretStream = supermqPrefix + userUpdateSecret
|
||||
updateUsernameStream = supermqPrefix + userUpdateUsername
|
||||
updatePictureStream = supermqPrefix + userUpdateProfilePicture
|
||||
UpdateEmailStream = supermqPrefix + userUpdateEmail
|
||||
enableStream = supermqPrefix + userEnable
|
||||
disableStream = supermqPrefix + userDisable
|
||||
viewStream = supermqPrefix + userView
|
||||
viewProfileStream = supermqPrefix + profileView
|
||||
listStream = supermqPrefix + userList
|
||||
searchStream = supermqPrefix + userSearch
|
||||
identifyStream = supermqPrefix + userIdentify
|
||||
issueTokenStream = supermqPrefix + issueToken
|
||||
refreshTokenStream = supermqPrefix + refreshToken
|
||||
resetSecretStream = supermqPrefix + resetSecret
|
||||
sendPasswordResetStream = supermqPrefix + sendPasswordReset
|
||||
oauthStream = supermqPrefix + oauthCallback
|
||||
addPolicyStream = supermqPrefix + addClientPolicy
|
||||
deleteStream = supermqPrefix + deleteUser
|
||||
supermqPrefix = "supermq."
|
||||
createStream = supermqPrefix + userCreate
|
||||
sendVerificationStream = supermqPrefix + userSendVerification
|
||||
verifyEmailStream = supermqPrefix + userVerifyEmail
|
||||
updateStream = supermqPrefix + userUpdate
|
||||
updateRoleStream = supermqPrefix + userUpdateRole
|
||||
updateTagsStream = supermqPrefix + userUpdateTags
|
||||
updateSecretStream = supermqPrefix + userUpdateSecret
|
||||
updateUsernameStream = supermqPrefix + userUpdateUsername
|
||||
updatePictureStream = supermqPrefix + userUpdateProfilePicture
|
||||
UpdateEmailStream = supermqPrefix + userUpdateEmail
|
||||
enableStream = supermqPrefix + userEnable
|
||||
disableStream = supermqPrefix + userDisable
|
||||
viewStream = supermqPrefix + userView
|
||||
viewProfileStream = supermqPrefix + profileView
|
||||
listStream = supermqPrefix + userList
|
||||
searchStream = supermqPrefix + userSearch
|
||||
identifyStream = supermqPrefix + userIdentify
|
||||
issueTokenStream = supermqPrefix + issueToken
|
||||
refreshTokenStream = supermqPrefix + refreshToken
|
||||
revokeRefreshTokenStream = supermqPrefix + revokeRefreshToken
|
||||
resetSecretStream = supermqPrefix + resetSecret
|
||||
sendPasswordResetStream = supermqPrefix + sendPasswordReset
|
||||
oauthStream = supermqPrefix + oauthCallback
|
||||
addPolicyStream = supermqPrefix + addClientPolicy
|
||||
deleteStream = supermqPrefix + deleteUser
|
||||
)
|
||||
|
||||
var _ users.Service = (*eventStore)(nil)
|
||||
@@ -350,8 +351,8 @@ func (es *eventStore) SendPasswordReset(ctx context.Context, email string) error
|
||||
return es.Publish(ctx, sendPasswordResetStream, event)
|
||||
}
|
||||
|
||||
func (es *eventStore) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
token, err := es.svc.IssueToken(ctx, username, secret)
|
||||
func (es *eventStore) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
|
||||
token, err := es.svc.IssueToken(ctx, username, secret, description)
|
||||
if err != nil {
|
||||
return token, err
|
||||
}
|
||||
@@ -385,6 +386,24 @@ func (es *eventStore) RefreshToken(ctx context.Context, session authn.Session, r
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (es *eventStore) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
err := es.svc.RevokeRefreshToken(ctx, session, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
event := revokeRefreshTokenEvent{
|
||||
tokenID: tokenID,
|
||||
requestID: middleware.GetReqID(ctx),
|
||||
}
|
||||
|
||||
return es.Publish(ctx, revokeRefreshTokenStream, event)
|
||||
}
|
||||
|
||||
func (es *eventStore) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
return es.svc.ListActiveRefreshTokens(ctx, session)
|
||||
}
|
||||
|
||||
func (es *eventStore) ResetSecret(ctx context.Context, session authn.Session, secret string) error {
|
||||
if err := es.svc.ResetSecret(ctx, session, secret); err != nil {
|
||||
return err
|
||||
|
||||
+105
-16
@@ -902,22 +902,24 @@ func TestIssueToken(t *testing.T) {
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
username string
|
||||
secret string
|
||||
svcRes *grpcTokenV1.Token
|
||||
svcErr error
|
||||
resp *grpcTokenV1.Token
|
||||
err error
|
||||
desc string
|
||||
username string
|
||||
secret string
|
||||
description string
|
||||
svcRes *grpcTokenV1.Token
|
||||
svcErr error
|
||||
resp *grpcTokenV1.Token
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
username: validUser.Credentials.Username,
|
||||
secret: validUser.Credentials.Secret,
|
||||
svcRes: validToken,
|
||||
svcErr: nil,
|
||||
resp: validToken,
|
||||
err: nil,
|
||||
desc: "publish successfully",
|
||||
username: validUser.Credentials.Username,
|
||||
secret: validUser.Credentials.Secret,
|
||||
description: "valid token",
|
||||
svcRes: validToken,
|
||||
svcErr: nil,
|
||||
resp: validToken,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
@@ -932,8 +934,8 @@ func TestIssueToken(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret)
|
||||
svcCall := svc.On("IssueToken", validCtx, tc.username, tc.secret, tc.description).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.IssueToken(validCtx, tc.username, tc.secret, tc.description)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
@@ -990,6 +992,93 @@ func TestRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeRefreshToken(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
validTokenID := "validTokenID"
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
tokenID string
|
||||
svcErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
tokenID: validTokenID,
|
||||
svcErr: nil,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
tokenID: validTokenID,
|
||||
svcErr: svcerr.ErrUpdateEntity,
|
||||
err: svcerr.ErrUpdateEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("RevokeRefreshToken", validCtx, tc.session, tc.tokenID).Return(tc.svcErr)
|
||||
err := nsvc.RevokeRefreshToken(validCtx, tc.session, tc.tokenID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListActiveRefreshTokens(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
validCtx := context.WithValue(context.Background(), middleware.RequestIDKey, testsutil.GenerateUUID(t))
|
||||
validTokensList := &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: []*grpcTokenV1.RefreshToken{
|
||||
{Id: "token1", Description: "token1"},
|
||||
{Id: "token2", Description: "token2"},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
svcRes *grpcTokenV1.ListUserRefreshTokensRes
|
||||
svcErr error
|
||||
resp *grpcTokenV1.ListUserRefreshTokensRes
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "publish successfully",
|
||||
session: validSession,
|
||||
svcRes: validTokensList,
|
||||
svcErr: nil,
|
||||
resp: validTokensList,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "failed to publish with service error",
|
||||
session: validSession,
|
||||
svcRes: nil,
|
||||
svcErr: svcerr.ErrViewEntity,
|
||||
resp: nil,
|
||||
err: svcerr.ErrViewEntity,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("ListActiveRefreshTokens", validCtx, tc.session).Return(tc.svcRes, tc.svcErr)
|
||||
resp, err := nsvc.ListActiveRefreshTokens(validCtx, tc.session)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
assert.Equal(t, tc.resp, resp, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.resp, resp))
|
||||
svcCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResetSecret(t *testing.T) {
|
||||
svc, nsvc := newEventStoreMiddleware(t)
|
||||
|
||||
|
||||
@@ -162,14 +162,22 @@ func (am *authorizationMiddleware) Identify(ctx context.Context, session authn.S
|
||||
return am.svc.Identify(ctx, session)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
return am.svc.IssueToken(ctx, username, secret)
|
||||
func (am *authorizationMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
|
||||
return am.svc.IssueToken(ctx, username, secret, description)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) {
|
||||
return am.svc.RefreshToken(ctx, session, refreshToken)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
return am.svc.RevokeRefreshToken(ctx, session, tokenID)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
return am.svc.ListActiveRefreshTokens(ctx, session)
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) OAuthCallback(ctx context.Context, user users.User) (users.User, error) {
|
||||
return am.svc.OAuthCallback(ctx, user)
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ func (lm *loggingMiddleware) VerifyEmail(ctx context.Context, verificationToken
|
||||
|
||||
// IssueToken logs the issue_token request. It logs the username type and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret string) (t *grpcTokenV1.Token, err error) {
|
||||
func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (t *grpcTokenV1.Token, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
@@ -105,7 +105,7 @@ func (lm *loggingMiddleware) IssueToken(ctx context.Context, username, secret st
|
||||
}
|
||||
lm.logger.Info("Issue token completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.IssueToken(ctx, username, secret)
|
||||
return lm.svc.IssueToken(ctx, username, secret, description)
|
||||
}
|
||||
|
||||
// RefreshToken logs the refresh_token request. It logs the refreshtoken, token type and the time it took to complete the request.
|
||||
@@ -129,6 +129,45 @@ func (lm *loggingMiddleware) RefreshToken(ctx context.Context, session authn.Ses
|
||||
return lm.svc.RefreshToken(ctx, session, refreshToken)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken logs the revoke_refresh_token request. It logs the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) (err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("Revoke refresh token failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("Revoke refresh token completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.RevokeRefreshToken(ctx, session, tokenID)
|
||||
}
|
||||
|
||||
// ListActiveRefreshTokens logs the list_active_refresh_tokens request. It logs the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (tokens *grpcTokenV1.ListUserRefreshTokensRes, err error) {
|
||||
defer func(begin time.Time) {
|
||||
args := []any{
|
||||
slog.String("duration", time.Since(begin).String()),
|
||||
slog.String("request_id", middleware.GetReqID(ctx)),
|
||||
}
|
||||
if tokens != nil {
|
||||
args = append(args, slog.Int("tokens_count", len(tokens.GetRefreshTokens())))
|
||||
}
|
||||
if err != nil {
|
||||
args = append(args, slog.String("error", err.Error()))
|
||||
lm.logger.Warn("List active refresh tokens failed", args...)
|
||||
return
|
||||
}
|
||||
lm.logger.Info("List active refresh tokens completed successfully", args...)
|
||||
}(time.Now())
|
||||
return lm.svc.ListActiveRefreshTokens(ctx, session)
|
||||
}
|
||||
|
||||
// View logs the view_user request. It logs the user id and the time it took to complete the request.
|
||||
// If the request fails, it logs the error.
|
||||
func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id string) (c users.User, err error) {
|
||||
|
||||
@@ -58,12 +58,12 @@ func (ms *metricsMiddleware) VerifyEmail(ctx context.Context, verificationToken
|
||||
}
|
||||
|
||||
// IssueToken instruments IssueToken method with metrics.
|
||||
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
func (ms *metricsMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "issue_token").Add(1)
|
||||
ms.latency.With("method", "issue_token").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.IssueToken(ctx, username, secret)
|
||||
return ms.svc.IssueToken(ctx, username, secret, description)
|
||||
}
|
||||
|
||||
// RefreshToken instruments RefreshToken method with metrics.
|
||||
@@ -75,6 +75,24 @@ func (ms *metricsMiddleware) RefreshToken(ctx context.Context, session authn.Ses
|
||||
return ms.svc.RefreshToken(ctx, session, refreshToken)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken instruments RevokeRefreshToken method with metrics.
|
||||
func (ms *metricsMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "revoke_refresh_token").Add(1)
|
||||
ms.latency.With("method", "revoke_refresh_token").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.RevokeRefreshToken(ctx, session, tokenID)
|
||||
}
|
||||
|
||||
// ListActiveRefreshTokens instruments ListActiveRefreshTokens method with metrics.
|
||||
func (ms *metricsMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
defer func(begin time.Time) {
|
||||
ms.counter.With("method", "list_active_refresh_tokens").Add(1)
|
||||
ms.latency.With("method", "list_active_refresh_tokens").Observe(time.Since(begin).Seconds())
|
||||
}(time.Now())
|
||||
return ms.svc.ListActiveRefreshTokens(ctx, session)
|
||||
}
|
||||
|
||||
// View instruments View method with metrics.
|
||||
func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
|
||||
defer func(begin time.Time) {
|
||||
|
||||
@@ -50,11 +50,11 @@ func (tm *tracingMiddleware) VerifyEmail(ctx context.Context, verificationToken
|
||||
}
|
||||
|
||||
// IssueToken traces the "IssueToken" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) {
|
||||
func (tm *tracingMiddleware) IssueToken(ctx context.Context, username, secret, description string) (*grpcTokenV1.Token, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_issue_token", trace.WithAttributes(attribute.String("username", username)))
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.IssueToken(ctx, username, secret)
|
||||
return tm.svc.IssueToken(ctx, username, secret, description)
|
||||
}
|
||||
|
||||
// RefreshToken traces the "RefreshToken" operation of the wrapped users.Service.
|
||||
@@ -65,6 +65,22 @@ func (tm *tracingMiddleware) RefreshToken(ctx context.Context, session authn.Ses
|
||||
return tm.svc.RefreshToken(ctx, session, refreshToken)
|
||||
}
|
||||
|
||||
// RevokeRefreshToken traces the "RevokeRefreshToken" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_revoke_refresh_token")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.RevokeRefreshToken(ctx, session, tokenID)
|
||||
}
|
||||
|
||||
// ListActiveRefreshTokens traces the "ListActiveRefreshTokens" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_list_active_refresh_tokens")
|
||||
defer span.End()
|
||||
|
||||
return tm.svc.ListActiveRefreshTokens(ctx, session)
|
||||
}
|
||||
|
||||
// View traces the "View" operation of the wrapped users.Service.
|
||||
func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
|
||||
ctx, span := tracing.StartSpan(ctx, tm.tracer, "svc_view_user", trace.WithAttributes(attribute.String("id", id)))
|
||||
|
||||
+149
-12
@@ -318,8 +318,8 @@ func (_c *Service_Identify_Call) RunAndReturn(run func(ctx context.Context, sess
|
||||
}
|
||||
|
||||
// IssueToken provides a mock function for the type Service
|
||||
func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string) (*v1.Token, error) {
|
||||
ret := _mock.Called(ctx, identity, secret)
|
||||
func (_mock *Service) IssueToken(ctx context.Context, identity string, secret string, description string) (*v1.Token, error) {
|
||||
ret := _mock.Called(ctx, identity, secret, description)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for IssueToken")
|
||||
@@ -327,18 +327,18 @@ func (_mock *Service) IssueToken(ctx context.Context, identity string, secret st
|
||||
|
||||
var r0 *v1.Token
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) (*v1.Token, error)); ok {
|
||||
return returnFunc(ctx, identity, secret)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) (*v1.Token, error)); ok {
|
||||
return returnFunc(ctx, identity, secret, description)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string) *v1.Token); ok {
|
||||
r0 = returnFunc(ctx, identity, secret)
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string) *v1.Token); ok {
|
||||
r0 = returnFunc(ctx, identity, secret, description)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.Token)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, identity, secret)
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok {
|
||||
r1 = returnFunc(ctx, identity, secret, description)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
@@ -354,11 +354,12 @@ type Service_IssueToken_Call struct {
|
||||
// - ctx context.Context
|
||||
// - identity string
|
||||
// - secret string
|
||||
func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}) *Service_IssueToken_Call {
|
||||
return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret)}
|
||||
// - description string
|
||||
func (_e *Service_Expecter) IssueToken(ctx interface{}, identity interface{}, secret interface{}, description interface{}) *Service_IssueToken_Call {
|
||||
return &Service_IssueToken_Call{Call: _e.mock.On("IssueToken", ctx, identity, secret, description)}
|
||||
}
|
||||
|
||||
func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string)) *Service_IssueToken_Call {
|
||||
func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity string, secret string, description string)) *Service_IssueToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
@@ -372,10 +373,15 @@ func (_c *Service_IssueToken_Call) Run(run func(ctx context.Context, identity st
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
var arg3 string
|
||||
if args[3] != nil {
|
||||
arg3 = args[3].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
@@ -386,7 +392,75 @@ func (_c *Service_IssueToken_Call) Return(token *v1.Token, err error) *Service_I
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string) (*v1.Token, error)) *Service_IssueToken_Call {
|
||||
func (_c *Service_IssueToken_Call) RunAndReturn(run func(ctx context.Context, identity string, secret string, description string) (*v1.Token, error)) *Service_IssueToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListActiveRefreshTokens provides a mock function for the type Service
|
||||
func (_mock *Service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error) {
|
||||
ret := _mock.Called(ctx, session)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListActiveRefreshTokens")
|
||||
}
|
||||
|
||||
var r0 *v1.ListUserRefreshTokensRes
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) (*v1.ListUserRefreshTokensRes, error)); ok {
|
||||
return returnFunc(ctx, session)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session) *v1.ListUserRefreshTokensRes); ok {
|
||||
r0 = returnFunc(ctx, session)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*v1.ListUserRefreshTokensRes)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(context.Context, authn.Session) error); ok {
|
||||
r1 = returnFunc(ctx, session)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// Service_ListActiveRefreshTokens_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListActiveRefreshTokens'
|
||||
type Service_ListActiveRefreshTokens_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListActiveRefreshTokens is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - session authn.Session
|
||||
func (_e *Service_Expecter) ListActiveRefreshTokens(ctx interface{}, session interface{}) *Service_ListActiveRefreshTokens_Call {
|
||||
return &Service_ListActiveRefreshTokens_Call{Call: _e.mock.On("ListActiveRefreshTokens", ctx, session)}
|
||||
}
|
||||
|
||||
func (_c *Service_ListActiveRefreshTokens_Call) Run(run func(ctx context.Context, session authn.Session)) *Service_ListActiveRefreshTokens_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 authn.Session
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(authn.Session)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ListActiveRefreshTokens_Call) Return(listUserRefreshTokensRes *v1.ListUserRefreshTokensRes, err error) *Service_ListActiveRefreshTokens_Call {
|
||||
_c.Call.Return(listUserRefreshTokensRes, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_ListActiveRefreshTokens_Call) RunAndReturn(run func(ctx context.Context, session authn.Session) (*v1.ListUserRefreshTokensRes, error)) *Service_ListActiveRefreshTokens_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
@@ -801,6 +875,69 @@ func (_c *Service_ResetSecret_Call) RunAndReturn(run func(ctx context.Context, s
|
||||
return _c
|
||||
}
|
||||
|
||||
// RevokeRefreshToken provides a mock function for the type Service
|
||||
func (_mock *Service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
ret := _mock.Called(ctx, session, tokenID)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RevokeRefreshToken")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if returnFunc, ok := ret.Get(0).(func(context.Context, authn.Session, string) error); ok {
|
||||
r0 = returnFunc(ctx, session, tokenID)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
return r0
|
||||
}
|
||||
|
||||
// Service_RevokeRefreshToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeRefreshToken'
|
||||
type Service_RevokeRefreshToken_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RevokeRefreshToken is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - session authn.Session
|
||||
// - tokenID string
|
||||
func (_e *Service_Expecter) RevokeRefreshToken(ctx interface{}, session interface{}, tokenID interface{}) *Service_RevokeRefreshToken_Call {
|
||||
return &Service_RevokeRefreshToken_Call{Call: _e.mock.On("RevokeRefreshToken", ctx, session, tokenID)}
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeRefreshToken_Call) Run(run func(ctx context.Context, session authn.Session, tokenID string)) *Service_RevokeRefreshToken_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
var arg0 context.Context
|
||||
if args[0] != nil {
|
||||
arg0 = args[0].(context.Context)
|
||||
}
|
||||
var arg1 authn.Session
|
||||
if args[1] != nil {
|
||||
arg1 = args[1].(authn.Session)
|
||||
}
|
||||
var arg2 string
|
||||
if args[2] != nil {
|
||||
arg2 = args[2].(string)
|
||||
}
|
||||
run(
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
)
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeRefreshToken_Call) Return(err error) *Service_RevokeRefreshToken_Call {
|
||||
_c.Call.Return(err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *Service_RevokeRefreshToken_Call) RunAndReturn(run func(ctx context.Context, session authn.Session, tokenID string) error) *Service_RevokeRefreshToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SearchUsers provides a mock function for the type Service
|
||||
func (_mock *Service) SearchUsers(ctx context.Context, pm users.Page) (users.UsersPage, error) {
|
||||
ret := _mock.Called(ctx, pm)
|
||||
|
||||
+45
-3
@@ -183,7 +183,7 @@ func (svc service) VerifyEmail(ctx context.Context, token string) (User, error)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (svc service) IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error) {
|
||||
func (svc service) IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error) {
|
||||
var dbUser User
|
||||
var err error
|
||||
|
||||
@@ -205,7 +205,13 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
|
||||
}
|
||||
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey), Verified: !dbUser.VerifiedAt.IsZero()})
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{
|
||||
UserId: dbUser.ID,
|
||||
UserRole: uint32(dbUser.Role + 1),
|
||||
Type: uint32(smqauth.AccessKey),
|
||||
Verified: !dbUser.VerifiedAt.IsZero(),
|
||||
Description: description,
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
|
||||
}
|
||||
@@ -229,6 +235,42 @@ func (svc service) RefreshToken(ctx context.Context, session authn.Session, refr
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (svc service) RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error {
|
||||
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if dbUser.Status == DisabledStatus {
|
||||
return errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
|
||||
}
|
||||
_, err = svc.token.Revoke(ctx, &grpcTokenV1.RevokeReq{UserId: session.UserID, TokenId: tokenID})
|
||||
if err != nil {
|
||||
if errors.Contains(err, svcerr.ErrNotFound) {
|
||||
return errors.Wrap(svcerr.ErrNotFound, err)
|
||||
}
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error) {
|
||||
dbUser, err := svc.users.RetrieveByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
if dbUser.Status == DisabledStatus {
|
||||
return nil, errors.Wrap(svcerr.ErrAuthentication, errLoginDisableUser)
|
||||
}
|
||||
|
||||
refreshTokens, err := svc.token.ListUserRefreshTokens(ctx, &grpcTokenV1.ListUserRefreshTokensReq{UserId: session.UserID})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(svcerr.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
return refreshTokens, nil
|
||||
}
|
||||
|
||||
func (svc service) View(ctx context.Context, session authn.Session, id string) (User, error) {
|
||||
user, err := svc.users.RetrieveByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -453,7 +495,7 @@ func (svc service) UpdateSecret(ctx context.Context, session authn.Session, oldS
|
||||
if err != nil {
|
||||
return User{}, errors.Wrap(svcerr.ErrViewEntity, err)
|
||||
}
|
||||
if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret); err != nil {
|
||||
if _, err := svc.IssueToken(ctx, dbUser.Credentials.Username, oldSecret, ""); err != nil {
|
||||
return User{}, err
|
||||
}
|
||||
newSecret, err = svc.hasher.Hash(newSecret)
|
||||
|
||||
+176
-1
@@ -1615,7 +1615,7 @@ func TestIssueToken(t *testing.T) {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
|
||||
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
|
||||
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
|
||||
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret, "")
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotEmpty(t, token.GetAccessToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetAccessToken()))
|
||||
@@ -1710,6 +1710,181 @@ func TestRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeRefreshToken(t *testing.T) {
|
||||
svc, authsvc, crepo, _, _ := newService()
|
||||
|
||||
rUser := user
|
||||
rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
tokenID string
|
||||
revokeResp *grpcTokenV1.RevokeRes
|
||||
revokeErr error
|
||||
repoResp users.User
|
||||
repoErr error
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "revoke refresh token successfully",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
tokenID: validToken,
|
||||
revokeResp: &grpcTokenV1.RevokeRes{},
|
||||
repoResp: rUser,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with empty domain id",
|
||||
session: authn.Session{UserID: validID},
|
||||
tokenID: validToken,
|
||||
revokeResp: &grpcTokenV1.RevokeRes{},
|
||||
repoResp: rUser,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token for non-existing user",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
tokenID: validToken,
|
||||
repoErr: repoerr.ErrNotFound,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token for disabled user",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
tokenID: validToken,
|
||||
repoResp: users.User{Status: users.DisabledStatus},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token with revoke service error",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
tokenID: validToken,
|
||||
revokeResp: &grpcTokenV1.RevokeRes{},
|
||||
revokeErr: svcerr.ErrAuthorization,
|
||||
repoResp: rUser,
|
||||
err: svcerr.ErrAuthorization,
|
||||
},
|
||||
{
|
||||
desc: "revoke refresh token not found",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
tokenID: validToken,
|
||||
revokeResp: &grpcTokenV1.RevokeRes{},
|
||||
revokeErr: svcerr.ErrNotFound,
|
||||
repoResp: rUser,
|
||||
err: svcerr.ErrNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
|
||||
authCall := authsvc.On("Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID}).Return(tc.revokeResp, tc.revokeErr)
|
||||
err := svc.RevokeRefreshToken(context.Background(), tc.session, tc.tokenID)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = authCall.Parent.AssertCalled(t, "Revoke", context.Background(), &grpcTokenV1.RevokeReq{UserId: tc.session.UserID, TokenId: tc.tokenID})
|
||||
assert.True(t, ok, fmt.Sprintf("Revoke was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListActiveRefreshTokens(t *testing.T) {
|
||||
svc, authsvc, crepo, _, _ := newService()
|
||||
|
||||
rUser := user
|
||||
rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret)
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
session authn.Session
|
||||
listResp *grpcTokenV1.ListUserRefreshTokensRes
|
||||
listErr error
|
||||
repoResp users.User
|
||||
repoErr error
|
||||
expectedTokens int
|
||||
err error
|
||||
}{
|
||||
{
|
||||
desc: "list active refresh tokens successfully",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
listResp: &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: []*grpcTokenV1.RefreshToken{
|
||||
{Id: "token1", Description: "token1"},
|
||||
{Id: "token2", Description: "token2"},
|
||||
},
|
||||
},
|
||||
repoResp: rUser,
|
||||
expectedTokens: 2,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with empty domain id",
|
||||
session: authn.Session{UserID: validID},
|
||||
listResp: &grpcTokenV1.ListUserRefreshTokensRes{
|
||||
RefreshTokens: []*grpcTokenV1.RefreshToken{
|
||||
{Id: "token1", Description: "token1"},
|
||||
},
|
||||
},
|
||||
repoResp: rUser,
|
||||
expectedTokens: 1,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens for non-existing user",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
repoErr: repoerr.ErrNotFound,
|
||||
err: repoerr.ErrNotFound,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens for disabled user",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
repoResp: users.User{Status: users.DisabledStatus},
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with list service error",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
listResp: &grpcTokenV1.ListUserRefreshTokensRes{},
|
||||
listErr: svcerr.ErrAuthentication,
|
||||
repoResp: rUser,
|
||||
err: svcerr.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "list active refresh tokens with empty list",
|
||||
session: authn.Session{DomainUserID: validID, UserID: validID, DomainID: validID},
|
||||
listResp: &grpcTokenV1.ListUserRefreshTokensRes{RefreshTokens: []*grpcTokenV1.RefreshToken{}},
|
||||
repoResp: rUser,
|
||||
expectedTokens: 0,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
repoCall := crepo.On("RetrieveByID", context.Background(), tc.session.UserID).Return(tc.repoResp, tc.repoErr)
|
||||
authCall := authsvc.On("ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID}).Return(tc.listResp, tc.listErr)
|
||||
tokens, err := svc.ListActiveRefreshTokens(context.Background(), tc.session)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
assert.NotNil(t, tokens, fmt.Sprintf("%s: expected tokens not to be nil\n", tc.desc))
|
||||
assert.Equal(t, tc.expectedTokens, len(tokens.GetRefreshTokens()), fmt.Sprintf("%s: expected %d tokens got %d\n", tc.desc, tc.expectedTokens, len(tokens.GetRefreshTokens())))
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", context.Background(), tc.session.UserID)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc))
|
||||
ok = authCall.Parent.AssertCalled(t, "ListUserRefreshTokens", context.Background(), &grpcTokenV1.ListUserRefreshTokensReq{UserId: tc.session.UserID})
|
||||
assert.True(t, ok, fmt.Sprintf("ListUserRefreshTokens was not called on %s", tc.desc))
|
||||
}
|
||||
repoCall.Unset()
|
||||
authCall.Unset()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendPasswordReset(t *testing.T) {
|
||||
svc, auth, cRepo, _, e := newService()
|
||||
|
||||
|
||||
+7
-1
@@ -264,13 +264,19 @@ type Service interface {
|
||||
Identify(ctx context.Context, session authn.Session) (string, error)
|
||||
|
||||
// IssueToken issues a new access and refresh token when provided with either a username or email.
|
||||
IssueToken(ctx context.Context, identity, secret string) (*grpcTokenV1.Token, error)
|
||||
IssueToken(ctx context.Context, identity, secret, description string) (*grpcTokenV1.Token, error)
|
||||
|
||||
// RefreshToken refreshes expired access tokens.
|
||||
// After an access token expires, the refresh token is used to get
|
||||
// a new pair of access and refresh tokens.
|
||||
RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error)
|
||||
|
||||
// RevokeRefreshToken revokes a refresh token by its ID.
|
||||
RevokeRefreshToken(ctx context.Context, session authn.Session, tokenID string) error
|
||||
|
||||
// ListActiveRefreshTokens lists all active refresh tokens for the authenticated user.
|
||||
ListActiveRefreshTokens(ctx context.Context, session authn.Session) (*grpcTokenV1.ListUserRefreshTokensRes, error)
|
||||
|
||||
// OAuthCallback handles the callback from any supported OAuth provider.
|
||||
// It processes the OAuth tokens and either signs in or signs up the user based on the provided state.
|
||||
OAuthCallback(ctx context.Context, user User) (User, error)
|
||||
|
||||
Reference in New Issue
Block a user