SMQ-2604 - Change PAT repo implementation (#2680)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
Signed-off-by: Arvindh <arvindh91@gmail.com>
Co-authored-by: Arvindh <arvindh91@gmail.com>
This commit is contained in:
Steve Munene
2025-02-27 17:38:34 +03:00
committed by GitHub
parent 56829c15d8
commit 17b5224090
47 changed files with 2585 additions and 2690 deletions
+46 -59
View File
@@ -236,16 +236,15 @@ func (x *AuthZReq) GetObjectType() string {
}
type AuthZPatReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id
PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id
PlatformEntityType uint32 `protobuf:"varint,3,opt,name=platform_entity_type,json=platformEntityType,proto3" json:"platform_entity_type,omitempty"` // Platform entity type
OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id
OptionalDomainEntityType uint32 `protobuf:"varint,5,opt,name=optional_domain_entity_type,json=optionalDomainEntityType,proto3" json:"optional_domain_entity_type,omitempty"` // Optional domain entity type
Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation
EntityIds []string `protobuf:"bytes,7,rep,name=entity_ids,json=entityIds,proto3" json:"entity_ids,omitempty"` // EntityIDs
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
state protoimpl.MessageState `protogen:"open.v1"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id
PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id
EntityType uint32 `protobuf:"varint,3,opt,name=entity_type,json=entityType,proto3" json:"entity_type,omitempty"` // Entity type
OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id
Operation uint32 `protobuf:"varint,6,opt,name=operation,proto3" json:"operation,omitempty"` // Operation
EntityId string `protobuf:"bytes,7,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` // EntityID
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthZPatReq) Reset() {
@@ -292,9 +291,9 @@ func (x *AuthZPatReq) GetPatId() string {
return ""
}
func (x *AuthZPatReq) GetPlatformEntityType() uint32 {
func (x *AuthZPatReq) GetEntityType() uint32 {
if x != nil {
return x.PlatformEntityType
return x.EntityType
}
return 0
}
@@ -306,13 +305,6 @@ func (x *AuthZPatReq) GetOptionalDomainId() string {
return ""
}
func (x *AuthZPatReq) GetOptionalDomainEntityType() uint32 {
if x != nil {
return x.OptionalDomainEntityType
}
return 0
}
func (x *AuthZPatReq) GetOperation() uint32 {
if x != nil {
return x.Operation
@@ -320,11 +312,11 @@ func (x *AuthZPatReq) GetOperation() uint32 {
return 0
}
func (x *AuthZPatReq) GetEntityIds() []string {
func (x *AuthZPatReq) GetEntityId() string {
if x != nil {
return x.EntityIds
return x.EntityId
}
return nil
return ""
}
type AuthZRes struct {
@@ -409,47 +401,42 @@ var file_auth_v1_auth_proto_rawDesc = []byte{
0x06, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f,
0x62, 0x6a, 0x65, 0x63, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x62, 0x6a, 0x65, 0x63, 0x74, 0x5f,
0x74, 0x79, 0x70, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6f, 0x62, 0x6a, 0x65,
0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0x99, 0x02, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a,
0x63, 0x74, 0x54, 0x79, 0x70, 0x65, 0x22, 0xc7, 0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x5a,
0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x12, 0x17, 0x0a, 0x07, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69,
0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x64, 0x12,
0x15, 0x0a, 0x06, 0x70, 0x61, 0x74, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52,
0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x30, 0x0a, 0x14, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f,
0x72, 0x6d, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03,
0x20, 0x01, 0x28, 0x0d, 0x52, 0x12, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x45, 0x6e,
0x74, 0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04,
0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f,
0x6d, 0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x3d, 0x0a, 0x1b, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e,
0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x18, 0x6f, 0x70, 0x74,
0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x45, 0x6e, 0x74, 0x69, 0x74,
0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69,
0x05, 0x70, 0x61, 0x74, 0x49, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79,
0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x65, 0x6e, 0x74,
0x69, 0x74, 0x79, 0x54, 0x79, 0x70, 0x65, 0x12, 0x2c, 0x0a, 0x12, 0x6f, 0x70, 0x74, 0x69, 0x6f,
0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x10, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x44, 0x6f, 0x6d,
0x61, 0x69, 0x6e, 0x49, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69,
0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74,
0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64,
0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49,
0x64, 0x73, 0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e,
0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e,
0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0,
0x01, 0x0a, 0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33,
0x0a, 0x09, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75,
0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11,
0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65,
0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65,
0x50, 0x41, 0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75,
0x74, 0x68, 0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36,
0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11,
0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65,
0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68,
0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e,
0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61,
0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22,
0x00, 0x42, 0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f,
0x61, 0x62, 0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f,
0x61, 0x70, 0x69, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x76, 0x31,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x5f, 0x69, 0x64,
0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x49, 0x64,
0x22, 0x3a, 0x0a, 0x08, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x12, 0x1e, 0x0a, 0x0a,
0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08,
0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, 0x0e, 0x0a, 0x02,
0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x32, 0xf0, 0x01, 0x0a,
0x0b, 0x41, 0x75, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x09,
0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68,
0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61,
0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22,
0x00, 0x12, 0x39, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x50, 0x41,
0x54, 0x12, 0x14, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68,
0x5a, 0x50, 0x61, 0x74, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76,
0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x5a, 0x52, 0x65, 0x73, 0x22, 0x00, 0x12, 0x36, 0x0a, 0x0c,
0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x12, 0x11, 0x2e, 0x61,
0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a,
0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52,
0x65, 0x73, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0f, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69,
0x63, 0x61, 0x74, 0x65, 0x50, 0x41, 0x54, 0x12, 0x11, 0x2e, 0x61, 0x75, 0x74, 0x68, 0x2e, 0x76,
0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x71, 0x1a, 0x11, 0x2e, 0x61, 0x75, 0x74,
0x68, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x75, 0x74, 0x68, 0x4e, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42,
0x2d, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x61, 0x62,
0x73, 0x6d, 0x61, 0x63, 0x68, 0x2f, 0x73, 0x75, 0x70, 0x65, 0x72, 0x6d, 0x71, 0x2f, 0x61, 0x70,
0x69, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x2f, 0x76, 0x31, 0x62, 0x06,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
+12 -14
View File
@@ -151,13 +151,12 @@ func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.A
defer cancel()
res, err := client.authorizePAT(ctx, authPATReq{
userID: req.GetUserId(),
patID: req.GetPatId(),
platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()),
operation: auth.OperationType(req.GetOperation()),
entityIDs: req.GetEntityIds(),
userID: req.GetUserId(),
patID: req.GetPatId(),
entityType: auth.EntityType(req.GetEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
operation: auth.Operation(req.GetOperation()),
entityID: req.GetEntityId(),
})
if err != nil {
return &grpcAuthV1.AuthZRes{}, grpcapi.DecodeError(err)
@@ -170,12 +169,11 @@ func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.A
func encodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(authPATReq)
return &grpcAuthV1.AuthZPatReq{
UserId: req.userID,
PatId: req.patID,
PlatformEntityType: uint32(req.platformEntityType),
OptionalDomainId: req.optionalDomainID,
OptionalDomainEntityType: uint32(req.optionalDomainEntityType),
Operation: uint32(req.operation),
EntityIds: req.entityIDs,
UserId: req.userID,
PatId: req.patID,
EntityType: uint32(req.entityType),
OptionalDomainId: req.optionalDomainID,
Operation: uint32(req.operation),
EntityId: req.entityID,
}, nil
}
+1 -1
View File
@@ -74,7 +74,7 @@ func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return authorizeRes{}, err
}
err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.platformEntityType, req.optionalDomainID, req.optionalDomainEntityType, req.operation, req.entityIDs...)
err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.entityType, req.optionalDomainID, req.operation, req.entityID)
if err != nil {
return authorizeRes{authorized: false}, err
}
+22 -26
View File
@@ -301,13 +301,12 @@ func TestAuthorizePAT(t *testing.T) {
desc: "authorize user with authorized token",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
UserId: id,
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
@@ -316,13 +315,12 @@ func TestAuthorizePAT(t *testing.T) {
desc: "authorize user with unauthorized token",
token: inValidPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
UserId: id,
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
@@ -331,12 +329,11 @@ func TestAuthorizePAT(t *testing.T) {
desc: "authorize user with missing user id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingUserID,
@@ -345,12 +342,11 @@ func TestAuthorizePAT(t *testing.T) {
desc: "authorize user with missing pat id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
UserId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPATID,
+6 -7
View File
@@ -52,13 +52,12 @@ func (req authReq) validate() error {
}
type authPATReq struct {
userID string
patID string
platformEntityType auth.PlatformEntityType
optionalDomainID string
optionalDomainEntityType auth.DomainEntityType
operation auth.OperationType
entityIDs []string
userID string
patID string
entityType auth.EntityType
optionalDomainID string
operation auth.Operation
entityID string
}
func (req authPATReq) validate() error {
+6 -7
View File
@@ -112,13 +112,12 @@ func encodeAuthorizeResponse(_ context.Context, grpcRes interface{}) (interface{
func decodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcAuthV1.AuthZPatReq)
return authPATReq{
userID: req.GetUserId(),
patID: req.GetPatId(),
platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()),
operation: auth.OperationType(req.GetOperation()),
entityIDs: req.GetEntityIds(),
userID: req.GetUserId(),
patID: req.GetPatId(),
entityType: auth.EntityType(req.GetEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
operation: auth.Operation(req.GetOperation()),
entityID: req.GetEntityId(),
}, nil
}
+2 -1
View File
@@ -70,13 +70,14 @@ func (tr testRequest) make() (*http.Response, error) {
func newService() (auth.Service, *mocks.KeyRepository) {
krepo := new(mocks.KeyRepository)
pRepo := new(mocks.PATSRepository)
cache := new(mocks.Cache)
hash := new(mocks.Hasher)
idProvider := uuid.NewMock()
pService := new(policymocks.Service)
pEvaluator := new(policymocks.Evaluator)
t := jwt.New([]byte(secret))
return auth.New(krepo, pRepo, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), krepo
}
func newServer(svc auth.Service) *httptest.Server {
+48 -13
View File
@@ -17,7 +17,7 @@ func createPATEndpoint(svc auth.Service) endpoint.Endpoint {
return nil, err
}
pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration, req.Scope)
pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration)
if err != nil {
return nil, err
}
@@ -140,48 +140,83 @@ func revokePATSecretEndpoint(svc auth.Service) endpoint.Endpoint {
}
}
func addPATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
func clearAllPATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(addPatScopeEntryReq)
req := request.(clearAllPATReq)
if err := req.validate(); err != nil {
return nil, err
}
scope, err := svc.AddPATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...)
if err := svc.RemoveAllPAT(ctx, req.token); err != nil {
return nil, err
}
return clearAllRes{}, nil
}
}
func addScopeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(addScopeReq)
if err := req.validate(); err != nil {
return nil, err
}
err := svc.AddScope(ctx, req.token, req.id, req.Scopes)
if err != nil {
return nil, err
}
return addPatScopeEntryRes{scope}, nil
return scopeRes{}, nil
}
}
func removePATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
func removeScopeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(removePatScopeEntryReq)
req := request.(removeScopeReq)
if err := req.validate(); err != nil {
return nil, err
}
scope, err := svc.RemovePATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...)
err := svc.RemoveScope(ctx, req.token, req.id, req.ScopesID...)
if err != nil {
return nil, err
}
return removePatScopeEntryRes{scope}, nil
return scopeRes{}, nil
}
}
func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
func clearAllScopeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(clearAllScopeEntryReq)
req := request.(clearAllScopeReq)
if err := req.validate(); err != nil {
return nil, err
}
if err := svc.ClearPATAllScopeEntry(ctx, req.token, req.id); err != nil {
if err := svc.RemovePATAllScope(ctx, req.token, req.id); err != nil {
return nil, err
}
return clearAllScopeEntryRes{}, nil
return clearAllRes{}, nil
}
}
func listScopesEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(listScopesReq)
if err := req.validate(); err != nil {
return nil, err
}
pm := auth.ScopesPageMeta{
Limit: req.limit,
Offset: req.offset,
PatID: req.patID,
}
scopesPage, err := svc.ListScopes(ctx, req.token, pm)
if err != nil {
return nil, err
}
return listScopeRes{scopesPage}, nil
}
}
+78 -97
View File
@@ -10,6 +10,7 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
)
type createPatReq struct {
@@ -17,15 +18,13 @@ type createPatReq struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Duration time.Duration `json:"duration,omitempty"`
Scope auth.Scope `json:"scope,omitempty"`
}
func (cpr *createPatReq) UnmarshalJSON(data []byte) error {
var temp struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Duration string `json:"duration,omitempty"`
Scope auth.Scope `json:"scope,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Duration string `json:"duration,omitempty"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
@@ -37,7 +36,6 @@ func (cpr *createPatReq) UnmarshalJSON(data []byte) error {
cpr.Name = temp.Name
cpr.Description = temp.Description
cpr.Duration = duration
cpr.Scope = temp.Scope
return nil
}
@@ -63,7 +61,7 @@ func (req retrievePatReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
return nil
}
@@ -79,7 +77,7 @@ func (req updatePatNameReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
if strings.TrimSpace(req.Name) == "" {
return apiutil.ErrMissingName
@@ -98,7 +96,7 @@ func (req updatePatDescriptionReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
if strings.TrimSpace(req.Description) == "" {
return apiutil.ErrMissingDescription
@@ -129,7 +127,7 @@ func (req deletePatReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
return nil
}
@@ -161,7 +159,7 @@ func (req resetPatSecretReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
return nil
}
@@ -176,128 +174,111 @@ func (req revokePatSecretReq) validate() (err error) {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
return nil
}
type addPatScopeEntryReq struct {
token string
id string
PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"`
Operation auth.OperationType `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
type clearAllPATReq struct {
token string
}
func (apser *addPatScopeEntryReq) UnmarshalJSON(data []byte) error {
var temp struct {
PlatformEntityType string `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"`
Operation string `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
func (req clearAllPATReq) validate() error {
if req.token == "" {
return apiutil.ErrBearerToken
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType)
if err != nil {
return err
}
odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType)
if err != nil {
return err
}
op, err := auth.ParseOperationType(temp.Operation)
if err != nil {
return err
}
apser.PlatformEntityType = pet
apser.OptionalDomainID = temp.OptionalDomainID
apser.OptionalDomainEntityType = odt
apser.Operation = op
apser.EntityIDs = temp.EntityIDs
return nil
}
func (req addPatScopeEntryReq) validate() (err error) {
type addScopeReq struct {
token string
id string
Scopes []auth.Scope `json:"scopes,omitempty"`
}
func (aser *addScopeReq) UnmarshalJSON(data []byte) error {
type Alias addScopeReq
aux := &struct {
*Alias
}{
Alias: (*Alias)(aser),
}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
return nil
}
func (req addScopeReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
if len(req.Scopes) == 0 {
return apiutil.ErrValidation
}
for _, scope := range req.Scopes {
if err := scope.Validate(); err != nil {
return errors.Wrap(apiutil.ErrValidation, err)
}
}
return nil
}
type removePatScopeEntryReq struct {
token string
id string
PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"`
Operation auth.OperationType `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
type removeScopeReq struct {
token string
id string
ScopesID []string `json:"scopes_id,omitempty"`
}
func (rpser *removePatScopeEntryReq) UnmarshalJSON(data []byte) error {
var temp struct {
PlatformEntityType string `json:"platform_entity_type,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"`
Operation string `json:"operation,omitempty"`
EntityIDs []string `json:"entity_ids,omitempty"`
}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType)
if err != nil {
return err
}
odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType)
if err != nil {
return err
}
op, err := auth.ParseOperationType(temp.Operation)
if err != nil {
return err
}
rpser.PlatformEntityType = pet
rpser.OptionalDomainID = temp.OptionalDomainID
rpser.OptionalDomainEntityType = odt
rpser.Operation = op
rpser.EntityIDs = temp.EntityIDs
return nil
}
func (req removePatScopeEntryReq) validate() (err error) {
func (req removeScopeReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
if len(req.ScopesID) == 0 {
return apiutil.ErrValidation
}
return nil
}
type clearAllScopeEntryReq struct {
type clearAllScopeReq struct {
token string
id string
}
func (req clearAllScopeEntryReq) validate() (err error) {
func (req clearAllScopeReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.id == "" {
return apiutil.ErrMissingID
return apiutil.ErrMissingPATID
}
return nil
}
type listScopesReq struct {
token string
offset uint64
limit uint64
patID string
}
func (req listScopesReq) validate() (err error) {
if req.token == "" {
return apiutil.ErrBearerToken
}
if req.patID == "" {
return apiutil.ErrMissingPATID
}
return nil
}
+42 -45
View File
@@ -18,13 +18,12 @@ var (
_ supermq.Response = (*deletePatRes)(nil)
_ supermq.Response = (*resetPatSecretRes)(nil)
_ supermq.Response = (*revokePatSecretRes)(nil)
_ supermq.Response = (*addPatScopeEntryRes)(nil)
_ supermq.Response = (*removePatScopeEntryRes)(nil)
_ supermq.Response = (*clearAllScopeEntryRes)(nil)
_ supermq.Response = (*scopeRes)(nil)
_ supermq.Response = (*clearAllRes)(nil)
)
type createPatRes struct {
auth.PAT
auth.PAT `json:",inline"`
}
func (res createPatRes) Code() int {
@@ -40,7 +39,7 @@ func (res createPatRes) Empty() bool {
}
type retrievePatRes struct {
auth.PAT
auth.PAT `json:",inline"`
}
func (res retrievePatRes) Code() int {
@@ -56,7 +55,7 @@ func (res retrievePatRes) Empty() bool {
}
type updatePatNameRes struct {
auth.PAT
auth.PAT `json:",inline"`
}
func (res updatePatNameRes) Code() int {
@@ -72,7 +71,7 @@ func (res updatePatNameRes) Empty() bool {
}
type updatePatDescriptionRes struct {
auth.PAT
auth.PAT `json:",inline"`
}
func (res updatePatDescriptionRes) Code() int {
@@ -88,7 +87,7 @@ func (res updatePatDescriptionRes) Empty() bool {
}
type listPatsRes struct {
auth.PATSPage
auth.PATSPage `json:",inline"`
}
func (res listPatsRes) Code() int {
@@ -118,7 +117,7 @@ func (res deletePatRes) Empty() bool {
}
type resetPatSecretRes struct {
auth.PAT
auth.PAT `json:",inline"`
}
func (res resetPatSecretRes) Code() int {
@@ -147,48 +146,46 @@ func (res revokePatSecretRes) Empty() bool {
return true
}
type addPatScopeEntryRes struct {
auth.Scope
}
type scopeRes struct{}
func (res addPatScopeEntryRes) Code() int {
func (res scopeRes) Code() int {
return http.StatusOK
}
func (res addPatScopeEntryRes) Headers() map[string]string {
func (res scopeRes) Headers() map[string]string {
return map[string]string{}
}
func (res addPatScopeEntryRes) Empty() bool {
return false
}
type removePatScopeEntryRes struct {
auth.Scope
}
func (res removePatScopeEntryRes) Code() int {
return http.StatusOK
}
func (res removePatScopeEntryRes) Headers() map[string]string {
return map[string]string{}
}
func (res removePatScopeEntryRes) Empty() bool {
return false
}
type clearAllScopeEntryRes struct{}
func (res clearAllScopeEntryRes) Code() int {
return http.StatusOK
}
func (res clearAllScopeEntryRes) Headers() map[string]string {
return map[string]string{}
}
func (res clearAllScopeEntryRes) Empty() bool {
func (res scopeRes) Empty() bool {
return true
}
type clearAllRes struct{}
func (res clearAllRes) Code() int {
return http.StatusOK
}
func (res clearAllRes) Headers() map[string]string {
return map[string]string{}
}
func (res clearAllRes) Empty() bool {
return true
}
type listScopeRes struct {
auth.ScopesPage `json:",inline"`
}
func (res listScopeRes) Code() int {
return http.StatusOK
}
func (res listScopeRes) Headers() map[string]string {
return map[string]string{}
}
func (res listScopeRes) Empty() bool {
return false
}
+79 -34
View File
@@ -44,6 +44,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
opts...,
).ServeHTTP)
r.Delete("/", kithttp.NewServer(
clearAllPATEndpoint(svc),
decodeClearAllPATRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Route("/{id}", func(r chi.Router) {
r.Get("/", kithttp.NewServer(
retrievePATEndpoint(svc),
@@ -91,22 +98,29 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
r.Route("/scope", func(r chi.Router) {
r.Patch("/add", kithttp.NewServer(
addPATScopeEntryEndpoint(svc),
decodeAddPATScopeEntryRequest,
addScopeEndpoint(svc),
decodeAddScopeRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Get("/", kithttp.NewServer(
listScopesEndpoint(svc),
decodeListScopeRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Patch("/remove", kithttp.NewServer(
removePATScopeEntryEndpoint(svc),
decodeRemovePATScopeEntryRequest,
removeScopeEndpoint(svc),
decodeRemoveScopeRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
r.Delete("/", kithttp.NewServer(
clearPATAllScopeEntryEndpoint(svc),
decodeClearPATAllScopeEntryRequest,
clearAllScopeEndpoint(svc),
decodeClearAllScopeRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
@@ -243,7 +257,18 @@ func decodeRevokePATSecretRequest(_ context.Context, r *http.Request) (interface
}, nil
}
func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
func decodeClearAllPATRequest(_ context.Context, r *http.Request) (interface{}, error) {
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
return clearAllPATReq{
token: token,
}, nil
}
func decodeAddScopeRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
@@ -253,7 +278,51 @@ func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interfac
return nil, apiutil.ErrUnsupportedTokenType
}
req := addPatScopeEntryReq{
req := addScopeReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeListScopeRequest(_ context.Context, r *http.Request) (interface{}, error) {
l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset)
if err != nil {
return nil, errors.Wrap(apiutil.ErrValidation, err)
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := listScopesReq{
token: token,
limit: l,
offset: o,
patID: chi.URLParam(r, "id"),
}
return req, nil
}
func decodeRemoveScopeRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := removeScopeReq{
token: token,
id: chi.URLParam(r, "id"),
}
@@ -263,37 +332,13 @@ func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interfac
return req, nil
}
func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
func decodeClearAllScopeRequest(_ context.Context, r *http.Request) (interface{}, error) {
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
req := removePatScopeEntryReq{
token: token,
id: chi.URLParam(r, "id"),
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}
token := apiutil.ExtractBearerToken(r)
if strings.HasPrefix(token, patPrefix) {
return nil, apiutil.ErrUnsupportedTokenType
}
return clearAllScopeEntryReq{
return clearAllScopeReq{
token: token,
id: chi.URLParam(r, "id"),
}, nil
+64 -51
View File
@@ -125,14 +125,13 @@ func (lm *loggingMiddleware) Authorize(ctx context.Context, pr policies.Policy)
return lm.svc.Authorize(ctx, pr)
}
func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (pa auth.PAT, err error) {
func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (pa auth.PAT, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("name", name),
slog.String("description", description),
slog.String("pat_duration", duration.String()),
slog.String("scope", scope.String()),
}
if err != nil {
args = append(args, slog.Any("error", err))
@@ -141,7 +140,7 @@ func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, descrip
}
lm.logger.Info("Create PAT completed successfully", args...)
}(time.Now())
return lm.svc.CreatePAT(ctx, token, name, description, duration, scope)
return lm.svc.CreatePAT(ctx, token, name, description, duration)
}
func (lm *loggingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (pa auth.PAT, err error) {
@@ -211,6 +210,24 @@ func (lm *loggingMiddleware) ListPATS(ctx context.Context, token string, pm auth
return lm.svc.ListPATS(ctx, token, pm)
}
func (lm *loggingMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (pp auth.ScopesPage, err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.Uint64("limit", pm.Limit),
slog.Uint64("offset", pm.Offset),
slog.String("pat_id", pm.PatID),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("List Scopes failed", args...)
return
}
lm.logger.Info("List Scopes completed successfully", args...)
}(time.Now())
return lm.svc.ListScopes(ctx, token, pm)
}
func (lm *loggingMiddleware) DeletePAT(ctx context.Context, token, patID string) (err error) {
defer func(begin time.Time) {
args := []any{
@@ -260,37 +277,56 @@ func (lm *loggingMiddleware) RevokePATSecret(ctx context.Context, token, patID s
return lm.svc.RevokePATSecret(ctx, token, patID)
}
func (lm *loggingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) {
func (lm *loggingMiddleware) RemoveAllPAT(ctx context.Context, token string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Add entry to PAT scope failed", args...)
lm.logger.Warn("Remove all PAT failed", args...)
return
}
lm.logger.Info("Add entry to PAT scope completed successfully", args...)
lm.logger.Info("Remove all of PAT completed successfully", args...)
}(time.Now())
return lm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return lm.svc.RemoveAllPAT(ctx, token)
}
func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) {
func (lm *loggingMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) (err error) {
defer func(begin time.Time) {
var groupArgs []any
for _, s := range scopes {
groupArgs = append(groupArgs, slog.String("entity_type", s.EntityType.String()))
groupArgs = append(groupArgs, slog.String("optional_domain_id", s.OptionalDomainID))
groupArgs = append(groupArgs, slog.String("operation", s.Operation.String()))
groupArgs = append(groupArgs, slog.String("entity_id", s.EntityID))
}
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
slog.Group("scope", groupArgs...),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Add PAT scope failed", args...)
return
}
lm.logger.Info("Add PAT scope completed successfully", args...)
}(time.Now())
return lm.svc.AddScope(ctx, token, patID, scopes)
}
func (lm *loggingMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) (err error) {
defer func(begin time.Time) {
var groupArgs []any
for _, s := range scopesID {
groupArgs = append(groupArgs, slog.String("scope_id", s))
}
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("pat_id", patID),
slog.Group("scope", groupArgs...),
}
if err != nil {
args = append(args, slog.Any("error", err))
@@ -299,10 +335,10 @@ func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, pat
}
lm.logger.Info("Remove entry from PAT scope completed successfully", args...)
}(time.Now())
return lm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return lm.svc.RemoveScope(ctx, token, patID, scopesID...)
}
func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) (err error) {
func (lm *loggingMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
@@ -310,12 +346,12 @@ func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, p
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Clear all entry from PAT scope failed", args...)
lm.logger.Warn("Remove all scopes from PAT failed", args...)
return
}
lm.logger.Info("Clear all entry from PAT scope completed successfully", args...)
lm.logger.Info("Remove all scopes from PAT completed successfully", args...)
}(time.Now())
return lm.svc.ClearPATAllScopeEntry(ctx, token, patID)
return lm.svc.RemovePATAllScope(ctx, token, patID)
}
func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) {
@@ -333,15 +369,14 @@ func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (p
return lm.svc.IdentifyPAT(ctx, paToken)
}
func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) {
func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("entity_type", entityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
slog.String("entities", entityID),
}
if err != nil {
args = append(args, slog.Any("error", err))
@@ -350,27 +385,5 @@ func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID str
}
lm.logger.Info("Authorize PAT completed successfully", args...)
}(time.Now())
return lm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (lm *loggingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("user_id", userID),
slog.String("pat_id", patID),
slog.String("platform_entity_type", platformEntityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("optional_domain_entity_type", optionalDomainEntityType.String()),
slog.String("operation", operation.String()),
slog.Any("entities", entityIDs),
}
if err != nil {
args = append(args, slog.Any("error", err))
lm.logger.Warn("Check PAT failed complete successfully", args...)
return
}
lm.logger.Info("Check PAT completed successfully", args...)
}(time.Now())
return lm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return lm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
}
+32 -24
View File
@@ -75,12 +75,12 @@ func (ms *metricsMiddleware) Authorize(ctx context.Context, pr policies.Policy)
return ms.svc.Authorize(ctx, pr)
}
func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (auth.PAT, error) {
defer func(begin time.Time) {
ms.counter.With("method", "create_pat").Add(1)
ms.latency.With("method", "create_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.CreatePAT(ctx, token, name, description, duration, scope)
return ms.svc.CreatePAT(ctx, token, name, description, duration)
}
func (ms *metricsMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) {
@@ -115,6 +115,14 @@ func (ms *metricsMiddleware) ListPATS(ctx context.Context, token string, pm auth
return ms.svc.ListPATS(ctx, token, pm)
}
func (ms *metricsMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
defer func(begin time.Time) {
ms.counter.With("method", "list_scopes").Add(1)
ms.latency.With("method", "list_scopes").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ListScopes(ctx, token, pm)
}
func (ms *metricsMiddleware) DeletePAT(ctx context.Context, token, patID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "delete_pat").Add(1)
@@ -139,28 +147,36 @@ func (ms *metricsMiddleware) RevokePATSecret(ctx context.Context, token, patID s
return ms.svc.RevokePATSecret(ctx, token, patID)
}
func (ms *metricsMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
func (ms *metricsMiddleware) RemoveAllPAT(ctx context.Context, token string) error {
defer func(begin time.Time) {
ms.counter.With("method", "add_pat_scope_entry").Add(1)
ms.latency.With("method", "add_pat_scope_entry").Observe(time.Since(begin).Seconds())
ms.counter.With("method", "clear_all_pat").Add(1)
ms.latency.With("method", "clear_all_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return ms.svc.RemoveAllPAT(ctx, token)
}
func (ms *metricsMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
func (ms *metricsMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) error {
defer func(begin time.Time) {
ms.counter.With("method", "remove_pat_scope_entry").Add(1)
ms.latency.With("method", "remove_pat_scope_entry").Observe(time.Since(begin).Seconds())
ms.counter.With("method", "add_pat_scope").Add(1)
ms.latency.With("method", "add_pat_scope").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return ms.svc.AddScope(ctx, token, patID, scopes)
}
func (ms *metricsMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
func (ms *metricsMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) error {
defer func(begin time.Time) {
ms.counter.With("method", "clear_pat_all_scope_entry").Add(1)
ms.latency.With("method", "clear_pat_all_scope_entry").Observe(time.Since(begin).Seconds())
ms.counter.With("method", "remove_pat_scope").Add(1)
ms.latency.With("method", "remove_pat_scope").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.ClearPATAllScopeEntry(ctx, token, patID)
return ms.svc.RemoveScope(ctx, token, patID, scopesID...)
}
func (ms *metricsMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "clear_pat_all_scope").Add(1)
ms.latency.With("method", "clear_pat_all_scope").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.RemovePATAllScope(ctx, token, patID)
}
func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
@@ -171,18 +187,10 @@ func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a
return ms.svc.IdentifyPAT(ctx, paToken)
}
func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "authorize_pat").Add(1)
ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (ms *metricsMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
defer func(begin time.Time) {
ms.counter.With("method", "check_pat").Add(1)
ms.latency.With("method", "check_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return ms.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
}
-6
View File
@@ -1,6 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package bolt contains PAT repository implementations using
// bolt as the underlying database.
package bolt
-21
View File
@@ -1,21 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package bolt contains PAT repository implementations using
// bolt as the underlying database.
package bolt
import (
"github.com/absmach/supermq/pkg/errors"
bolt "go.etcd.io/bbolt"
)
var errInit = errors.New("failed to initialize BoltDB")
func Init(tx *bolt.Tx, bucket string) error {
_, err := tx.CreateBucketIfNotExists([]byte(bucket))
if err != nil {
return errors.Wrap(errInit, err)
}
return nil
}
-812
View File
@@ -1,812 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bolt
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
bolt "go.etcd.io/bbolt"
)
const (
idKey = "id"
userKey = "user"
nameKey = "name"
descriptionKey = "description"
secretKey = "secret_key"
scopeKey = "scope"
issuedAtKey = "issued_at"
expiresAtKey = "expires_at"
updatedAtKey = "updated_at"
lastUsedAtKey = "last_used_at"
revokedKey = "revoked"
revokedAtKey = "revoked_at"
platformEntitiesKey = "platform_entities"
patKey = "pat"
keySeparator = ":"
anyID = "*"
)
var (
activateValue = []byte{0x00}
revokedValue = []byte{0x01}
entityValue = []byte{0x02}
anyIDValue = []byte{0x03}
selectedIDsValue = []byte{0x04}
)
type patRepo struct {
db *bolt.DB
bucketName string
}
// NewPATSRepository instantiates a bolt
// implementation of PAT repository.
func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository {
return &patRepo{
db: db,
bucketName: bucketName,
}
}
func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error {
idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID)
kv, err := patToKeyValue(pat)
if err != nil {
return err
}
return pr.db.Update(func(tx *bolt.Tx) error {
rootBucket, err := pr.retrieveRootBucket(tx)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
b, err := pr.createUserBucket(rootBucket, pat.User)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
for key, value := range kv {
fullKey := []byte(pat.ID + keySeparator + key)
if err := b.Put(fullKey, value); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
}
if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
return nil
})
}
func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return err
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) {
revoked := true
expired := false
keySecret := patID + keySeparator + secretKey
keyRevoked := patID + keySeparator + revokedKey
keyExpiresAt := patID + keySeparator + expiresAtKey
var secretHash string
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return err
}
secretHash = string(b.Get([]byte(keySecret)))
revoked = bytesToBoolean(b.Get([]byte(keyRevoked)))
expiresAt := bytesToTime(b.Get([]byte(keyExpiresAt)))
expired = time.Now().After(expiresAt)
return nil
}); err != nil {
return "", true, true, err
}
return secretHash, revoked, expired, nil
}
func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) {
return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name))
}
func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) {
return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description))
}
func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
prefix := []byte(userID + keySeparator + patKey + keySeparator)
patIDs := []string{}
if err := pr.db.View(func(tx *bolt.Tx) error {
b, err := pr.retrieveRootBucket(tx)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
if v != nil {
patIDs = append(patIDs, string(v))
}
}
return nil
}); err != nil {
return auth.PATSPage{}, err
}
total := len(patIDs)
var pats []auth.PAT
patsPage := auth.PATSPage{
Total: uint64(total),
Limit: pm.Limit,
Offset: pm.Offset,
PATS: pats,
}
if int(pm.Offset) >= total {
return patsPage, nil
}
aLimit := pm.Limit
if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit {
aLimit = uint64(rLimit)
}
for i := pm.Offset; i < pm.Offset+aLimit; i++ {
if int(i) < total {
pat, err := pr.Retrieve(ctx, userID, patIDs[i])
if err != nil {
return patsPage, err
}
patsPage.PATS = append(patsPage.PATS, pat)
}
}
return patsPage, nil
}
func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error {
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error {
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error {
prefix := []byte(patID + keySeparator)
idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity)
if err != nil {
return err
}
c := b.Cursor()
for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() {
if err := b.Delete(k); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
rb, err := pr.retrieveRootBucket(tx)
if err != nil {
return err
}
if err := rb.Delete(idxKey); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}); err != nil {
return err
}
return nil
}
func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
prefix := []byte(patID + keySeparator + scopeKey)
rKV := make(map[string][]byte)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity)
if err != nil {
return err
}
kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return err
}
for key, value := range kv {
fullKey := []byte(patID + keySeparator + key)
if err := b.Put(fullKey, value); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
rKV[string(k)] = v
}
return nil
}); err != nil {
return auth.Scope{}, err
}
return parseKeyValueToScope(rKV)
}
func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
if len(entityIDs) == 0 {
return auth.Scope{}, repoerr.ErrMalformedEntity
}
prefix := []byte(patID + keySeparator + scopeKey)
rKV := make(map[string][]byte)
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity)
if err != nil {
return err
}
kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return err
}
for key := range kv {
fullKey := []byte(patID + keySeparator + key)
if err := b.Delete(fullKey); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
rKV[string(k)] = v
}
return nil
}); err != nil {
return auth.Scope{}, err
}
return parseKeyValueToScope(rKV)
}
func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
return pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
rootKey := patID + keySeparator + srootKey
if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) {
return nil
}
for _, entity := range entityIDs {
value := b.Get([]byte(rootKey + keySeparator + entity))
if !bytes.Equal(value, entityValue) {
return repoerr.ErrNotFound
}
}
return nil
})
}
func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error {
return nil
}
func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) {
prefix := []byte(patID + keySeparator)
kv := map[string][]byte{}
if err := pr.db.Update(func(tx *bolt.Tx) error {
b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity)
if err != nil {
return err
}
if err := b.Put([]byte(patID+keySeparator+key), value); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
c := b.Cursor()
for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() {
kv[string(k)] = v
}
return nil
}); err != nil {
return auth.PAT{}, err
}
return keyValueToPAT(kv)
}
func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) {
userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID))
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err))
}
return userBucket, nil
}
func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) {
rootBucket, err := pr.retrieveRootBucket(tx)
if err != nil {
return nil, errors.Wrap(wrap, err)
}
vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID))
if vPatID == nil {
return nil, repoerr.ErrNotFound
}
userBucket := rootBucket.Bucket([]byte(userID))
if userBucket == nil {
return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID))
}
return userBucket, nil
}
func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) {
rootBucket := tx.Bucket([]byte(pr.bucketName))
if rootBucket == nil {
return nil, fmt.Errorf("bucket %s not found", pr.bucketName)
}
return rootBucket, nil
}
func patToKeyValue(pat auth.PAT) (map[string][]byte, error) {
kv := map[string][]byte{
idKey: []byte(pat.ID),
userKey: []byte(pat.User),
nameKey: []byte(pat.Name),
descriptionKey: []byte(pat.Description),
secretKey: []byte(pat.Secret),
issuedAtKey: timeToBytes(pat.IssuedAt),
expiresAtKey: timeToBytes(pat.ExpiresAt),
updatedAtKey: timeToBytes(pat.UpdatedAt),
lastUsedAtKey: timeToBytes(pat.LastUsedAt),
revokedKey: booleanToBytes(pat.Revoked),
revokedAtKey: timeToBytes(pat.RevokedAt),
}
scopeKV, err := scopeToKeyValue(pat.Scope)
if err != nil {
return nil, err
}
for k, v := range scopeKV {
kv[k] = v
}
return kv, nil
}
func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) {
kv := map[string][]byte{}
for opType, scopeValue := range scope.Users {
tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for opType, scopeValue := range scope.Dashboard {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDashBoardScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for opType, scopeValue := range scope.Messaging {
tempKV, err := scopeEntryToKeyValue(auth.PlatformMesagingScope, "", auth.DomainNullScope, opType, scopeValue.Values()...)
if err != nil {
return nil, err
}
for k, v := range tempKV {
kv[k] = v
}
}
for domainID, domainScope := range scope.Domains {
for opType, scopeValue := range domainScope.DomainManagement {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...)
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, err)
}
for k, v := range tempKV {
kv[k] = v
}
}
for entityType, scope := range domainScope.Entities {
for opType, scopeValue := range scope {
tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...)
if err != nil {
return nil, errors.Wrap(repoerr.ErrCreateEntity, err)
}
for k, v := range tempKV {
kv[k] = v
}
}
}
}
return kv, nil
}
func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) {
if len(entityIDs) == 0 {
return nil, repoerr.ErrMalformedEntity
}
rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
if err != nil {
return nil, err
}
if len(entityIDs) == 1 && entityIDs[0] == anyID {
return map[string][]byte{rootKey: anyIDValue}, nil
}
kv := map[string][]byte{rootKey: selectedIDsValue}
for _, entryID := range entityIDs {
if entryID == anyID {
return nil, repoerr.ErrMalformedEntity
}
kv[rootKey+keySeparator+entryID] = entityValue
}
return kv, nil
}
func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) {
op, err := operation.ValidString()
if err != nil {
return "", errors.Wrap(repoerr.ErrMalformedEntity, err)
}
var rootKey strings.Builder
rootKey.WriteString(scopeKey)
rootKey.WriteString(keySeparator)
rootKey.WriteString(platformEntityType.String())
rootKey.WriteString(keySeparator)
switch platformEntityType {
case auth.PlatformUsersScope:
rootKey.WriteString(op)
case auth.PlatformDashBoardScope:
rootKey.WriteString(op)
case auth.PlatformMesagingScope:
rootKey.WriteString(op)
case auth.PlatformDomainsScope:
if optionalDomainID == "" {
return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String())
}
odet, err := optionalDomainEntityType.ValidString()
if err != nil {
return "", errors.Wrap(repoerr.ErrMalformedEntity, err)
}
rootKey.WriteString(optionalDomainID)
rootKey.WriteString(keySeparator)
rootKey.WriteString(odet)
rootKey.WriteString(keySeparator)
rootKey.WriteString(op)
default:
return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String()))
}
return rootKey.String(), nil
}
func keyValueToBasicPAT(kv map[string][]byte) auth.PAT {
var pat auth.PAT
for k, v := range kv {
switch {
case strings.HasSuffix(k, keySeparator+idKey):
pat.ID = string(v)
case strings.HasSuffix(k, keySeparator+userKey):
pat.User = string(v)
case strings.HasSuffix(k, keySeparator+nameKey):
pat.Name = string(v)
case strings.HasSuffix(k, keySeparator+descriptionKey):
pat.Description = string(v)
case strings.HasSuffix(k, keySeparator+issuedAtKey):
pat.IssuedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+expiresAtKey):
pat.ExpiresAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+updatedAtKey):
pat.UpdatedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+lastUsedAtKey):
pat.LastUsedAt = bytesToTime(v)
case strings.HasSuffix(k, keySeparator+revokedKey):
pat.Revoked = bytesToBoolean(v)
case strings.HasSuffix(k, keySeparator+revokedAtKey):
pat.RevokedAt = bytesToTime(v)
}
}
return pat
}
func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) {
pat := keyValueToBasicPAT(kv)
scope, err := parseKeyValueToScope(kv)
if err != nil {
return auth.PAT{}, err
}
pat.Scope = scope
return pat, nil
}
func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) {
scope := auth.Scope{
Domains: make(map[string]auth.DomainScope),
}
for key, value := range kv {
if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 {
keyParts := strings.Split(key, keySeparator)
platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2])
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
switch platformEntityType {
case auth.PlatformUsersScope:
scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformDashBoardScope:
scope.Dashboard, err = parseOperation(platformEntityType, scope.Dashboard, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformMesagingScope:
scope.Messaging, err = parseOperation(platformEntityType, scope.Messaging, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
case auth.PlatformDomainsScope:
if len(keyParts) < 6 {
return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key)
}
domainID := keyParts[3]
if scope.Domains == nil {
scope.Domains = make(map[string]auth.DomainScope)
}
if _, ok := scope.Domains[domainID]; !ok {
scope.Domains[domainID] = auth.DomainScope{}
}
domainScope := scope.Domains[domainID]
entityType := keyParts[4]
switch entityType {
case auth.DomainManagementScope.String():
domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
default:
etype, err := auth.ParseDomainEntityType(entityType)
if err != nil {
return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err)
}
if domainScope.Entities == nil {
domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope)
}
if _, ok := domainScope.Entities[etype]; !ok {
domainScope.Entities[etype] = auth.OperationScope{}
}
entityOperationScope := domainScope.Entities[etype]
entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value)
if err != nil {
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
domainScope.Entities[etype] = entityOperationScope
}
scope.Domains[domainID] = domainScope
default:
return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()))
}
}
}
return scope, nil
}
func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) {
if opScope == nil {
opScope = make(map[auth.OperationType]auth.ScopeValue)
}
if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil {
return auth.OperationScope{}, err
}
switch string(value) {
case string(entityValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
entityID := keyParts[len(keyParts)-1]
if _, oValueExists := opScope[opType]; !oValueExists {
opScope[opType] = &auth.SelectedIDs{}
}
oValue := opScope[opType]
if err := oValue.AddValues(entityID); err != nil {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err)
}
opScope[opType] = oValue
case string(anyIDValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil {
if _, ok := oValue.(*auth.AnyIDs); !ok {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key)
}
}
opScope[opType] = &auth.AnyIDs{}
case string(selectedIDsValue):
opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1])
if err != nil {
return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
oValue, oValueExists := opScope[opType]
if oValueExists && oValue != nil {
if _, ok := oValue.(*auth.SelectedIDs); !ok {
return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key)
}
}
if !oValueExists {
opScope[opType] = &auth.SelectedIDs{}
}
default:
return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value)
}
return opScope, nil
}
func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error {
expectedKeyPartsLength := 0
switch string(value) {
case string(entityValue):
switch platformEntityType {
case auth.PlatformDomainsScope:
expectedKeyPartsLength = 7
case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope:
expectedKeyPartsLength = 5
default:
return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())
}
case string(selectedIDsValue), string(anyIDValue):
switch platformEntityType {
case auth.PlatformDomainsScope:
expectedKeyPartsLength = 6
case auth.PlatformUsersScope, auth.PlatformDashBoardScope, auth.PlatformMesagingScope:
expectedKeyPartsLength = 4
default:
return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())
}
default:
return fmt.Errorf("key %s have invalid value %v", key, value)
}
if len(keyParts) != expectedKeyPartsLength {
return fmt.Errorf("invalid scope key format: %s", key)
}
return nil
}
func timeToBytes(t time.Time) []byte {
timeBytes := make([]byte, 8)
binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix()))
return timeBytes
}
func bytesToTime(b []byte) time.Time {
timeAtSeconds := binary.BigEndian.Uint64(b)
return time.Unix(int64(timeAtSeconds), 0)
}
func booleanToBytes(b bool) []byte {
if b {
return []byte{1}
}
return []byte{0}
}
func bytesToBoolean(b []byte) bool {
if len(b) > 1 || b[0] != activateValue[0] {
return true
}
return false
}
+4
View File
@@ -0,0 +1,4 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache
+120
View File
@@ -0,0 +1,120 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package cache
import (
"context"
"fmt"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/redis/go-redis/v9"
)
type patCache struct {
client *redis.Client
duration time.Duration
}
func NewPatsCache(client *redis.Client, duration time.Duration) auth.Cache {
return &patCache{
client: client,
duration: duration,
}
}
func (pc *patCache) Save(ctx context.Context, userID string, scopes []auth.Scope) error {
for _, sc := range scopes {
key := generateKey(userID, sc.PatID, sc.OptionalDomainID, sc.EntityType, sc.Operation, sc.EntityID)
if err := pc.client.Set(ctx, key, sc.ID, pc.duration).Err(); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
}
return nil
}
func (pc *patCache) CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool {
exactKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainID, operation, entityID)
wildcardKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:*", userID, patID, entityType, operation, operation)
res, err := pc.client.Exists(ctx, exactKey, wildcardKey).Result()
if err != nil {
return false
}
return res > 0
}
func (pc *patCache) Remove(ctx context.Context, userID string, scopeIDs []string) error {
if len(scopeIDs) == 0 {
return repoerr.ErrRemoveEntity
}
pattern := fmt.Sprintf("pat:%s:*", userID)
iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator()
for iter.Next(ctx) {
key := iter.Val()
val, err := pc.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
continue
}
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
for _, scopeID := range scopeIDs {
if val == scopeID {
if err := pc.client.Del(ctx, key).Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
break
}
}
}
if err := iter.Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pc *patCache) RemoveUserAllScope(ctx context.Context, userID string) error {
pattern := fmt.Sprintf("pat:%s:*", userID)
iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator()
for iter.Next(ctx) {
if err := pc.client.Del(ctx, iter.Val()).Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
if err := iter.Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pc *patCache) RemoveAllScope(ctx context.Context, userID, patID string) error {
pattern := fmt.Sprintf("pat:%s:%s", userID, patID)
iter := pc.client.Scan(ctx, 0, pattern, 0).Iterator()
for iter.Next(ctx) {
if err := pc.client.Del(ctx, iter.Val()).Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
}
if err := iter.Err(); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func generateKey(userID, patID, optionalDomainId string, entityType auth.EntityType, operation auth.Operation, entityID string) string {
return fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainId, operation, entityID)
}
+122
View File
@@ -0,0 +1,122 @@
// Code generated by mockery v2.43.2. DO NOT EDIT.
// Copyright (c) Abstract Machines
package mocks
import (
context "context"
auth "github.com/absmach/supermq/auth"
mock "github.com/stretchr/testify/mock"
)
// Cache is an autogenerated mock type for the Cache type
type Cache struct {
mock.Mock
}
// CheckScope provides a mock function with given fields: ctx, userID, patID, optionalDomainID, entityType, operation, entityID
func (_m *Cache) CheckScope(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool {
ret := _m.Called(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for CheckScope")
}
var r0 bool
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, auth.EntityType, auth.Operation, string) bool); ok {
r0 = rf(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// Remove provides a mock function with given fields: ctx, userID, scopesID
func (_m *Cache) Remove(ctx context.Context, userID string, scopesID []string) error {
ret := _m.Called(ctx, userID, scopesID)
if len(ret) == 0 {
panic("no return value specified for Remove")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, []string) error); ok {
r0 = rf(ctx, userID, scopesID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveAllScope provides a mock function with given fields: ctx, userID, patID
func (_m *Cache) RemoveAllScope(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
if len(ret) == 0 {
panic("no return value specified for RemoveAllScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveUserAllScope provides a mock function with given fields: ctx, userID
func (_m *Cache) RemoveUserAllScope(ctx context.Context, userID string) error {
ret := _m.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for RemoveUserAllScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, userID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Save provides a mock function with given fields: ctx, userID, scopes
func (_m *Cache) Save(ctx context.Context, userID string, scopes []auth.Scope) error {
ret := _m.Called(ctx, userID, scopes)
if len(ret) == 0 {
panic("no return value specified for Save")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, []auth.Scope) error); ok {
r0 = rf(ctx, userID, scopes)
} else {
r0 = ret.Error(0)
}
return r0
}
// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewCache(t interface {
mock.TestingT
Cleanup(func())
}) *Cache {
mock := &Cache{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
+96 -109
View File
@@ -19,59 +19,35 @@ type PATS struct {
mock.Mock
}
// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// AddScope provides a mock function with given fields: ctx, token, patID, scopes
func (_m *PATS) AddScope(ctx context.Context, token string, patID string, scopes []auth.Scope) error {
ret := _m.Called(ctx, token, patID, scopes)
if len(ret) == 0 {
panic("no return value specified for AddPATScopeEntry")
panic("no return value specified for AddScope")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, []auth.Scope) error); ok {
r0 = rf(ctx, token, patID, scopes)
} else {
r0 = ret.Get(0).(auth.Scope)
r0 = ret.Error(0)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
return r0
}
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID
func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -79,52 +55,9 @@ func (_m *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, p
return r0
}
// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID
func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for ClearPATAllScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope
func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration, scope)
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration
func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration)
if len(ret) == 0 {
panic("no return value specified for CreatePAT")
@@ -132,17 +65,17 @@ func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, descri
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok {
r1 = rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration) error); ok {
r1 = rf(ctx, token, name, description, duration)
} else {
r1 = ret.Error(1)
}
@@ -224,34 +157,27 @@ func (_m *PATS) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta
return r0, r1
}
// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// ListScopes provides a mock function with given fields: ctx, token, pm
func (_m *PATS) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
ret := _m.Called(ctx, token, pm)
if len(ret) == 0 {
panic("no return value specified for RemovePATScopeEntry")
panic("no return value specified for ListScopes")
}
var r0 auth.Scope
var r0 auth.ScopesPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok {
return rf(ctx, token, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) auth.ScopesPage); ok {
r0 = rf(ctx, token, pm)
} else {
r0 = ret.Get(0).(auth.Scope)
r0 = ret.Get(0).(auth.ScopesPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(1).(func(context.Context, string, auth.ScopesPageMeta) error); ok {
r1 = rf(ctx, token, pm)
} else {
r1 = ret.Error(1)
}
@@ -259,6 +185,67 @@ func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID str
return r0, r1
}
// RemoveAllPAT provides a mock function with given fields: ctx, token
func (_m *PATS) RemoveAllPAT(ctx context.Context, token string) error {
ret := _m.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for RemoveAllPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, token)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemovePATAllScope provides a mock function with given fields: ctx, token, patID
func (_m *PATS) RemovePATAllScope(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for RemovePATAllScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveScope provides a mock function with given fields: ctx, token, patID, scopeIDs
func (_m *PATS) RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error {
_va := make([]interface{}, len(scopeIDs))
for _i := range scopeIDs {
_va[_i] = scopeIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemoveScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, ...string) error); ok {
r0 = rf(ctx, token, patID, scopeIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration
func (_m *PATS) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, duration)
+88 -76
View File
@@ -19,59 +19,35 @@ type PATSRepository struct {
mock.Mock
}
// AddScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) AddScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// AddScope provides a mock function with given fields: ctx, userID, scopes
func (_m *PATSRepository) AddScope(ctx context.Context, userID string, scopes []auth.Scope) error {
ret := _m.Called(ctx, userID, scopes)
if len(ret) == 0 {
panic("no return value specified for AddScopeEntry")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// CheckScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) CheckScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckScopeEntry")
panic("no return value specified for AddScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, []auth.Scope) error); ok {
r0 = rf(ctx, userID, scopes)
} else {
r0 = ret.Error(0)
}
return r0
}
// CheckScope provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID
func (_m *PATSRepository) CheckScope(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for CheckScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -115,17 +91,17 @@ func (_m *PATSRepository) Remove(ctx context.Context, userID string, patID strin
return r0
}
// RemoveAllScopeEntry provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string, patID string) error {
ret := _m.Called(ctx, userID, patID)
// RemoveAllPAT provides a mock function with given fields: ctx, userID
func (_m *PATSRepository) RemoveAllPAT(ctx context.Context, userID string) error {
ret := _m.Called(ctx, userID)
if len(ret) == 0 {
panic("no return value specified for RemoveAllScopeEntry")
panic("no return value specified for RemoveAllPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, userID, patID)
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, userID)
} else {
r0 = ret.Error(0)
}
@@ -133,39 +109,47 @@ func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string
return r0
}
// RemoveScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *PATSRepository) RemoveScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
// RemoveAllScope provides a mock function with given fields: ctx, patID
func (_m *PATSRepository) RemoveAllScope(ctx context.Context, patID string) error {
ret := _m.Called(ctx, patID)
if len(ret) == 0 {
panic("no return value specified for RemoveAllScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveScope provides a mock function with given fields: ctx, userID, scopesIDs
func (_m *PATSRepository) RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error {
_va := make([]interface{}, len(scopesIDs))
for _i := range scopesIDs {
_va[_i] = scopesIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, ctx, userID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemoveScopeEntry")
panic("no return value specified for RemoveScope")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, ...string) error); ok {
r0 = rf(ctx, userID, scopesIDs...)
} else {
r0 = ret.Get(0).(auth.Scope)
r0 = ret.Error(0)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
return r0
}
// Retrieve provides a mock function with given fields: ctx, userID, patID
@@ -224,6 +208,34 @@ func (_m *PATSRepository) RetrieveAll(ctx context.Context, userID string, pm aut
return r0, r1
}
// RetrieveScope provides a mock function with given fields: ctx, pm
func (_m *PATSRepository) RetrieveScope(ctx context.Context, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
ret := _m.Called(ctx, pm)
if len(ret) == 0 {
panic("no return value specified for RetrieveScope")
}
var r0 auth.ScopesPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok {
return rf(ctx, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, auth.ScopesPageMeta) auth.ScopesPage); ok {
r0 = rf(ctx, pm)
} else {
r0 = ret.Get(0).(auth.ScopesPage)
}
if rf, ok := ret.Get(1).(func(context.Context, auth.ScopesPageMeta) error); ok {
r1 = rf(ctx, pm)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// RetrieveSecretAndRevokeStatus provides a mock function with given fields: ctx, userID, patID
func (_m *PATSRepository) RetrieveSecretAndRevokeStatus(ctx context.Context, userID string, patID string) (string, bool, bool, error) {
ret := _m.Called(ctx, userID, patID)
+96 -109
View File
@@ -21,39 +21,22 @@ type Service struct {
mock.Mock
}
// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// AddScope provides a mock function with given fields: ctx, token, patID, scopes
func (_m *Service) AddScope(ctx context.Context, token string, patID string, scopes []auth.Scope) error {
ret := _m.Called(ctx, token, patID, scopes)
if len(ret) == 0 {
panic("no return value specified for AddPATScopeEntry")
panic("no return value specified for AddScope")
}
var r0 auth.Scope
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, []auth.Scope) error); ok {
r0 = rf(ctx, token, patID, scopes)
} else {
r0 = ret.Get(0).(auth.Scope)
r0 = ret.Error(0)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r1 = ret.Error(1)
}
return r0, r1
return r0
}
// Authorize provides a mock function with given fields: ctx, pr
@@ -74,24 +57,17 @@ func (_m *Service) Authorize(ctx context.Context, pr policies.Policy) error {
return r0
}
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// AuthorizePAT provides a mock function with given fields: ctx, userID, patID, entityType, optionalDomainID, operation, entityID
func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _m.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = rf(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -99,52 +75,9 @@ func (_m *Service) AuthorizePAT(ctx context.Context, userID string, patID string
return r0
}
// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CheckPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID
func (_m *Service) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for ClearPATAllScopeEntry")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope
func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration, scope)
// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration
func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, name, description, duration)
if len(ret) == 0 {
panic("no return value specified for CreatePAT")
@@ -152,17 +85,17 @@ func (_m *Service) CreatePAT(ctx context.Context, token string, name string, des
var r0 auth.PAT
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) (auth.PAT, error)); ok {
return rf(ctx, token, name, description, duration)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration) auth.PAT); ok {
r0 = rf(ctx, token, name, description, duration)
} else {
r0 = ret.Get(0).(auth.PAT)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok {
r1 = rf(ctx, token, name, description, duration, scope)
if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration) error); ok {
r1 = rf(ctx, token, name, description, duration)
} else {
r1 = ret.Error(1)
}
@@ -300,34 +233,27 @@ func (_m *Service) ListPATS(ctx context.Context, token string, pm auth.PATSPageM
return r0, r1
}
// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs
func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
_va := make([]interface{}, len(entityIDs))
for _i := range entityIDs {
_va[_i] = entityIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
// ListScopes provides a mock function with given fields: ctx, token, pm
func (_m *Service) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
ret := _m.Called(ctx, token, pm)
if len(ret) == 0 {
panic("no return value specified for RemovePATScopeEntry")
panic("no return value specified for ListScopes")
}
var r0 auth.Scope
var r0 auth.ScopesPage
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok {
return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) (auth.ScopesPage, error)); ok {
return rf(ctx, token, pm)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok {
r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(0).(func(context.Context, string, auth.ScopesPageMeta) auth.ScopesPage); ok {
r0 = rf(ctx, token, pm)
} else {
r0 = ret.Get(0).(auth.Scope)
r0 = ret.Get(0).(auth.ScopesPage)
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok {
r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if rf, ok := ret.Get(1).(func(context.Context, string, auth.ScopesPageMeta) error); ok {
r1 = rf(ctx, token, pm)
} else {
r1 = ret.Error(1)
}
@@ -335,6 +261,67 @@ func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID
return r0, r1
}
// RemoveAllPAT provides a mock function with given fields: ctx, token
func (_m *Service) RemoveAllPAT(ctx context.Context, token string) error {
ret := _m.Called(ctx, token)
if len(ret) == 0 {
panic("no return value specified for RemoveAllPAT")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, token)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemovePATAllScope provides a mock function with given fields: ctx, token, patID
func (_m *Service) RemovePATAllScope(ctx context.Context, token string, patID string) error {
ret := _m.Called(ctx, token, patID)
if len(ret) == 0 {
panic("no return value specified for RemovePATAllScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok {
r0 = rf(ctx, token, patID)
} else {
r0 = ret.Error(0)
}
return r0
}
// RemoveScope provides a mock function with given fields: ctx, token, patID, scopeIDs
func (_m *Service) RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error {
_va := make([]interface{}, len(scopeIDs))
for _i := range scopeIDs {
_va[_i] = scopeIDs[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, token, patID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for RemoveScope")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, ...string) error); ok {
r0 = rf(ctx, token, patID, scopeIDs...)
} else {
r0 = ret.Error(0)
}
return r0
}
// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration
func (_m *Service) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) {
ret := _m.Called(ctx, token, patID, duration)
+213 -561
View File
@@ -7,18 +7,18 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/pkg/errors"
)
var errAddEntityToAnyIDs = errors.New("could not add entity id to any ID scope value")
const AnyIDs = "*"
// Define OperationType.
type OperationType uint32
type Operation uint32
const (
CreateOp OperationType = iota
CreateOp Operation = iota
ReadOp
ListOp
UpdateOp
@@ -41,8 +41,8 @@ const (
SubscribeOpStr = "subscribe"
)
func (ot OperationType) String() string {
switch ot {
func (op Operation) String() string {
switch op {
case CreateOp:
return createOpStr
case ReadOp:
@@ -62,20 +62,20 @@ func (ot OperationType) String() string {
case SubscribeOp:
return SubscribeOpStr
default:
return fmt.Sprintf("unknown operation type %d", ot)
return fmt.Sprintf("unknown operation type %d", op)
}
}
func (ot OperationType) ValidString() (string, error) {
str := ot.String()
if str == fmt.Sprintf("unknown operation type %d", ot) {
func (op Operation) ValidString() (string, error) {
str := op.String()
if str == fmt.Sprintf("unknown operation type %d", op) {
return "", errors.New(str)
}
return str, nil
}
func ParseOperationType(ot string) (OperationType, error) {
switch ot {
func ParseOperation(op string) (Operation, error) {
switch op {
case createOpStr:
return CreateOp, nil
case readOpStr:
@@ -95,592 +95,205 @@ func ParseOperationType(ot string) (OperationType, error) {
case SubscribeOpStr:
return SubscribeOp, nil
default:
return 0, fmt.Errorf("unknown operation type %s", ot)
return 0, fmt.Errorf("unknown operation type %s", op)
}
}
func (ot OperationType) MarshalJSON() ([]byte, error) {
return []byte(ot.String()), nil
func (op Operation) MarshalJSON() ([]byte, error) {
return json.Marshal(op.String())
}
func (ot OperationType) MarshalText() (text []byte, err error) {
return []byte(ot.String()), nil
}
func (ot *OperationType) UnmarshalText(data []byte) (err error) {
*ot, err = ParseOperationType(string(data))
func (op *Operation) UnmarshalJSON(data []byte) error {
str := strings.Trim(string(data), "\"")
val, err := ParseOperation(str)
*op = val
return err
}
// Define DomainEntityType.
type DomainEntityType uint32
func (op Operation) MarshalText() (text []byte, err error) {
return []byte(op.String()), nil
}
func (op *Operation) UnmarshalText(data []byte) (err error) {
str := strings.Trim(string(data), "\"")
*op, err = ParseOperation(str)
return err
}
type EntityType uint32
const (
DomainManagementScope DomainEntityType = iota
DomainGroupsScope
DomainChannelsScope
DomainClientsScope
DomainNullScope
GroupsType EntityType = iota
ChannelsType
ClientsType
DomainsType
UsersType
DashboardType
MessagesType
)
const (
domainManagementScopeStr = "domain_management"
domainGroupsScopeStr = "groups"
domainChannelsScopeStr = "channels"
domainClientsScopeStr = "clients"
GroupsScopeStr = "groups"
ChannelsScopeStr = "channels"
ClientsScopeStr = "clients"
DomainsStr = "domains"
UsersStr = "users"
DashboardsStr = "dashboards"
MessagesStr = "messages"
)
func (det DomainEntityType) String() string {
switch det {
case DomainManagementScope:
return domainManagementScopeStr
case DomainGroupsScope:
return domainGroupsScopeStr
case DomainChannelsScope:
return domainChannelsScopeStr
case DomainClientsScope:
return domainClientsScopeStr
func (et EntityType) String() string {
switch et {
case GroupsType:
return GroupsScopeStr
case ChannelsType:
return ChannelsScopeStr
case ClientsType:
return ClientsScopeStr
case DomainsType:
return DomainsStr
case UsersType:
return UsersStr
case DashboardType:
return DashboardsStr
case MessagesType:
return MessagesStr
default:
return fmt.Sprintf("unknown domain entity type %d", det)
return fmt.Sprintf("unknown domain entity type %d", et)
}
}
func (det DomainEntityType) ValidString() (string, error) {
str := det.String()
if str == fmt.Sprintf("unknown operation type %d", det) {
func (et EntityType) ValidString() (string, error) {
str := et.String()
if str == fmt.Sprintf("unknown operation type %d", et) {
return "", errors.New(str)
}
return str, nil
}
func ParseDomainEntityType(det string) (DomainEntityType, error) {
switch det {
case domainManagementScopeStr:
return DomainManagementScope, nil
case domainGroupsScopeStr:
return DomainGroupsScope, nil
case domainChannelsScopeStr:
return DomainChannelsScope, nil
case domainClientsScopeStr:
return DomainClientsScope, nil
func ParseEntityType(et string) (EntityType, error) {
switch et {
case GroupsScopeStr:
return GroupsType, nil
case ChannelsScopeStr:
return ChannelsType, nil
case ClientsScopeStr:
return ClientsType, nil
case DomainsStr:
return DomainsType, nil
case UsersStr:
return UsersType, nil
case DashboardsStr:
return DashboardType, nil
default:
return 0, fmt.Errorf("unknown domain entity type %s", det)
return 0, fmt.Errorf("unknown domain entity type %s", et)
}
}
func (det DomainEntityType) MarshalJSON() ([]byte, error) {
return []byte(det.String()), nil
func (et EntityType) MarshalJSON() ([]byte, error) {
return json.Marshal(et.String())
}
func (det DomainEntityType) MarshalText() ([]byte, error) {
return []byte(det.String()), nil
}
func (det *DomainEntityType) UnmarshalText(data []byte) (err error) {
*det, err = ParseDomainEntityType(string(data))
func (et *EntityType) UnmarshalJSON(data []byte) error {
str := strings.Trim(string(data), "\"")
val, err := ParseEntityType(str)
*et = val
return err
}
// Define DomainEntityType.
type PlatformEntityType uint32
const (
PlatformUsersScope PlatformEntityType = iota
PlatformDomainsScope
PlatformDashBoardScope
PlatformMesagingScope
)
const (
platformUsersScopeStr = "users"
platformDomainsScopeStr = "domains"
PlatformDashBoardScopeStr = "dashboard"
PlatformMesagingScopeStr = "messaging"
)
func (pet PlatformEntityType) String() string {
switch pet {
case PlatformUsersScope:
return platformUsersScopeStr
case PlatformDomainsScope:
return platformDomainsScopeStr
case PlatformDashBoardScope:
return PlatformDashBoardScopeStr
case PlatformMesagingScope:
return PlatformMesagingScopeStr
default:
return fmt.Sprintf("unknown platform entity type %d", pet)
}
func (et EntityType) MarshalText() ([]byte, error) {
return []byte(et.String()), nil
}
func (pet PlatformEntityType) ValidString() (string, error) {
str := pet.String()
if str == fmt.Sprintf("unknown platform entity type %d", pet) {
return "", errors.New(str)
}
return str, nil
}
func ParsePlatformEntityType(pet string) (PlatformEntityType, error) {
switch pet {
case platformUsersScopeStr:
return PlatformUsersScope, nil
case platformDomainsScopeStr:
return PlatformDomainsScope, nil
default:
return 0, fmt.Errorf("unknown platform entity type %s", pet)
}
}
func (pet PlatformEntityType) MarshalJSON() ([]byte, error) {
return []byte(pet.String()), nil
}
func (pet PlatformEntityType) MarshalText() (text []byte, err error) {
return []byte(pet.String()), nil
}
func (pet *PlatformEntityType) UnmarshalText(data []byte) (err error) {
*pet, err = ParsePlatformEntityType(string(data))
func (et *EntityType) UnmarshalText(data []byte) (err error) {
str := strings.Trim(string(data), "\"")
*et, err = ParseEntityType(str)
return err
}
// ScopeValue interface for Any entity ids or for sets of entity ids.
type ScopeValue interface {
Contains(id string) bool
Values() []string
AddValues(ids ...string) error
RemoveValues(ids ...string) error
}
// AnyIDs implements ScopeValue for any entity id value.
type AnyIDs struct{}
func (s AnyIDs) Contains(id string) bool { return true }
func (s AnyIDs) Values() []string { return []string{"*"} }
func (s *AnyIDs) AddValues(ids ...string) error { return errAddEntityToAnyIDs }
func (s *AnyIDs) RemoveValues(ids ...string) error { return errAddEntityToAnyIDs }
// SelectedIDs implements ScopeValue for sets of entity ids.
type SelectedIDs map[string]struct{}
func (s SelectedIDs) Contains(id string) bool { _, ok := s[id]; return ok }
func (s SelectedIDs) Values() []string {
values := []string{}
for value := range s {
values = append(values, value)
}
return values
}
func (s *SelectedIDs) AddValues(ids ...string) error {
if *s == nil {
*s = make(SelectedIDs)
}
for _, id := range ids {
(*s)[id] = struct{}{}
}
return nil
}
func (s *SelectedIDs) RemoveValues(ids ...string) error {
if *s == nil {
return nil
}
for _, id := range ids {
delete(*s, id)
}
return nil
}
// OperationScope contains map of OperationType with value of AnyIDs or SelectedIDs.
type OperationScope map[OperationType]ScopeValue
func (os *OperationScope) UnmarshalJSON(data []byte) error {
type tempOperationScope map[OperationType]json.RawMessage
var tempScope tempOperationScope
if err := json.Unmarshal(data, &tempScope); err != nil {
return err
}
// Initialize the Operations map
*os = OperationScope{}
for opType, rawMessage := range tempScope {
var stringValue string
var stringArrayValue []string
// Try to unmarshal as string
if err := json.Unmarshal(rawMessage, &stringValue); err == nil {
if err := os.Add(opType, stringValue); err != nil {
return err
}
continue
}
// Try to unmarshal as []string
if err := json.Unmarshal(rawMessage, &stringArrayValue); err == nil {
if err := os.Add(opType, stringArrayValue...); err != nil {
return err
}
continue
}
// If neither unmarshalling succeeded, return an error
return fmt.Errorf("invalid ScopeValue for OperationType %v", opType)
}
return nil
}
func (os OperationScope) MarshalJSON() ([]byte, error) {
tempOperationScope := make(map[OperationType]interface{})
for oType, scope := range os {
value := scope.Values()
if len(value) == 1 && value[0] == "*" {
tempOperationScope[oType] = "*"
continue
}
tempOperationScope[oType] = value
}
b, err := json.Marshal(tempOperationScope)
if err != nil {
return nil, err
}
return b, nil
}
func (os *OperationScope) Add(operation OperationType, entityIDs ...string) error {
var value ScopeValue
if os == nil {
os = &OperationScope{}
}
if len(entityIDs) == 0 {
return fmt.Errorf("entity ID is missing")
}
switch {
case len(entityIDs) == 1 && entityIDs[0] == "*":
value = &AnyIDs{}
default:
var sids SelectedIDs
for _, entityID := range entityIDs {
if entityID == "*" {
return fmt.Errorf("list contains wildcard")
}
if sids == nil {
sids = make(SelectedIDs)
}
sids[entityID] = struct{}{}
}
value = &sids
}
(*os)[operation] = value
return nil
}
func (os *OperationScope) Delete(operation OperationType, entityIDs ...string) error {
if os == nil {
return nil
}
opEntityIDs, exists := (*os)[operation]
if !exists {
return nil
}
if len(entityIDs) == 0 {
return fmt.Errorf("failed to delete operation %s: entity ID is missing", operation.String())
}
switch eIDs := opEntityIDs.(type) {
case *AnyIDs:
if !(len(entityIDs) == 1 && entityIDs[0] == "*") {
return fmt.Errorf("failed to delete operation %s: invalid list", operation.String())
}
delete((*os), operation)
return nil
case *SelectedIDs:
for _, entityID := range entityIDs {
if !eIDs.Contains(entityID) {
return fmt.Errorf("failed to delete operation %s: invalid entity ID in list", operation.String())
}
}
for _, entityID := range entityIDs {
delete(*eIDs, entityID)
if len(*eIDs) == 0 {
delete((*os), operation)
}
}
return nil
default:
return fmt.Errorf("failed to delete operation: invalid entity id type %d", operation)
}
}
func (os *OperationScope) Check(operation OperationType, entityIDs ...string) bool {
if os == nil {
return false
}
if scopeValue, ok := (*os)[operation]; ok {
if len(entityIDs) == 0 {
_, ok := scopeValue.(*AnyIDs)
return ok
}
for _, entityID := range entityIDs {
if !scopeValue.Contains(entityID) {
return false
}
}
return true
}
return false
}
type DomainScope struct {
DomainManagement OperationScope `json:"domain_management,omitempty"`
Entities map[DomainEntityType]OperationScope `json:"entities,omitempty"`
}
// Add entry in Domain scope.
func (ds *DomainScope) Add(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if ds == nil {
return fmt.Errorf("failed to add domain %s scope: domain_scope is nil and not initialized", domainEntityType)
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return fmt.Errorf("failed to add domain %d scope: invalid domain entity type", domainEntityType)
}
if domainEntityType == DomainManagementScope {
if err := ds.DomainManagement.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain management scope: %w", err)
}
}
if ds.Entities == nil {
ds.Entities = make(map[DomainEntityType]OperationScope)
}
opReg, ok := ds.Entities[domainEntityType]
if !ok {
opReg = OperationScope{}
}
if err := opReg.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add domain %s scope: %w ", domainEntityType.String(), err)
}
ds.Entities[domainEntityType] = opReg
return nil
}
// Delete entry in Domain scope.
func (ds *DomainScope) Delete(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if ds == nil {
return nil
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return fmt.Errorf("failed to delete domain %d scope: invalid domain entity type", domainEntityType)
}
if ds.Entities == nil {
return nil
}
if domainEntityType == DomainManagementScope {
if err := ds.DomainManagement.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain management scope: %w", err)
}
}
os, exists := ds.Entities[domainEntityType]
if !exists {
return nil
}
if err := os.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete domain %s scope: %w", domainEntityType.String(), err)
}
if len(os) == 0 {
delete(ds.Entities, domainEntityType)
}
return nil
}
// Check entry in Domain scope.
func (ds *DomainScope) Check(domainEntityType DomainEntityType, operation OperationType, ids ...string) bool {
if ds.Entities == nil {
return false
}
if domainEntityType < DomainManagementScope || domainEntityType > DomainClientsScope {
return false
}
if domainEntityType == DomainManagementScope {
return ds.DomainManagement.Check(operation, ids...)
}
os, exists := ds.Entities[domainEntityType]
if !exists {
return false
}
return os.Check(operation, ids...)
}
// Example Scope as JSON
//
// {
// "users": {
// "create": ["*"],
// "read": ["*"],
// "list": ["*"],
// "update": ["*"],
// "delete": ["*"]
// },
// "domains": {
// "domain_1": {
// "entities": {
// "groups": {
// "create": ["*"] // this for all groups in domain
// },
// "channels": {
// // for particular channel in domain
// "delete": [
// "channel1",
// "channel2"
// ]
// },
// "things": {
// "update": ["*"] // this for all things in domain
// }
// }
// }
// }
// }
// [
// {
// "optional_domain_id": "domain_1",
// "entity_type": "groups",
// "operation": "create",
// "entity_id": "*"
// },
// {
// "optional_domain_id": "domain_1",
// "entity_type": "channels",
// "operation": "delete",
// "entity_id": "channel1"
// },
// {
// "optional_domain_id": "domain_1",
// "entity_type": "things",
// "operation": "update",
// "entity_id": "*"
// }
// ]
type Scope struct {
Users OperationScope `json:"users,omitempty"`
Domains map[string]DomainScope `json:"domains,omitempty"`
Dashboard OperationScope `json:"dashboard,omitempty"`
Messaging OperationScope `json:"messaging,omitempty"`
ID string `json:"id,omitempty"`
PatID string `json:"pat_id,omitempty"`
OptionalDomainID string `json:"optional_domain_id,omitempty"`
EntityType EntityType `json:"entity_type,omitempty"`
EntityID string `json:"entity_id,omitempty"`
Operation Operation `json:"operation,omitempty"`
}
// Add entry in Domain scope.
func (s *Scope) Add(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if s == nil {
return fmt.Errorf("failed to add platform %s scope: scope is nil and not initialized", platformEntityType.String())
}
switch platformEntityType {
case PlatformUsersScope:
if err := s.Users.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDashBoardScope:
if err := s.Dashboard.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformMesagingScope:
if err := s.Messaging.Add(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDomainsScope:
if optionalDomainID == "" {
return fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String())
}
if len(s.Domains) == 0 {
s.Domains = make(map[string]DomainScope)
}
ds, ok := s.Domains[optionalDomainID]
if !ok {
ds = DomainScope{}
}
if err := ds.Add(optionalDomainEntityType, operation, entityIDs...); err != nil {
return fmt.Errorf("failed to add platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err)
}
s.Domains[optionalDomainID] = ds
default:
return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType)
}
return nil
}
// Delete entry in Domain scope.
func (s *Scope) Delete(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if s == nil {
return nil
}
switch platformEntityType {
case PlatformUsersScope:
if err := s.Users.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDashBoardScope:
if err := s.Dashboard.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformMesagingScope:
if err := s.Messaging.Delete(operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err)
}
case PlatformDomainsScope:
if optionalDomainID == "" {
return fmt.Errorf("failed to delete platform %s scope: invalid domain id", platformEntityType.String())
}
ds, ok := s.Domains[optionalDomainID]
if !ok {
return nil
}
if err := ds.Delete(optionalDomainEntityType, operation, entityIDs...); err != nil {
return fmt.Errorf("failed to delete platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err)
}
default:
return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType)
}
return nil
}
// Check entry in Domain scope.
func (s *Scope) Check(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) bool {
func (s *Scope) Authorized(entityType EntityType, optionalDomainID string, operation Operation, entityID string) bool {
if s == nil {
return false
}
switch platformEntityType {
case PlatformUsersScope:
return s.Users.Check(operation, entityIDs...)
case PlatformDashBoardScope:
return s.Dashboard.Check(operation, entityIDs...)
case PlatformMesagingScope:
return s.Messaging.Check(operation, entityIDs...)
case PlatformDomainsScope:
ds, ok := s.Domains[optionalDomainID]
if !ok {
return false
}
return ds.Check(optionalDomainEntityType, operation, entityIDs...)
default:
if s.EntityType != entityType {
return false
}
if optionalDomainID != "" && s.OptionalDomainID != optionalDomainID {
return false
}
if s.Operation != operation {
return false
}
if s.EntityID == "*" {
return true
}
if s.EntityID == entityID {
return true
}
return false
}
func (s *Scope) String() string {
str, err := json.Marshal(s) // , "", " ")
if err != nil {
return fmt.Sprintf("failed to convert scope to string: json marshal error :%s", err.Error())
func (s *Scope) Validate() error {
if s == nil {
return errInvalidScope
}
return string(str)
if s.EntityID == "" {
return errors.New("missing entityID")
}
switch s.EntityType {
case ChannelsType, GroupsType, ClientsType:
if s.OptionalDomainID == "" {
return errors.New("missing domainID")
}
}
return nil
}
// PAT represents Personal Access Token.
type PAT struct {
ID string `json:"id,omitempty"`
User string `json:"user,omitempty"`
User string `json:"user_id,omitempty"`
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Secret string `json:"secret,omitempty"`
Scope Scope `json:"scope,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
UpdatedAt time.Time `json:"updated_at,omitempty"`
@@ -697,7 +310,29 @@ type PATSPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
PATS []PAT `json:"pats"`
PATS []PAT `json:"pats,omitempty"`
}
type ScopesPageMeta struct {
Offset uint64 `json:"offset"`
Limit uint64 `json:"limit"`
PatID string `json:"pat_id"`
ID string `json:"id"`
}
type ScopesPage struct {
Total uint64 `json:"total"`
Offset uint64 `json:"offset,omitempty"`
Limit uint64 `json:"limit,omitempy"`
Scopes []Scope `json:"scopes,omitempty"`
}
func (pat PAT) MarshalBinary() ([]byte, error) {
return json.Marshal(pat)
}
func (pat *PAT) UnmarshalBinary(data []byte) error {
return json.Unmarshal(data, pat)
}
func (pat *PAT) String() string {
@@ -708,17 +343,12 @@ func (pat *PAT) String() string {
return string(str)
}
// Expired verifies if the key is expired.
func (pat PAT) Expired() bool {
return pat.ExpiresAt.UTC().Before(time.Now().UTC())
}
// PATS specifies function which are required for Personal access Token implementation.
//go:generate mockery --name PATS --output=./mocks --filename pats.go --quiet --note "Copyright (c) Abstract Machines"
type PATS interface {
// Create function creates new PAT for given valid inputs.
CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error)
CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (PAT, error)
// UpdateName function updates the name for the given PAT ID.
UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error)
@@ -729,7 +359,10 @@ type PATS interface {
// Retrieve function retrieves the PAT for given ID.
RetrievePAT(ctx context.Context, userID string, patID string) (PAT, error)
// List function lists all the PATs for the user.
// RemoveAllPAT function removes all PATs of user.
RemoveAllPAT(ctx context.Context, token string) error
// ListPATS function lists all the PATs for the user.
ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error)
// Delete function deletes the PAT for given ID.
@@ -741,23 +374,23 @@ type PATS interface {
// RevokeSecret function revokes the secret for the given ID.
RevokePATSecret(ctx context.Context, token, patID string) error
// AddScope function adds a new scope entry.
AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
// AddScope function adds a new scope.
AddScope(ctx context.Context, token, patID string, scopes []Scope) error
// RemoveScope function removes a scope entry.
RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
// RemoveScope function removes a scope.
RemoveScope(ctx context.Context, token string, patID string, scopeIDs ...string) error
// ClearAllScope function removes all scope entry.
ClearPATAllScopeEntry(ctx context.Context, token, patID string) error
// RemovePATAllScope function removes all scope.
RemovePATAllScope(ctx context.Context, token, patID string) error
// List function lists all the Scopes for the patID.
ListScopes(ctx context.Context, token string, pm ScopesPageMeta) (ScopesPage, error)
// IdentifyPAT function will valid the secret.
IdentifyPAT(ctx context.Context, paToken string) (PAT, error)
// AuthorizePAT function will valid the secret and check the given scope exists.
AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
// CheckPAT function will check the given scope exists.
CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error
}
// PATSRepository specifies PATS persistence API.
@@ -770,6 +403,9 @@ type PATSRepository interface {
// Retrieve retrieves users PAT by its unique identifier.
Retrieve(ctx context.Context, userID, patID string) (pat PAT, err error)
// RetrieveScope retrieves PAT scopes by its unique identifier.
RetrieveScope(ctx context.Context, pm ScopesPageMeta) (scopes ScopesPage, err error)
// RetrieveSecretAndRevokeStatus retrieves secret and revoke status of PAT by its unique identifier.
RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error)
@@ -794,11 +430,27 @@ type PATSRepository interface {
// Remove removes Key with provided ID.
Remove(ctx context.Context, userID, patID string) error
AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
// RemoveAllPAT removes all PAT for a given user.
RemoveAllPAT(ctx context.Context, userID string) error
RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error)
AddScope(ctx context.Context, userID string, scopes []Scope) error
CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error
RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error
RemoveAllScopeEntry(ctx context.Context, userID, patID string) error
CheckScope(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error
RemoveAllScope(ctx context.Context, patID string) error
}
//go:generate mockery --name Cache --output=./mocks --filename cache.go --quiet --note "Copyright (c) Abstract Machines"
type Cache interface {
Save(ctx context.Context, userID string, scopes []Scope) error
CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType EntityType, operation Operation, entityID string) bool
Remove(ctx context.Context, userID string, scopesID []string) error
RemoveUserAllScope(ctx context.Context, userID string) error
RemoveAllScope(ctx context.Context, userID, patID string) error
}
+39
View File
@@ -65,6 +65,45 @@ func Migration() *migrate.MemoryMigrationSource {
`,
},
},
{
Id: "auth_4",
Up: []string{
`CREATE TABLE IF NOT EXISTS pats (
id VARCHAR(36) PRIMARY KEY,
name VARCHAR(254) NOT NULL,
user_id VARCHAR(36),
description TEXT,
secret TEXT,
issued_at TIMESTAMP,
expires_at TIMESTAMP,
updated_at TIMESTAMP,
revoked BOOLEAN,
revoked_at TIMESTAMP,
last_used_at TIMESTAMP,
UNIQUE (id, name, secret)
)`,
},
Down: []string{
`DROP TABLE IF EXISTS pats`,
},
},
{
Id: "auth_5",
Up: []string{
`CREATE TABLE IF NOT EXISTS pat_scopes (
id VARCHAR(36) PRIMARY KEY,
pat_id VARCHAR(36) REFERENCES pats(id) ON DELETE CASCADE,
optional_domain_id VARCHAR(36),
entity_type VARCHAR(50) NOT NULL,
operation VARCHAR(50) NOT NULL,
entity_id VARCHAR(50) NOT NULL,
UNIQUE (pat_id, optional_domain_id, entity_type, operation, entity_id)
);`,
},
Down: []string{
`DROP TABLE IF EXISTS pat_scopes;`,
},
},
},
}
}
+167
View File
@@ -0,0 +1,167 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"database/sql"
"time"
"github.com/absmach/supermq/auth"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
)
type dbPat struct {
ID string `db:"id,omitempty"`
User string `db:"user_id,omitempty"`
Name string `db:"name,omitempty"`
Description string `db:"description,omitempty"`
Secret string `db:"secret,omitempty"`
IssuedAt time.Time `db:"issued_at,omitempty"`
ExpiresAt time.Time `db:"expires_at,omitempty"`
UpdatedAt sql.NullTime `db:"updated_at,omitempty"`
LastUsedAt sql.NullTime `db:"last_used_at,omitempty"`
Revoked bool `db:"revoked,omitempty"`
RevokedAt sql.NullTime `db:"revoked_at,omitempty"`
}
type dbScope struct {
ID string `db:"id,omitempty"`
PatID string `db:"pat_id,omitempty"`
OptionalDomainID string `db:"optional_domain_id,omitempty"`
EntityType string `db:"entity_type,omitempty"`
EntityID string `db:"entity_id,omitempty"`
Operation string `db:"operation,omitempty"`
}
type dbPagemeta struct {
Limit uint64 `db:"limit"`
Offset uint64 `db:"offset"`
User string `db:"user_id"`
PatID string `db:"pat_id"`
ScopesID []string `db:"scopes_id"`
ID string `db:"id"`
Name string `db:"name"`
UpdatedAt sql.NullTime `db:"updated_at"`
ExpiresAt time.Time `db:"expires_at"`
RevokedAt sql.NullTime `db:"revoked_at"`
Description string `db:"description"`
Secret string `db:"secret"`
}
func toAuthPat(db dbPat) (auth.PAT, error) {
if db.ID == "" {
return auth.PAT{}, repoerr.ErrNotFound
}
updatedAt := time.Time{}
lastUsedAt := time.Time{}
revokedAt := time.Time{}
if db.UpdatedAt.Valid {
updatedAt = db.UpdatedAt.Time
}
if db.LastUsedAt.Valid {
lastUsedAt = db.LastUsedAt.Time
}
if db.RevokedAt.Valid {
revokedAt = db.RevokedAt.Time
}
pat := auth.PAT{
ID: db.ID,
User: db.User,
Name: db.Name,
Description: db.Description,
Secret: db.Secret,
IssuedAt: db.IssuedAt,
ExpiresAt: db.ExpiresAt,
UpdatedAt: updatedAt,
LastUsedAt: lastUsedAt,
Revoked: db.Revoked,
RevokedAt: revokedAt,
}
return pat, nil
}
func toAuthScope(dsc []dbScope) ([]auth.Scope, error) {
scope := []auth.Scope{}
for _, s := range dsc {
entityType, err := auth.ParseEntityType(s.EntityType)
if err != nil {
return []auth.Scope{}, err
}
operation, err := auth.ParseOperation(s.Operation)
if err != nil {
return []auth.Scope{}, err
}
scope = append(scope, auth.Scope{
ID: s.ID,
PatID: s.PatID,
OptionalDomainID: s.OptionalDomainID,
EntityType: entityType,
EntityID: s.EntityID,
Operation: operation,
})
}
return scope, nil
}
func toDBPats(pat auth.PAT) (dbPat, error) {
var updatedAt, lastUsedAt, revokedAt sql.NullTime
if !pat.UpdatedAt.IsZero() {
updatedAt = sql.NullTime{
Time: pat.UpdatedAt,
Valid: true,
}
}
if !pat.LastUsedAt.IsZero() {
lastUsedAt = sql.NullTime{
Time: pat.LastUsedAt,
Valid: true,
}
}
if !pat.RevokedAt.IsZero() {
revokedAt = sql.NullTime{
Time: pat.RevokedAt,
Valid: true,
}
}
return dbPat{
ID: pat.ID,
User: pat.User,
Name: pat.Name,
Description: pat.Description,
Secret: pat.Secret,
IssuedAt: pat.IssuedAt,
ExpiresAt: pat.ExpiresAt,
UpdatedAt: updatedAt,
LastUsedAt: lastUsedAt,
Revoked: pat.Revoked,
RevokedAt: revokedAt,
}, nil
}
func toDBScope(sc []auth.Scope) []dbScope {
var scopes []dbScope
for _, s := range sc {
scopes = append(scopes, dbScope{
ID: s.ID,
PatID: s.PatID,
OptionalDomainID: s.OptionalDomainID,
EntityType: s.EntityType.String(),
EntityID: s.EntityID,
Operation: s.Operation.String(),
})
}
return scopes
}
+657
View File
@@ -0,0 +1,657 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
repoerr "github.com/absmach/supermq/pkg/errors/repository"
"github.com/absmach/supermq/pkg/postgres"
)
var _ auth.PATSRepository = (*patRepo)(nil)
type patRepo struct {
db postgres.Database
cache auth.Cache
}
func NewPatRepo(db postgres.Database, cache auth.Cache) auth.PATSRepository {
return &patRepo{
db: db,
cache: cache,
}
}
func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error {
q := `
INSERT INTO pats (
id, user_id, name, description, secret, issued_at, expires_at,
updated_at, last_used_at, revoked, revoked_at
) VALUES (
:id, :user_id, :name, :description, :secret, :issued_at, :expires_at,
:updated_at, :last_used_at, :revoked, :revoked_at
)`
dbPat, err := toDBPats(pat)
if err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
_, err = pr.db.NamedQueryContext(ctx, q, dbPat)
if err != nil {
return postgres.HandleError(repoerr.ErrCreateEntity, err)
}
return nil
}
func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) {
pat, err := pr.retrievePATFromDB(ctx, userID, patID)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return pat, nil
}
func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) {
q := `
SELECT
p.id, p.user_id, p.name, p.description, p.issued_at, p.expires_at,
p.updated_at, p.revoked, p.revoked_at
FROM pats p WHERE user_id = :user_id
ORDER BY issued_at DESC
LIMIT :limit OFFSET :offset`
dbPage := dbPagemeta{
Limit: pm.Limit,
Offset: pm.Offset,
User: userID,
}
rows, err := pr.db.NamedQueryContext(ctx, q, dbPage)
if err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var items []auth.PAT
for rows.Next() {
var pat dbPat
if err := rows.StructScan(&pat); err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
var updatedAt, revokedAt time.Time
if pat.UpdatedAt.Valid {
updatedAt = pat.UpdatedAt.Time
}
if pat.RevokedAt.Valid {
revokedAt = pat.RevokedAt.Time
}
items = append(items, auth.PAT{
ID: pat.ID,
User: pat.User,
Name: pat.Name,
Description: pat.Description,
IssuedAt: pat.IssuedAt,
ExpiresAt: pat.ExpiresAt,
UpdatedAt: updatedAt,
Revoked: pat.Revoked,
RevokedAt: revokedAt,
})
}
cq := `SELECT COUNT(*) FROM pats p WHERE user_id = :user_id`
total, err := postgres.Total(ctx, pr.db, cq, dbPage)
if err != nil {
return auth.PATSPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
page := auth.PATSPage{
PATS: items,
Total: total,
Offset: pm.Offset,
Limit: pm.Limit,
}
return page, nil
}
func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, bool, error) {
q := `
SELECT p.secret, p.revoked, p.expires_at
FROM pats p
WHERE user_id = $1 AND id = $2`
rows, err := pr.db.QueryContext(ctx, q, userID, patID)
if err != nil {
return "", true, true, postgres.HandleError(repoerr.ErrNotFound, err)
}
defer rows.Close()
var secret string
var revoked bool
var expiresAt time.Time
if !rows.Next() {
return "", true, true, repoerr.ErrNotFound
}
if err := rows.Scan(&secret, &revoked, &expiresAt); err != nil {
return "", true, true, postgres.HandleError(repoerr.ErrNotFound, err)
}
expired := time.Now().After(expiresAt)
return secret, revoked, expired, nil
}
func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) {
q := `
UPDATE pats p
SET name = :name, updated_at = :updated_at
WHERE user_id = :user_id AND id = :id
RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at`
upm := dbPagemeta{
User: userID,
ID: patID,
Name: name,
UpdatedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
}
rows, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if !rows.Next() {
return auth.PAT{}, repoerr.ErrNotFound
}
var pat dbPat
if err := rows.StructScan(&pat); err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
res, err := toAuthPat(pat)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return res, nil
}
func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) {
q := `
UPDATE pats
SET description = :description, updated_at = :updated_at
WHERE user_id = :user_id AND id = :id
RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at`
upm := dbPagemeta{
User: userID,
ID: patID,
UpdatedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
Description: description,
}
rows, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if !rows.Next() {
return auth.PAT{}, repoerr.ErrNotFound
}
var pat dbPat
if err := rows.StructScan(&pat); err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
res, err := toAuthPat(pat)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return res, nil
}
func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) {
q := `
UPDATE pats
SET secret = :secret, expires_at = :expires_at, updated_at = :updated_at
WHERE user_id = :user_id AND id = :id
RETURNING id, user_id, name, description, secret, issued_at, updated_at, expires_at, revoked, revoked_at, last_used_at`
upm := dbPagemeta{
User: userID,
ID: patID,
UpdatedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
ExpiresAt: expiryAt,
Secret: tokenHash,
}
rows, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
if !rows.Next() {
return auth.PAT{}, repoerr.ErrNotFound
}
var pat dbPat
if err := rows.StructScan(&pat); err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
res, err := toAuthPat(pat)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return res, nil
}
func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error {
q := `
UPDATE pats
SET revoked = true, revoked_at = :revoked_at
WHERE user_id = :user_id AND id = :id`
upm := dbPagemeta{
User: userID,
ID: patID,
RevokedAt: sql.NullTime{
Time: time.Now(),
Valid: true,
},
}
_, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}
func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error {
q := `
UPDATE pats
SET revoked = false, revoked_at = NULL
WHERE user_id = :user_id AND id = :id`
upm := dbPagemeta{
User: userID,
ID: patID,
}
_, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}
func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error {
q := `DELETE FROM pats WHERE user_id = :user_id AND id = :id`
upm := dbPagemeta{
User: userID,
ID: patID,
}
_, err := pr.db.NamedQueryContext(ctx, q, upm)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pr *patRepo) RemoveAllPAT(ctx context.Context, userID string) error {
q := `DELETE FROM pats WHERE user_id = :user_id`
pm := dbPagemeta{
User: userID,
}
_, err := pr.db.NamedQueryContext(ctx, q, pm)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
if err := pr.cache.RemoveUserAllScope(ctx, userID); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pr *patRepo) AddScope(ctx context.Context, userID string, scopes []auth.Scope) error {
q := `
INSERT INTO pat_scopes (id, pat_id, entity_type, optional_domain_id, operation, entity_id)
VALUES (:id, :pat_id, :entity_type, :optional_domain_id, :operation, :entity_id)`
var newScopes []auth.Scope
for _, sc := range scopes {
processedScope, err := pr.processScope(ctx, sc)
if err != nil {
return err
}
if processedScope.ID != "" {
newScopes = append(newScopes, processedScope)
}
}
if len(newScopes) > 0 {
_, err := pr.db.NamedQueryContext(ctx, q, toDBScope(newScopes))
if err != nil {
return postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
}
if err := pr.cache.Save(ctx, userID, scopes); err != nil {
return errors.Wrap(repoerr.ErrUpdateEntity, err)
}
return nil
}
func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope, error) {
q := `
SELECT COUNT(*)
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND operation = :operation
AND entity_id = :entity_id
LIMIT 1`
params := dbScope{
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation.String(),
EntityID: auth.AnyIDs,
}
rows, err := pr.db.NamedQueryContext(ctx, q, params)
if err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
var count int
if rows.Next() {
if err := rows.Scan(&count); err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
}
if count > 0 {
return auth.Scope{}, repoerr.ErrConflict
}
if sc.EntityID == auth.AnyIDs {
newParams := dbScope{
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation.String(),
}
checkEntityQuery := `
SELECT COUNT(*)
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND operation = :operation
LIMIT 1`
rows, err := pr.db.NamedQueryContext(ctx, checkEntityQuery, newParams)
if err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
defer rows.Close()
var count int
if rows.Next() {
if err := rows.Scan(&count); err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrViewEntity, err)
}
}
if count > 0 {
updateWithWildcardQuery := `
UPDATE pat_scopes
SET entity_id = :entity_id
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND operation = :operation`
_, err = pr.db.NamedQueryContext(ctx, updateWithWildcardQuery, params)
if err != nil {
return auth.Scope{}, postgres.HandleError(repoerr.ErrUpdateEntity, err)
}
return auth.Scope{}, nil
}
}
return sc, nil
}
func (pr *patRepo) RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error {
deleteScopesQuery := fmt.Sprintf(`DELETE FROM pat_scopes WHERE id IN ('%s')`, strings.Join(scopesIDs, ","))
res, err := pr.db.ExecContext(ctx, deleteScopesQuery)
if err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
if rows, _ := res.RowsAffected(); rows == 0 {
return repoerr.ErrNotFound
}
if err := pr.cache.Remove(ctx, userID, scopesIDs); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pr *patRepo) CheckScope(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
q := `
SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND operation = :operation
AND (entity_id = :entity_id OR entity_id = '*')
LIMIT 1`
authorized := pr.cache.CheckScope(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
if authorized {
return nil
}
scope := dbScope{
PatID: patID,
EntityType: entityType.String(),
OptionalDomainID: optionalDomainID,
Operation: operation.String(),
EntityID: entityID,
}
rows, err := pr.db.NamedQueryContext(ctx, q, scope)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
if rows.Next() {
var sc dbScope
if err := rows.StructScan(&sc); err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
entityType, err := auth.ParseEntityType(sc.EntityType)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
operation, err := auth.ParseOperation(sc.Operation)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
authScope := auth.Scope{
ID: sc.ID,
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: entityType,
EntityID: sc.EntityID,
Operation: operation,
}
if err := pr.cache.Save(ctx, userID, []auth.Scope{authScope}); err != nil {
return err
}
if authScope.Authorized(entityType, optionalDomainID, operation, entityID) {
return nil
}
}
return repoerr.ErrNotFound
}
func (pr *patRepo) RemoveAllScope(ctx context.Context, patID string) error {
pm := dbPagemeta{
PatID: patID,
}
q := `DELETE FROM pat_scopes WHERE pat_id = :pat_id`
_, err := pr.db.NamedQueryContext(ctx, q, pm)
if err != nil {
return postgres.HandleError(repoerr.ErrRemoveEntity, err)
}
if err := pr.cache.RemoveAllScope(ctx, pm.User, patID); err != nil {
return errors.Wrap(repoerr.ErrRemoveEntity, err)
}
return nil
}
func (pr *patRepo) RetrieveScope(ctx context.Context, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
dbs := dbPagemeta{
PatID: pm.PatID,
Offset: pm.Offset,
Limit: pm.Limit,
}
scopes, err := pr.retrieveScopeFromDB(ctx, dbs)
if err != nil {
return auth.ScopesPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
cq := `SELECT COUNT(*) FROM pat_scopes WHERE pat_id = :pat_id`
total, err := postgres.Total(ctx, pr.db, cq, dbs)
if err != nil {
return auth.ScopesPage{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
return auth.ScopesPage{
Total: total,
Scopes: scopes,
Offset: pm.Offset,
Limit: pm.Limit,
}, nil
}
func (pr *patRepo) retrieveScopeFromDB(ctx context.Context, pm dbPagemeta) ([]auth.Scope, error) {
q := `
SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id
FROM pat_scopes WHERE pat_id = :pat_id OFFSET :offset LIMIT :limit`
scopeRows, err := pr.db.NamedQueryContext(ctx, q, pm)
if err != nil {
return []auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer scopeRows.Close()
var scopes []dbScope
for scopeRows.Next() {
var scope dbScope
if err := scopeRows.StructScan(&scope); err != nil {
return []auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
scopes = append(scopes, scope)
}
sc, err := toAuthScope(scopes)
if err != nil {
return []auth.Scope{}, err
}
return sc, nil
}
func (pr *patRepo) retrievePATFromDB(ctx context.Context, userID, patID string) (auth.PAT, error) {
q := `
SELECT
id, user_id, name, description, secret, issued_at, expires_at,
updated_at, last_used_at, revoked, revoked_at
FROM pats WHERE user_id = :user_id AND id = :id`
dbp := dbPagemeta{
ID: patID,
User: userID,
}
rows, err := pr.db.NamedQueryContext(ctx, q, dbp)
if err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
defer rows.Close()
var record dbPat
if rows.Next() {
if err := rows.StructScan(&record); err != nil {
return auth.PAT{}, errors.Wrap(repoerr.ErrViewEntity, err)
}
}
pat, err := toAuthPat(record)
if err != nil {
return auth.PAT{}, err
}
return pat, nil
}
+98 -53
View File
@@ -44,8 +44,7 @@ var (
errUpdatePAT = errors.New("failed to update PAT")
errRetrievePAT = errors.New("failed to retrieve PAT")
errDeletePAT = errors.New("failed to delete PAT")
errRevokePAT = errors.New("failed to revoke PAT")
errClearAllScope = errors.New("failed to clear all entry in scope")
errInvalidScope = errors.New("invalid scope")
)
// Authz represents a authorization service. It exposes
@@ -100,6 +99,7 @@ var _ Service = (*service)(nil)
type service struct {
keys KeyRepository
pats PATSRepository
cache Cache
hasher Hasher
idProvider supermq.IDProvider
evaluator policies.Evaluator
@@ -111,11 +111,12 @@ type service struct {
}
// New instantiates the auth service implementation.
func New(keys KeyRepository, pats PATSRepository, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
func New(keys KeyRepository, repo PATSRepository, cache Cache, hasher Hasher, idp supermq.IDProvider, tokenizer Tokenizer, policyEvaluator policies.Evaluator, policyService policies.Service, loginDuration, refreshDuration, invitationDuration time.Duration) Service {
return &service{
tokenizer: tokenizer,
keys: keys,
pats: pats,
pats: repo,
cache: cache,
hasher: hasher,
idProvider: idp,
evaluator: policyEvaluator,
@@ -457,7 +458,7 @@ func DecodeDomainUserID(domainUserID string) (string, string) {
}
}
func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) {
func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
@@ -481,17 +482,18 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin
Secret: hash,
IssuedAt: now,
ExpiresAt: now.Add(duration),
Scope: scope,
}
if err := svc.pats.Save(ctx, pat); err != nil {
return PAT{}, errors.Wrap(errCreatePAT, err)
}
pat.Secret = secret
return pat, nil
}
func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) {
key, err := svc.Identify(ctx, token)
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return PAT{}, err
}
@@ -503,7 +505,7 @@ func (svc service) UpdatePATName(ctx context.Context, token, patID, name string)
}
func (svc service) UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) {
key, err := svc.Identify(ctx, token)
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return PAT{}, err
}
@@ -514,8 +516,12 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr
return pat, nil
}
func (svc service) RetrievePAT(ctx context.Context, userID, patID string) (PAT, error) {
pat, err := svc.pats.Retrieve(ctx, userID, patID)
func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.Retrieve(ctx, key.User, patID)
if err != nil {
return PAT{}, errors.Wrap(errRetrievePAT, err)
}
@@ -535,7 +541,7 @@ func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta)
}
func (svc service) DeletePAT(ctx context.Context, token, patID string) error {
key, err := svc.Identify(ctx, token)
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return err
}
@@ -546,7 +552,7 @@ func (svc service) DeletePAT(ctx context.Context, token, patID string) error {
}
func (svc service) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) {
key, err := svc.Identify(ctx, token)
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return PAT{}, err
}
@@ -572,48 +578,83 @@ func (svc service) ResetPATSecret(ctx context.Context, token, patID string, dura
}
func (svc service) RevokePATSecret(ctx context.Context, token, patID string) error {
key, err := svc.Identify(ctx, token)
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return err
}
if err := svc.pats.Revoke(ctx, key.User, patID); err != nil {
return errors.Wrap(errRevokePAT, err)
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return nil
}
func (svc service) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Scope{}, err
}
scope, err := svc.pats.AddScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return Scope{}, errors.Wrap(errRevokePAT, err)
}
return scope, nil
}
func (svc service) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Scope{}, err
}
scope, err := svc.pats.RemoveScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
if err != nil {
return Scope{}, err
}
return scope, nil
}
func (svc service) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
func (svc service) RemoveAllPAT(ctx context.Context, token string) error {
key, err := svc.Identify(ctx, token)
if err != nil {
return err
}
if err := svc.pats.RemoveAllScopeEntry(ctx, key.User, patID); err != nil {
return errors.Wrap(errClearAllScope, err)
if err := svc.pats.RemoveAllPAT(ctx, key.User); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
func (svc service) AddScope(ctx context.Context, token, patID string, scopes []Scope) error {
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return err
}
for i := range len(scopes) {
scopes[i].ID, err = svc.idProvider.ID()
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
scopes[i].PatID = patID
}
err = svc.pats.AddScope(ctx, key.User, scopes)
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
return nil
}
func (svc service) RemoveScope(ctx context.Context, token, patID string, scopesIDs ...string) error {
key, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return err
}
err = svc.pats.RemoveScope(ctx, key.User, scopesIDs...)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
func (svc service) ListScopes(ctx context.Context, token string, pm ScopesPageMeta) (ScopesPage, error) {
_, err := svc.authnAuthzUserPAT(ctx, token, pm.PatID)
if err != nil {
return ScopesPage{}, err
}
patsPage, err := svc.pats.RetrieveScope(ctx, pm)
if err != nil {
return ScopesPage{}, errors.Wrap(errRetrievePAT, err)
}
return patsPage, nil
}
func (svc service) RemovePATAllScope(ctx context.Context, token, patID string) error {
_, err := svc.authnAuthzUserPAT(ctx, token, patID)
if err != nil {
return err
}
if err := svc.pats.RemoveAllScope(ctx, patID); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
}
@@ -643,21 +684,11 @@ func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error)
return PAT{ID: patID.String(), User: userID.String()}, nil
}
func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
res, err := svc.RetrievePAT(ctx, userID, patID)
if err != nil {
return err
}
if err := svc.pats.CheckScopeEntry(ctx, res.User, res.ID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil {
func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error {
if err := svc.pats.CheckScope(ctx, userID, patID, entityType, optionalDomainID, operation, entityID); err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
func (svc service) CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error {
if err := svc.pats.CheckScopeEntry(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
return nil
}
@@ -707,3 +738,17 @@ func generateRandomString(n int) string {
}
return string(b)
}
func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (Key, error) {
key, err := svc.Identify(ctx, token)
if err != nil {
return Key{}, err
}
_, err = svc.pats.Retrieve(ctx, key.User, patID)
if err != nil {
return Key{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
return key, nil
}
+3 -1
View File
@@ -50,11 +50,13 @@ var (
pService *policymocks.Service
pEvaluator *policymocks.Evaluator
patsrepo *mocks.PATSRepository
cache *mocks.Cache
hasher *mocks.Hasher
)
func newService() (auth.Service, string) {
krepo = new(mocks.KeyRepository)
cache = new(mocks.Cache)
pService = new(policymocks.Service)
pEvaluator = new(policymocks.Evaluator)
patsrepo = new(mocks.PATSRepository)
@@ -72,7 +74,7 @@ func newService() (auth.Service, string) {
}
token, _ := t.Issue(key)
return auth.New(krepo, patsrepo, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
return auth.New(krepo, patsrepo, cache, hasher, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration), token
}
func TestIssue(t *testing.T) {
+47 -45
View File
@@ -76,15 +76,14 @@ func (tm *tracingMiddleware) Authorize(ctx context.Context, pr policies.Policy)
return tm.svc.Authorize(ctx, pr)
}
func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) {
func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration) (auth.PAT, error) {
ctx, span := tm.tracer.Start(ctx, "create_pat", trace.WithAttributes(
attribute.String("name", name),
attribute.String("description", description),
attribute.String("duration", duration.String()),
attribute.String("scope", scope.String()),
))
defer span.End()
return tm.svc.CreatePAT(ctx, token, name, description, duration, scope)
return tm.svc.CreatePAT(ctx, token, name, description, duration)
}
func (tm *tracingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) {
@@ -122,6 +121,15 @@ func (tm *tracingMiddleware) ListPATS(ctx context.Context, token string, pm auth
return tm.svc.ListPATS(ctx, token, pm)
}
func (tm *tracingMiddleware) ListScopes(ctx context.Context, token string, pm auth.ScopesPageMeta) (auth.ScopesPage, error) {
ctx, span := tm.tracer.Start(ctx, "list_scopes", trace.WithAttributes(
attribute.Int64("limit", int64(pm.Limit)),
attribute.Int64("offset", int64(pm.Offset)),
))
defer span.End()
return tm.svc.ListScopes(ctx, token, pm)
}
func (tm *tracingMiddleware) DeletePAT(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "delete_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
@@ -147,38 +155,47 @@ func (tm *tracingMiddleware) RevokePATSecret(ctx context.Context, token, patID s
return tm.svc.RevokePATSecret(ctx, token, patID)
}
func (tm *tracingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
ctx, span := tm.tracer.Start(ctx, "add_pat_scope_entry", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
func (tm *tracingMiddleware) RemoveAllPAT(ctx context.Context, token string) error {
ctx, span := tm.tracer.Start(ctx, "clear_all_pat")
defer span.End()
return tm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return tm.svc.RemoveAllPAT(ctx, token)
}
func (tm *tracingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) {
ctx, span := tm.tracer.Start(ctx, "remove_pat_scope_entry", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
func (tm *tracingMiddleware) AddScope(ctx context.Context, token, patID string, scopes []auth.Scope) error {
var attributes []attribute.KeyValue
for _, s := range scopes {
attributes = append(attributes, attribute.String("entity_type", s.EntityType.String()))
attributes = append(attributes, attribute.String("optional_domain_id", s.OptionalDomainID))
attributes = append(attributes, attribute.String("operation", s.Operation.String()))
attributes = append(attributes, attribute.String("entity_id", s.EntityID))
}
attributes = append(attributes, attribute.String("pat_id", patID))
ctx, span := tm.tracer.Start(ctx, "add_pat_scope", trace.WithAttributes(attributes...))
defer span.End()
return tm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return tm.svc.AddScope(ctx, token, patID, scopes)
}
func (tm *tracingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope_entry", trace.WithAttributes(
func (tm *tracingMiddleware) RemoveScope(ctx context.Context, token, patID string, scopesID ...string) error {
var attributes []attribute.KeyValue
for _, s := range scopesID {
attributes = append(attributes, attribute.String("scope_id", s))
}
attributes = append(attributes, attribute.String("pat_id", patID))
ctx, span := tm.tracer.Start(ctx, "remove_pat_scope", trace.WithAttributes(attributes...))
defer span.End()
return tm.svc.RemoveScope(ctx, token, patID, scopesID...)
}
func (tm *tracingMiddleware) RemovePATAllScope(ctx context.Context, token, patID string) error {
ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope", trace.WithAttributes(
attribute.String("pat_id", patID),
))
defer span.End()
return tm.svc.ClearPATAllScopeEntry(ctx, token, patID)
return tm.svc.RemovePATAllScope(ctx, token, patID)
}
func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) {
@@ -187,29 +204,14 @@ func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a
return tm.svc.IdentifyPAT(ctx, paToken)
}
func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("entity_type", entityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
attribute.String("entities", entityID),
))
defer span.End()
return tm.svc.AuthorizePAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
}
func (tm *tracingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error {
ctx, span := tm.tracer.Start(ctx, "check_pat", trace.WithAttributes(
attribute.String("user_id", userID),
attribute.String("pat_id", patID),
attribute.String("platform_entity", platformEntityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("optional_domain_entity", optionalDomainEntityType.String()),
attribute.String("operation", operation.String()),
attribute.StringSlice("entities", entityIDs),
))
defer span.End()
return tm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...)
return tm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
}
+110 -119
View File
@@ -85,13 +85,12 @@ func AuthorizationMiddleware(svc channels.Service, repo channels.Repository, aut
func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.CreateOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: auth.AnyIDs,
}); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -126,13 +125,12 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.ReadOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.ReadOp,
EntityID: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -153,13 +151,12 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.PageMetadata) (channels.Page, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -174,13 +171,12 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut
func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -194,13 +190,12 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{channel.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: channel.ID,
}); err != nil {
return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -221,13 +216,12 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au
func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{channel.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: channel.ID,
}); err != nil {
return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -248,13 +242,12 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio
func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -275,13 +268,12 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au
func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -302,13 +294,12 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a
func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -328,28 +319,29 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au
func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.CreateOp,
EntityIDs: chIDs,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
for _, chID := range chIDs {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: chID,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.CreateOp,
EntityIDs: thIDs,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
for _, thID := range thIDs {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: thID,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
}
for _, chID := range chIDs {
@@ -380,28 +372,29 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.DeleteOp,
EntityIDs: chIDs,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
for _, chID := range chIDs {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: chID,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.DeleteOp,
EntityIDs: thIDs,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
for _, thID := range thIDs {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: thID,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
}
}
@@ -434,13 +427,12 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -471,13 +463,12 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainChannelsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
+72 -84
View File
@@ -77,13 +77,12 @@ func AuthorizationMiddleware(entityType string, svc clients.Service, authz smqau
func (am *authorizationMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.CreateOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: auth.AnyIDs,
}); err != nil {
return []clients.Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -105,13 +104,12 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au
func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.ReadOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.ReadOp,
EntityID: id,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -132,13 +130,12 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -154,13 +151,12 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth
func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -176,13 +172,12 @@ func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session
func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{client.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: client.ID,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -204,13 +199,12 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{client.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: client.ID,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -232,13 +226,12 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -259,13 +252,12 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut
func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -287,13 +279,12 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -314,13 +305,12 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainClientsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -341,13 +331,12 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -378,13 +367,12 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ClientsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
+19 -22
View File
@@ -21,12 +21,12 @@ import (
authgrpcapi "github.com/absmach/supermq/auth/api/grpc/auth"
tokengrpcapi "github.com/absmach/supermq/auth/api/grpc/token"
httpapi "github.com/absmach/supermq/auth/api/http"
"github.com/absmach/supermq/auth/bolt"
"github.com/absmach/supermq/auth/cache"
"github.com/absmach/supermq/auth/hasher"
"github.com/absmach/supermq/auth/jwt"
apostgres "github.com/absmach/supermq/auth/postgres"
"github.com/absmach/supermq/auth/tracing"
boltclient "github.com/absmach/supermq/internal/clients/bolt"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
"github.com/absmach/supermq/pkg/jaeger"
"github.com/absmach/supermq/pkg/policies/spicedb"
@@ -41,7 +41,7 @@ import (
"github.com/authzed/grpcutil"
"github.com/caarlos0/env/v11"
"github.com/jmoiron/sqlx"
"go.etcd.io/bbolt"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
@@ -75,6 +75,8 @@ type config struct {
SpicedbPreSharedKey string `env:"SMQ_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"`
TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"`
ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"`
CacheURL string `env:"SMQ_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"`
CacheKeyDuration time.Duration `env:"SMQ_AUTH_CACHE_KEY_DURATION" envDefault:"10m"`
}
func main() {
@@ -107,6 +109,14 @@ func main() {
logger.Error(err.Error())
}
cacheclient, err := redisclient.Connect(cfg.CacheURL)
if err != nil {
logger.Error(err.Error())
exitCode = 1
return
}
defer cacheclient.Close()
am := apostgres.Migration()
db, err := pgclient.Setup(dbConfig, *am)
if err != nil {
@@ -136,22 +146,7 @@ func main() {
return
}
boltDBConfig := boltclient.Config{}
if err := env.ParseWithOptions(&boltDBConfig, env.Options{Prefix: envPrefixPATDB}); err != nil {
logger.Error(fmt.Sprintf("failed to parse bolt db config : %s\n", err.Error()))
exitCode = 1
return
}
bClient, err := boltclient.Connect(boltDBConfig, bolt.Init)
if err != nil {
logger.Error(fmt.Sprintf("failed to connect to bolt db : %s\n", err.Error()))
exitCode = 1
return
}
defer bClient.Close()
svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, bClient, boltDBConfig)
svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, cacheclient, cfg.CacheKeyDuration)
grpcServerConfig := server.Config{Port: defSvcGRPCPort}
if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGrpc}); err != nil {
@@ -231,10 +226,12 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch
return nil
}
func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, bClient *bbolt.DB, bConfig boltclient.Config) auth.Service {
func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, cacheClient *redis.Client, keyDuration time.Duration) auth.Service {
cache := cache.NewPatsCache(cacheClient, keyDuration)
database := pgclient.NewDatabase(db, dbConfig, tracer)
keysRepo := apostgres.New(database)
patsRepo := bolt.NewPATSRepository(bClient, bConfig.Bucket)
patsRepo := apostgres.NewPatRepo(database, cache)
hasher := hasher.New()
idProvider := uuid.New()
@@ -243,7 +240,7 @@ func newService(_ context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config,
t := jwt.New([]byte(cfg.SecretKey))
svc := auth.New(keysRepo, patsRepo, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc := auth.New(keysRepo, patsRepo, nil, hasher, idProvider, t, pEvaluator, pService, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration)
svc = api.LoggingMiddleware(svc, logger)
counter, latency := prometheus.MakeMetrics("auth", "api")
svc = api.MetricsMiddleware(svc, counter, latency)
+2
View File
@@ -99,6 +99,8 @@ SMQ_AUTH_ACCESS_TOKEN_DURATION="1h"
SMQ_AUTH_REFRESH_TOKEN_DURATION="24h"
SMQ_AUTH_INVITATION_DURATION="168h"
SMQ_AUTH_ADAPTER_INSTANCE_ID=
SMQ_AUTH_CACHE_URL=redis://auth-redis:${SMQ_REDIS_TCP_PORT}/0
SMQ_AUTH_CACHE_KEY_DURATION=10m
#### Auth Client Config
SMQ_AUTH_URL=auth:9001
+12
View File
@@ -21,6 +21,7 @@ volumes:
supermq-domains-db-volume:
supermq-domains-redis-volume:
supermq-ui-db-volume:
supermq-auth-redis-volume:
services:
spicedb:
@@ -84,6 +85,15 @@ services:
volumes:
- supermq-auth-db-volume:/var/lib/postgresql/data
auth-redis:
image: redis:7.2.4-alpine
container_name: supermq-auth-redis
restart: on-failure
networks:
- supermq-base-net
volumes:
- supermq-auth-redis-volume:/data
auth:
image: supermq/auth:${SMQ_RELEASE_TAG}
container_name: supermq-auth
@@ -130,6 +140,8 @@ services:
SMQ_SEND_TELEMETRY: ${SMQ_SEND_TELEMETRY}
SMQ_AUTH_ADAPTER_INSTANCE_ID: ${SMQ_AUTH_ADAPTER_INSTANCE_ID}
SMQ_ES_URL: ${SMQ_ES_URL}
SMQ_AUTH_CACHE_URL: ${SMQ_AUTH_CACHE_URL}
SMQ_AUTH_CACHE_KEY_DURATION: ${SMQ_AUTH_CACHE_KEY_DURATION}
ports:
- ${SMQ_AUTH_HTTP_PORT}:${SMQ_AUTH_HTTP_PORT}
- ${SMQ_AUTH_GRPC_PORT}:${SMQ_AUTH_GRPC_PORT}
+1
View File
@@ -13,6 +13,7 @@ fi
envsubst '
${SMQ_NGINX_SERVER_NAME}
${SMQ_AUTH_HTTP_PORT}
${SMQ_DOMAINS_HTTP_PORT}
${SMQ_GROUPS_HTTP_PORT}
${SMQ_USERS_HTTP_PORT}
+7
View File
@@ -57,6 +57,13 @@ http {
add_header Access-Control-Allow-Methods '*';
add_header Access-Control-Allow-Headers '*';
# Proxy pass to auth service
location ~ ^/(pats) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT};
}
# Proxy pass to domains service
location ~ ^/(domains|invitations) {
include snippets/proxy-headers.conf;
+7
View File
@@ -66,6 +66,13 @@ http {
add_header Access-Control-Allow-Methods '*';
add_header Access-Control-Allow-Headers '*';
# Proxy pass to auth service
location ~ ^/(pats) {
include snippets/proxy-headers.conf;
add_header Access-Control-Expose-Headers Location;
proxy_pass http://auth:${SMQ_AUTH_HTTP_PORT};
}
# Proxy pass to domains service
location ~ ^/(domains|invitations) {
include snippets/proxy-headers.conf;
-1
View File
@@ -41,7 +41,6 @@ require (
github.com/spf13/cobra v1.9.1
github.com/sqids/sqids-go v0.4.1
github.com/stretchr/testify v1.10.0
go.etcd.io/bbolt v1.4.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0
go.opentelemetry.io/otel v1.34.0
-2
View File
@@ -425,8 +425,6 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.etcd.io/bbolt v1.4.0 h1:TU77id3TnN/zKr7CO/uk+fBCwF2jGcMuw2B/FMAzYIk=
go.etcd.io/bbolt v1.4.0/go.mod h1:AsD+OCi/qPN1giOX1aiLAha3o1U8rAz65bvN4j0sRuk=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.59.0 h1:rgMkmiGfix9vFJDcDi1PK8WEQP4FLQwLDfhp5ZLpFeE=
+84 -98
View File
@@ -84,13 +84,12 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups.
func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.CreateOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: auth.AnyIDs,
}); err != nil {
return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -125,13 +124,12 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth
func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{g.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: g.ID,
}); err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -154,13 +152,12 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth
func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.ReadOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.ReadOp,
EntityID: id,
}); err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -183,13 +180,12 @@ func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.
func (am *authorizationMiddleware) ListGroups(ctx context.Context, session authn.Session, gm groups.PageMeta) (groups.Page, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.ListOp,
EntityIDs: auth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: auth.AnyIDs,
}); err != nil {
return groups.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -235,13 +231,12 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a
func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -263,13 +258,12 @@ func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session auth
func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return groups.Group{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -291,13 +285,12 @@ func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session aut
func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -319,13 +312,12 @@ func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session auth
func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, session authn.Session, id string, hm groups.HierarchyPageMeta) (groups.HierarchyPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.ListOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: id,
}); err != nil {
return groups.HierarchyPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -346,13 +338,12 @@ func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, s
func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session authn.Session, id, parentID string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -383,13 +374,12 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -427,13 +417,12 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.UpdateOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -467,13 +456,12 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio
func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -495,13 +483,12 @@ func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, ses
func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -523,13 +510,12 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context,
func (am *authorizationMiddleware) ListChildrenGroups(ctx context.Context, session authn.Session, id string, startLevel, endLevel int64, pm groups.PageMeta) (groups.Page, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: auth.PlatformDomainsScope,
OptionalDomainID: session.DomainID,
OptionalDomainEntityType: auth.DomainGroupsScope,
Operation: auth.ListOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.GroupsType,
OptionalDomainID: session.DomainID,
Operation: auth.ListOp,
EntityID: id,
}); err != nil {
return groups.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
-83
View File
@@ -1,83 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package bolt
import (
"io/fs"
"strconv"
"time"
"github.com/absmach/supermq/pkg/errors"
"github.com/caarlos0/env/v11"
bolt "go.etcd.io/bbolt"
)
var (
errConfig = errors.New("failed to load BoltDB configuration")
errConnect = errors.New("failed to connect to BoltDB database")
errInit = errors.New("failed to initialize to BoltDB database")
)
type FileMode fs.FileMode
func (fm *FileMode) UnmarshalText(text []byte) error {
temp, err := strconv.ParseUint(string(text), 8, 32)
if err != nil {
return err
}
*fm = FileMode(temp)
return nil
}
// Config contains BoltDB specific parameters.
type Config struct {
FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./supermq-data"`
FileName string `env:"FILE_NAME" envDefault:"supermq-pat.db"`
FileMode FileMode `env:"FILE_MODE" envDefault:"0600"`
Bucket string `env:"BUCKET" envDefault:"supermq"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"0"`
}
// Setup load configuration from environment and creates new BoltDB.
func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
return SetupDB(envPrefix, initFn)
}
// SetupDB load configuration from environment,.
func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
cfg := Config{}
if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil {
return nil, errors.Wrap(errConfig, err)
}
bdb, err := Connect(cfg, initFn)
if err != nil {
return nil, err
}
return bdb, nil
}
// Connect establishes connection to the BoltDB.
func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) {
filePath := cfg.FileDirPath + "/" + cfg.FileName
db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil)
if err != nil {
return nil, errors.Wrap(errConnect, err)
}
if initFn != nil {
if err := Init(db, cfg, initFn); err != nil {
return nil, err
}
}
return db, nil
}
func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error {
if err := db.Update(func(tx *bolt.Tx) error {
return initFn(tx, cfg.Bucket)
}); err != nil {
return errors.Wrap(errInit, err)
}
return nil
}
-9
View File
@@ -1,9 +0,0 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
// Package BoltDB contains the domain concept definitions needed to support
// Supermq BoltDB database functionality.
//
// It provides the abstraction of the BoltDB database service, which is used
// to configure, setup and connect to the BoltDB database.
package bolt
+6 -7
View File
@@ -39,13 +39,12 @@ message AuthZReq {
}
message AuthZPatReq {
string user_id = 1; // User id
string pat_id = 2; // Pat id
uint32 platform_entity_type = 3; // Platform entity type
string optional_domain_id = 4; // Optional domain id
uint32 optional_domain_entity_type = 5; // Optional domain entity type
uint32 operation = 6; // Operation
repeated string entity_ids = 7; // EntityIDs
string user_id = 1; // User id
string pat_id = 2; // Pat id
uint32 entity_type = 3; // Entity type
string optional_domain_id = 4; // Optional domain id
uint32 operation = 6; // Operation
string entity_id = 7; // EntityID
}
message AuthZRes {
+6 -7
View File
@@ -124,13 +124,12 @@ func (a authorization) checkDomain(ctx context.Context, subjectType, subject, do
func (a authorization) AuthorizePAT(ctx context.Context, pr authz.PatReq) error {
req := grpcAuthV1.AuthZPatReq{
UserId: pr.UserID,
PatId: pr.PatID,
PlatformEntityType: uint32(pr.PlatformEntityType),
OptionalDomainId: pr.OptionalDomainID,
OptionalDomainEntityType: uint32(pr.OptionalDomainEntityType),
Operation: uint32(pr.Operation),
EntityIds: pr.EntityIDs,
UserId: pr.UserID,
PatId: pr.PatID,
EntityType: uint32(pr.EntityType),
OptionalDomainId: pr.OptionalDomainID,
Operation: uint32(pr.Operation),
EntityId: pr.EntityID,
}
res, err := a.authSvcClient.AuthorizePAT(ctx, &req)
if err != nil {
+6 -7
View File
@@ -47,13 +47,12 @@ type PolicyReq struct {
}
type PatReq struct {
UserID string `json:"user_id,omitempty"` // UserID
PatID string `json:"pat_id,omitempty"` // UserID
PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` // Platform entity type
OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id
OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` // Optional domain entity type
Operation auth.OperationType `json:"operation,omitempty"` // Operation
EntityIDs []string `json:"entityIDs,omitempty"` // EntityIDs
UserID string `json:"user_id,omitempty"` // UserID
PatID string `json:"pat_id,omitempty"` // UserID
EntityType auth.EntityType `json:"entity_type,omitempty"` // Entity type
OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id
Operation auth.Operation `json:"operation,omitempty"` // Operation
EntityID string `json:"entityID,omitempty"` // EntityID
}
// Authz is supermq authorization library.
-1
View File
@@ -180,7 +180,6 @@ func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[s
if err != nil {
return errors.Wrap(errRemoveParentGroupEvent, err)
}
if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil {
return errors.Wrap(errRemoveParentGroupEvent, err)
}
+65 -78
View File
@@ -46,12 +46,11 @@ func (am *authorizationMiddleware) Register(ctx context.Context, session authn.S
func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.ReadOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.ReadOp,
EntityID: id,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -67,12 +66,11 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi
func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session authn.Session) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.ReadOp,
EntityIDs: []string{session.UserID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.ReadOp,
EntityID: session.UserID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -83,12 +81,11 @@ func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session auth
func (am *authorizationMiddleware) ListUsers(ctx context.Context, session authn.Session, pm users.Page) (users.UsersPage, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.ListOp,
EntityIDs: smqauth.AnyIDs{}.Values(),
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.ListOp,
EntityID: smqauth.AnyIDs,
}); err != nil {
return users.UsersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -107,12 +104,11 @@ func (am *authorizationMiddleware) SearchUsers(ctx context.Context, pm users.Pag
func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, user users.User) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{user.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: user.ID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -128,12 +124,11 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, user users.User) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{user.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: user.ID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -149,12 +144,11 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session authn.Session, id, email string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: id,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -169,12 +163,11 @@ func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session auth
func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session authn.Session, id, username string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: id,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -190,12 +183,11 @@ func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session a
func (am *authorizationMiddleware) UpdateProfilePicture(ctx context.Context, session authn.Session, user users.User) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{user.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: user.ID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -215,12 +207,11 @@ func (am *authorizationMiddleware) GenerateResetToken(ctx context.Context, email
func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, oldSecret, newSecret string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{session.UserID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: session.UserID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -240,12 +231,11 @@ func (am *authorizationMiddleware) SendPasswordReset(ctx context.Context, host,
func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn.Session, user users.User) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{user.ID},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: user.ID,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -265,12 +255,11 @@ func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn
func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: id,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -286,12 +275,11 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (users.User, error) {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.UpdateOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.UpdateOp,
EntityID: id,
}); err != nil {
return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}
@@ -307,12 +295,11 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error {
if session.Type == authn.PersonalAccessToken {
if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{
UserID: session.UserID,
PatID: session.PatID,
PlatformEntityType: smqauth.PlatformUsersScope,
OptionalDomainEntityType: smqauth.DomainNullScope,
Operation: smqauth.DeleteOp,
EntityIDs: []string{id},
UserID: session.UserID,
PatID: session.PatID,
EntityType: smqauth.UsersType,
Operation: smqauth.DeleteOp,
EntityID: id,
}); err != nil {
return errors.Wrap(svcerr.ErrUnauthorizedPAT, err)
}