SMQ-2627 - Align PATs with new architecture (#3295)

Signed-off-by: nyagamunene <stevenyaga2014@gmail.com>
This commit is contained in:
Steve Munene
2026-02-03 16:06:30 +03:00
committed by GitHub
parent 0d4f4c9266
commit 61c120f947
49 changed files with 950 additions and 1687 deletions
+66 -226
View File
@@ -70,11 +70,10 @@ func (x *AuthNReq) GetToken() string {
type AuthNRes struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"` // verified user
TokenType uint32 `protobuf:"varint,5,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` // token type
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
Verified bool `protobuf:"varint,4,opt,name=verified,proto3" json:"verified,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -137,25 +136,22 @@ func (x *AuthNRes) GetVerified() bool {
return false
}
func (x *AuthNRes) GetTokenType() uint32 {
if x != nil {
return x.TokenType
}
return 0
}
type PolicyReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
TokenType uint32 `protobuf:"varint,1,opt,name=token_type,json=tokenType,proto3" json:"token_type,omitempty"` // Token type
Domain string `protobuf:"bytes,2,opt,name=domain,proto3" json:"domain,omitempty"` // Domain
SubjectType string `protobuf:"bytes,3,opt,name=subject_type,json=subjectType,proto3" json:"subject_type,omitempty"` // Client or User
SubjectKind string `protobuf:"bytes,4,opt,name=subject_kind,json=subjectKind,proto3" json:"subject_kind,omitempty"` // ID or Token
SubjectRelation string `protobuf:"bytes,5,opt,name=subject_relation,json=subjectRelation,proto3" json:"subject_relation,omitempty"` // Subject relation
Subject string `protobuf:"bytes,6,opt,name=subject,proto3" json:"subject,omitempty"` // Subject value
Relation string `protobuf:"bytes,7,opt,name=relation,proto3" json:"relation,omitempty"` // Relation to filter
Permission string `protobuf:"bytes,8,opt,name=permission,proto3" json:"permission,omitempty"` // Action
Object string `protobuf:"bytes,9,opt,name=object,proto3" json:"object,omitempty"` // Object ID
ObjectType string `protobuf:"bytes,10,opt,name=object_type,json=objectType,proto3" json:"object_type,omitempty"` // Client, User, Group
Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"`
SubjectType string `protobuf:"bytes,2,opt,name=subject_type,json=subjectType,proto3" json:"subject_type,omitempty"`
SubjectKind string `protobuf:"bytes,3,opt,name=subject_kind,json=subjectKind,proto3" json:"subject_kind,omitempty"`
SubjectRelation string `protobuf:"bytes,4,opt,name=subject_relation,json=subjectRelation,proto3" json:"subject_relation,omitempty"`
Subject string `protobuf:"bytes,5,opt,name=subject,proto3" json:"subject,omitempty"`
Relation string `protobuf:"bytes,6,opt,name=relation,proto3" json:"relation,omitempty"`
Permission string `protobuf:"bytes,7,opt,name=permission,proto3" json:"permission,omitempty"`
Object string `protobuf:"bytes,8,opt,name=object,proto3" json:"object,omitempty"`
ObjectType string `protobuf:"bytes,9,opt,name=object_type,json=objectType,proto3" json:"object_type,omitempty"`
PatId string `protobuf:"bytes,10,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"`
Operation string `protobuf:"bytes,11,opt,name=operation,proto3" json:"operation,omitempty"`
UserId string `protobuf:"bytes,12,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
EntityId string `protobuf:"bytes,13,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"`
EntityType string `protobuf:"bytes,14,opt,name=entity_type,json=entityType,proto3" json:"entity_type,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -190,13 +186,6 @@ func (*PolicyReq) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{2}
}
func (x *PolicyReq) GetTokenType() uint32 {
if x != nil {
return x.TokenType
}
return 0
}
func (x *PolicyReq) GetDomain() string {
if x != nil {
return x.Domain
@@ -260,172 +249,41 @@ func (x *PolicyReq) GetObjectType() string {
return ""
}
type PATReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // User id (PAT)
PatId string `protobuf:"bytes,2,opt,name=pat_id,json=patId,proto3" json:"pat_id,omitempty"` // Pat id
EntityType uint32 `protobuf:"varint,3,opt,name=entity_type,json=entityType,proto3" json:"entity_type,omitempty"` // Entity type (PAT)
OptionalDomainId string `protobuf:"bytes,4,opt,name=optional_domain_id,json=optionalDomainId,proto3" json:"optional_domain_id,omitempty"` // Optional domain id (PAT)
Operation uint32 `protobuf:"varint,5,opt,name=operation,proto3" json:"operation,omitempty"` // Operation (PAT)
EntityId string `protobuf:"bytes,6,opt,name=entity_id,json=entityId,proto3" json:"entity_id,omitempty"` // EntityID (PAT)
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *PATReq) Reset() {
*x = PATReq{}
mi := &file_auth_v1_auth_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *PATReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PATReq) ProtoMessage() {}
func (x *PATReq) ProtoReflect() protoreflect.Message {
mi := &file_auth_v1_auth_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PATReq.ProtoReflect.Descriptor instead.
func (*PATReq) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{3}
}
func (x *PATReq) GetUserId() string {
if x != nil {
return x.UserId
}
return ""
}
func (x *PATReq) GetPatId() string {
func (x *PolicyReq) GetPatId() string {
if x != nil {
return x.PatId
}
return ""
}
func (x *PATReq) GetEntityType() uint32 {
func (x *PolicyReq) GetOperation() string {
if x != nil {
return x.EntityType
}
return 0
}
func (x *PATReq) GetOptionalDomainId() string {
if x != nil {
return x.OptionalDomainId
return x.Operation
}
return ""
}
func (x *PATReq) GetOperation() uint32 {
func (x *PolicyReq) GetUserId() string {
if x != nil {
return x.Operation
return x.UserId
}
return 0
return ""
}
func (x *PATReq) GetEntityId() string {
func (x *PolicyReq) GetEntityId() string {
if x != nil {
return x.EntityId
}
return ""
}
type AuthZReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to AuthType:
//
// *AuthZReq_Policy
// *AuthZReq_Pat
AuthType isAuthZReq_AuthType `protobuf_oneof:"auth_type"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *AuthZReq) Reset() {
*x = AuthZReq{}
mi := &file_auth_v1_auth_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *AuthZReq) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*AuthZReq) ProtoMessage() {}
func (x *AuthZReq) ProtoReflect() protoreflect.Message {
mi := &file_auth_v1_auth_proto_msgTypes[4]
func (x *PolicyReq) GetEntityType() string {
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
return x.EntityType
}
return mi.MessageOf(x)
return ""
}
// Deprecated: Use AuthZReq.ProtoReflect.Descriptor instead.
func (*AuthZReq) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{4}
}
func (x *AuthZReq) GetAuthType() isAuthZReq_AuthType {
if x != nil {
return x.AuthType
}
return nil
}
func (x *AuthZReq) GetPolicy() *PolicyReq {
if x != nil {
if x, ok := x.AuthType.(*AuthZReq_Policy); ok {
return x.Policy
}
}
return nil
}
func (x *AuthZReq) GetPat() *PATReq {
if x != nil {
if x, ok := x.AuthType.(*AuthZReq_Pat); ok {
return x.Pat
}
}
return nil
}
type isAuthZReq_AuthType interface {
isAuthZReq_AuthType()
}
type AuthZReq_Policy struct {
Policy *PolicyReq `protobuf:"bytes,1,opt,name=policy,proto3,oneof"` // Policy-based authorization
}
type AuthZReq_Pat struct {
Pat *PATReq `protobuf:"bytes,2,opt,name=pat,proto3,oneof"` // PAT authorization
}
func (*AuthZReq_Policy) isAuthZReq_AuthType() {}
func (*AuthZReq_Pat) isAuthZReq_AuthType() {}
type AuthZRes struct {
state protoimpl.MessageState `protogen:"open.v1"`
Authorized bool `protobuf:"varint,1,opt,name=authorized,proto3" json:"authorized,omitempty"`
@@ -436,7 +294,7 @@ type AuthZRes struct {
func (x *AuthZRes) Reset() {
*x = AuthZRes{}
mi := &file_auth_v1_auth_proto_msgTypes[5]
mi := &file_auth_v1_auth_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -448,7 +306,7 @@ func (x *AuthZRes) String() string {
func (*AuthZRes) ProtoMessage() {}
func (x *AuthZRes) ProtoReflect() protoreflect.Message {
mi := &file_auth_v1_auth_proto_msgTypes[5]
mi := &file_auth_v1_auth_proto_msgTypes[3]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -461,7 +319,7 @@ func (x *AuthZRes) ProtoReflect() protoreflect.Message {
// Deprecated: Use AuthZRes.ProtoReflect.Descriptor instead.
func (*AuthZRes) Descriptor() ([]byte, []int) {
return file_auth_v1_auth_proto_rawDescGZIP(), []int{5}
return file_auth_v1_auth_proto_rawDescGZIP(), []int{3}
}
func (x *AuthZRes) GetAuthorized() bool {
@@ -484,49 +342,39 @@ const file_auth_v1_auth_proto_rawDesc = "" +
"\n" +
"\x12auth/v1/auth.proto\x12\aauth.v1\" \n" +
"\bAuthNReq\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\"\x8b\x01\n" +
"\x05token\x18\x01 \x01(\tR\x05token\"l\n" +
"\bAuthNRes\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" +
"\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" +
"\tuser_role\x18\x03 \x01(\rR\buserRole\x12\x1a\n" +
"\bverified\x18\x04 \x01(\bR\bverified\x12\x1d\n" +
"\bverified\x18\x04 \x01(\bR\bverified\"\xaf\x03\n" +
"\tPolicyReq\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" +
"\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" +
"\fsubject_kind\x18\x03 \x01(\tR\vsubjectKind\x12)\n" +
"\x10subject_relation\x18\x04 \x01(\tR\x0fsubjectRelation\x12\x18\n" +
"\asubject\x18\x05 \x01(\tR\asubject\x12\x1a\n" +
"\brelation\x18\x06 \x01(\tR\brelation\x12\x1e\n" +
"\n" +
"token_type\x18\x05 \x01(\rR\ttokenType\"\xc2\x02\n" +
"\tPolicyReq\x12\x1d\n" +
"\n" +
"token_type\x18\x01 \x01(\rR\ttokenType\x12\x16\n" +
"\x06domain\x18\x02 \x01(\tR\x06domain\x12!\n" +
"\fsubject_type\x18\x03 \x01(\tR\vsubjectType\x12!\n" +
"\fsubject_kind\x18\x04 \x01(\tR\vsubjectKind\x12)\n" +
"\x10subject_relation\x18\x05 \x01(\tR\x0fsubjectRelation\x12\x18\n" +
"\asubject\x18\x06 \x01(\tR\asubject\x12\x1a\n" +
"\brelation\x18\a \x01(\tR\brelation\x12\x1e\n" +
"\n" +
"permission\x18\b \x01(\tR\n" +
"permission\x18\a \x01(\tR\n" +
"permission\x12\x16\n" +
"\x06object\x18\t \x01(\tR\x06object\x12\x1f\n" +
"\vobject_type\x18\n" +
" \x01(\tR\n" +
"objectType\"\xc2\x01\n" +
"\x06PATReq\x12\x17\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x15\n" +
"\x06pat_id\x18\x02 \x01(\tR\x05patId\x12\x1f\n" +
"\ventity_type\x18\x03 \x01(\rR\n" +
"entityType\x12,\n" +
"\x12optional_domain_id\x18\x04 \x01(\tR\x10optionalDomainId\x12\x1c\n" +
"\toperation\x18\x05 \x01(\rR\toperation\x12\x1b\n" +
"\tentity_id\x18\x06 \x01(\tR\bentityId\"j\n" +
"\bAuthZReq\x12,\n" +
"\x06policy\x18\x01 \x01(\v2\x12.auth.v1.PolicyReqH\x00R\x06policy\x12#\n" +
"\x03pat\x18\x02 \x01(\v2\x0f.auth.v1.PATReqH\x00R\x03patB\v\n" +
"\tauth_type\":\n" +
"\x06object\x18\b \x01(\tR\x06object\x12\x1f\n" +
"\vobject_type\x18\t \x01(\tR\n" +
"objectType\x12\x15\n" +
"\x06pat_id\x18\n" +
" \x01(\tR\x05patId\x12\x1c\n" +
"\toperation\x18\v \x01(\tR\toperation\x12\x17\n" +
"\auser_id\x18\f \x01(\tR\x06userId\x12\x1b\n" +
"\tentity_id\x18\r \x01(\tR\bentityId\x12\x1f\n" +
"\ventity_type\x18\x0e \x01(\tR\n" +
"entityType\":\n" +
"\bAuthZRes\x12\x1e\n" +
"\n" +
"authorized\x18\x01 \x01(\bR\n" +
"authorized\x12\x0e\n" +
"\x02id\x18\x02 \x01(\tR\x02id2z\n" +
"\vAuthService\x123\n" +
"\tAuthorize\x12\x11.auth.v1.AuthZReq\x1a\x11.auth.v1.AuthZRes\"\x00\x126\n" +
"\x02id\x18\x02 \x01(\tR\x02id2{\n" +
"\vAuthService\x124\n" +
"\tAuthorize\x12\x12.auth.v1.PolicyReq\x1a\x11.auth.v1.AuthZRes\"\x00\x126\n" +
"\fAuthenticate\x12\x11.auth.v1.AuthNReq\x1a\x11.auth.v1.AuthNRes\"\x00B-Z+github.com/absmach/supermq/api/grpc/auth/v1b\x06proto3"
var (
@@ -541,27 +389,23 @@ func file_auth_v1_auth_proto_rawDescGZIP() []byte {
return file_auth_v1_auth_proto_rawDescData
}
var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_auth_v1_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_auth_v1_auth_proto_goTypes = []any{
(*AuthNReq)(nil), // 0: auth.v1.AuthNReq
(*AuthNRes)(nil), // 1: auth.v1.AuthNRes
(*PolicyReq)(nil), // 2: auth.v1.PolicyReq
(*PATReq)(nil), // 3: auth.v1.PATReq
(*AuthZReq)(nil), // 4: auth.v1.AuthZReq
(*AuthZRes)(nil), // 5: auth.v1.AuthZRes
(*AuthZRes)(nil), // 3: auth.v1.AuthZRes
}
var file_auth_v1_auth_proto_depIdxs = []int32{
2, // 0: auth.v1.AuthZReq.policy:type_name -> auth.v1.PolicyReq
3, // 1: auth.v1.AuthZReq.pat:type_name -> auth.v1.PATReq
4, // 2: auth.v1.AuthService.Authorize:input_type -> auth.v1.AuthZReq
0, // 3: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq
5, // 4: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes
1, // 5: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes
4, // [4:6] is the sub-list for method output_type
2, // [2:4] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
2, // 0: auth.v1.AuthService.Authorize:input_type -> auth.v1.PolicyReq
0, // 1: auth.v1.AuthService.Authenticate:input_type -> auth.v1.AuthNReq
3, // 2: auth.v1.AuthService.Authorize:output_type -> auth.v1.AuthZRes
1, // 3: auth.v1.AuthService.Authenticate:output_type -> auth.v1.AuthNRes
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_auth_v1_auth_proto_init() }
@@ -569,17 +413,13 @@ func file_auth_v1_auth_proto_init() {
if File_auth_v1_auth_proto != nil {
return
}
file_auth_v1_auth_proto_msgTypes[4].OneofWrappers = []any{
(*AuthZReq_Policy)(nil),
(*AuthZReq_Pat)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_auth_v1_auth_proto_rawDesc), len(file_auth_v1_auth_proto_rawDesc)),
NumEnums: 0,
NumMessages: 6,
NumMessages: 4,
NumExtensions: 0,
NumServices: 1,
},
+6 -6
View File
@@ -33,7 +33,7 @@ const (
// AuthService is a service that provides authentication
// and authorization functionalities for SuperMQ services.
type AuthServiceClient interface {
Authorize(ctx context.Context, in *AuthZReq, opts ...grpc.CallOption) (*AuthZRes, error)
Authorize(ctx context.Context, in *PolicyReq, opts ...grpc.CallOption) (*AuthZRes, error)
Authenticate(ctx context.Context, in *AuthNReq, opts ...grpc.CallOption) (*AuthNRes, error)
}
@@ -45,7 +45,7 @@ func NewAuthServiceClient(cc grpc.ClientConnInterface) AuthServiceClient {
return &authServiceClient{cc}
}
func (c *authServiceClient) Authorize(ctx context.Context, in *AuthZReq, opts ...grpc.CallOption) (*AuthZRes, error) {
func (c *authServiceClient) Authorize(ctx context.Context, in *PolicyReq, opts ...grpc.CallOption) (*AuthZRes, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AuthZRes)
err := c.cc.Invoke(ctx, AuthService_Authorize_FullMethodName, in, out, cOpts...)
@@ -72,7 +72,7 @@ func (c *authServiceClient) Authenticate(ctx context.Context, in *AuthNReq, opts
// AuthService is a service that provides authentication
// and authorization functionalities for SuperMQ services.
type AuthServiceServer interface {
Authorize(context.Context, *AuthZReq) (*AuthZRes, error)
Authorize(context.Context, *PolicyReq) (*AuthZRes, error)
Authenticate(context.Context, *AuthNReq) (*AuthNRes, error)
mustEmbedUnimplementedAuthServiceServer()
}
@@ -84,7 +84,7 @@ type AuthServiceServer interface {
// pointer dereference when methods are called.
type UnimplementedAuthServiceServer struct{}
func (UnimplementedAuthServiceServer) Authorize(context.Context, *AuthZReq) (*AuthZRes, error) {
func (UnimplementedAuthServiceServer) Authorize(context.Context, *PolicyReq) (*AuthZRes, error) {
return nil, status.Errorf(codes.Unimplemented, "method Authorize not implemented")
}
func (UnimplementedAuthServiceServer) Authenticate(context.Context, *AuthNReq) (*AuthNRes, error) {
@@ -112,7 +112,7 @@ func RegisterAuthServiceServer(s grpc.ServiceRegistrar, srv AuthServiceServer) {
}
func _AuthService_Authorize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AuthZReq)
in := new(PolicyReq)
if err := dec(in); err != nil {
return nil, err
}
@@ -124,7 +124,7 @@ func _AuthService_Authorize_Handler(srv interface{}, ctx context.Context, dec fu
FullMethod: AuthService_Authorize_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthServiceServer).Authorize(ctx, req.(*AuthZReq))
return srv.(AuthServiceServer).Authorize(ctx, req.(*PolicyReq))
}
return interceptor(ctx, in, info, handler)
}
+29 -51
View File
@@ -70,32 +70,27 @@ func decodeIdentifyResponse(_ context.Context, grpcRes any) (any, error) {
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole), verified: res.GetVerified()}, nil
}
func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.PolicyReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
var authReqData authReq
if policy := req.GetPolicy(); policy != nil {
if req != nil {
authReqData = authReq{
TokenType: policy.GetTokenType(),
Domain: policy.GetDomain(),
SubjectType: policy.GetSubjectType(),
Subject: policy.GetSubject(),
SubjectKind: policy.GetSubjectKind(),
Relation: policy.GetRelation(),
Permission: policy.GetPermission(),
ObjectType: policy.GetObjectType(),
Object: policy.GetObject(),
}
} else if pat := req.GetPat(); pat != nil {
authReqData = authReq{
UserID: pat.GetUserId(),
PatID: pat.GetPatId(),
EntityType: auth.EntityType(pat.GetEntityType()),
OptionalDomainID: pat.GetOptionalDomainId(),
Operation: auth.Operation(pat.GetOperation()),
EntityID: pat.GetEntityId(),
Domain: req.GetDomain(),
SubjectType: req.GetSubjectType(),
Subject: req.GetSubject(),
SubjectKind: req.GetSubjectKind(),
Relation: req.GetRelation(),
Permission: req.GetPermission(),
ObjectType: req.GetObjectType(),
Object: req.GetObject(),
UserID: req.GetUserId(),
PatID: req.GetPatId(),
EntityType: req.GetEntityType(),
Operation: req.GetOperation(),
EntityID: req.GetEntityId(),
}
}
@@ -116,36 +111,19 @@ func decodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
func encodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(authReq)
// Check if this is a PAT request (has PatID) or policy request
if req.PatID != "" {
return &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
UserId: req.UserID,
PatId: req.PatID,
EntityType: uint32(req.EntityType),
OptionalDomainId: req.OptionalDomainID,
Operation: uint32(req.Operation),
EntityId: req.EntityID,
},
},
}, nil
}
// Otherwise, it's a policy request
return &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
TokenType: req.TokenType,
Domain: req.Domain,
SubjectType: req.SubjectType,
Subject: req.Subject,
SubjectKind: req.SubjectKind,
Relation: req.Relation,
Permission: req.Permission,
ObjectType: req.ObjectType,
Object: req.Object,
},
},
return &grpcAuthV1.PolicyReq{
Domain: req.Domain,
SubjectType: req.SubjectType,
Subject: req.Subject,
SubjectKind: req.SubjectKind,
Relation: req.Relation,
Permission: req.Permission,
ObjectType: req.ObjectType,
Object: req.Object,
UserId: req.UserID,
PatId: req.PatID,
EntityType: req.EntityType,
Operation: req.Operation,
EntityId: req.EntityID,
}, nil
}
+13 -15
View File
@@ -36,21 +36,19 @@ func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
}
err := svc.Authorize(ctx, policies.Policy{
TokenType: req.TokenType,
Domain: req.Domain,
SubjectType: req.SubjectType,
SubjectKind: req.SubjectKind,
Subject: req.Subject,
Relation: req.Relation,
Permission: req.Permission,
ObjectType: req.ObjectType,
Object: req.Object,
UserID: req.UserID,
PatID: req.PatID,
EntityType: uint32(req.EntityType),
OptionalDomainID: req.OptionalDomainID,
Operation: uint32(req.Operation),
EntityID: req.EntityID,
Domain: req.Domain,
SubjectType: req.SubjectType,
SubjectKind: req.SubjectKind,
Subject: req.Subject,
Relation: req.Relation,
Permission: req.Permission,
ObjectType: req.ObjectType,
Object: req.Object,
PatID: req.PatID,
Operation: req.Operation,
UserID: req.UserID,
EntityType: req.EntityType,
EntityID: req.EntityID,
})
if err != nil {
return authorizeRes{authorized: false}, err
+101 -120
View File
@@ -17,6 +17,7 @@ import (
"github.com/absmach/supermq/internal/testsutil"
"github.com/absmach/supermq/pkg/errors"
svcerr "github.com/absmach/supermq/pkg/errors/service"
"github.com/absmach/supermq/pkg/policies"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
@@ -128,24 +129,20 @@ func TestAuthorize(t *testing.T) {
cases := []struct {
desc string
token string
authRequest *grpcAuthV1.AuthZReq
authRequest *grpcAuthV1.PolicyReq
authResponse *grpcAuthV1.AuthZRes
err error
}{
{
desc: "authorize user with authorized token",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
@@ -153,17 +150,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with unauthorized token",
token: inValidToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
@@ -171,17 +164,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with empty subject",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: "",
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: "",
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicySub,
@@ -189,17 +178,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with empty subject type",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: "",
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: "",
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicySub,
@@ -207,17 +192,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with empty object",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: "",
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: "",
ObjectType: usersType,
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicyObj,
@@ -225,17 +206,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with empty object type",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: "",
Relation: memberRelation,
Permission: adminPermission,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: "",
Relation: memberRelation,
Permission: adminPermission,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPolicyObj,
@@ -243,17 +220,13 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with empty permission",
token: validToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: "",
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: usersType,
Object: authoritiesObj,
ObjectType: usersType,
Relation: memberRelation,
Permission: "",
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMalformedPolicyPer,
@@ -261,17 +234,19 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with valid PAT token",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
UserId: id,
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
PatId: id,
ObjectType: policies.ClientType,
Domain: domainID,
Operation: "view",
Object: clientID,
UserId: id,
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
@@ -279,17 +254,19 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize user with unauthorized PAT token",
token: inValidToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
UserId: id,
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
PatId: id,
ObjectType: policies.ClientType,
Domain: domainID,
Operation: "view",
Object: clientID,
UserId: id,
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
@@ -297,16 +274,18 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize PAT with missing user id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
EntityId: clientID,
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
PatId: id,
ObjectType: policies.ClientType,
Domain: domainID,
Operation: "view",
Object: clientID,
EntityId: clientID,
EntityType: auth.ClientsScopeStr,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingUserID,
@@ -314,16 +293,18 @@ func TestAuthorize(t *testing.T) {
{
desc: "authorize PAT with missing entity id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
UserId: id,
PatId: id,
EntityType: uint32(auth.ClientsType),
OptionalDomainId: domainID,
Operation: uint32(auth.CreateOp),
},
},
authRequest: &grpcAuthV1.PolicyReq{
Subject: id,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Permission: policies.ViewPermission,
PatId: id,
ObjectType: policies.ClientType,
Domain: domainID,
Operation: "view",
Object: clientID,
UserId: id,
EntityType: auth.ClientsScopeStr,
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingID,
+8 -9
View File
@@ -5,7 +5,6 @@ package auth
import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
)
type authenticateReq struct {
@@ -25,7 +24,6 @@ func (req authenticateReq) validate() error {
// 2. object - an entity over which action will be executed
// 3. action - type of action that will be executed (read/write).
type authReq struct {
TokenType uint32
Domain string
SubjectType string
SubjectKind string
@@ -36,12 +34,11 @@ type authReq struct {
Object string
// PAT authorization fields
UserID string
PatID string
EntityType auth.EntityType
OptionalDomainID string
Operation auth.Operation
EntityID string
UserID string
PatID string
EntityType string
Operation string
EntityID string
}
func (req authReq) validate() error {
@@ -52,7 +49,9 @@ func (req authReq) validate() error {
if req.EntityID == "" {
return apiutil.ErrMissingID
}
return nil
if req.EntityType == "" {
return apiutil.ErrMissingPolicyObj
}
}
if req.Subject == "" || req.SubjectType == "" {
+19 -25
View File
@@ -45,7 +45,7 @@ func (s *authGrpcServer) Authenticate(ctx context.Context, req *grpcAuthV1.AuthN
return res.(*grpcAuthV1.AuthNRes), nil
}
func (s *authGrpcServer) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq) (*grpcAuthV1.AuthZRes, error) {
func (s *authGrpcServer) Authorize(ctx context.Context, req *grpcAuthV1.PolicyReq) (*grpcAuthV1.AuthZRes, error) {
_, res, err := s.authorize.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
@@ -64,32 +64,26 @@ func encodeAuthenticateResponse(_ context.Context, grpcRes any) (any, error) {
}
func decodeAuthorizeRequest(_ context.Context, grpcReq any) (any, error) {
req := grpcReq.(*grpcAuthV1.AuthZReq)
if policy := req.GetPolicy(); policy != nil {
return authReq{
TokenType: policy.GetTokenType(),
Domain: policy.GetDomain(),
SubjectType: policy.GetSubjectType(),
SubjectKind: policy.GetSubjectKind(),
Subject: policy.GetSubject(),
Relation: policy.GetRelation(),
Permission: policy.GetPermission(),
ObjectType: policy.GetObjectType(),
Object: policy.GetObject(),
}, nil
}
if pat := req.GetPat(); pat != nil {
return authReq{
UserID: pat.GetUserId(),
PatID: pat.GetPatId(),
EntityType: auth.EntityType(pat.GetEntityType()),
OptionalDomainID: pat.GetOptionalDomainId(),
Operation: auth.Operation(pat.GetOperation()),
EntityID: pat.GetEntityId(),
}, nil
req := grpcReq.(*grpcAuthV1.PolicyReq)
if req == nil {
return authReq{}, nil
}
return authReq{}, nil
return authReq{
Domain: req.GetDomain(),
SubjectType: req.GetSubjectType(),
SubjectKind: req.GetSubjectKind(),
Subject: req.GetSubject(),
Relation: req.GetRelation(),
Permission: req.GetPermission(),
ObjectType: req.GetObjectType(),
Object: req.GetObject(),
UserID: req.GetUserId(),
PatID: req.GetPatId(),
EntityType: req.GetEntityType(),
Operation: req.GetOperation(),
EntityID: req.GetEntityId(),
}, nil
}
func encodeAuthorizeResponse(_ context.Context, grpcRes any) (any, error) {
+8 -8
View File
@@ -508,17 +508,17 @@ func TestClearAllPATReqValidate(t *testing.T) {
func TestAddScopeReqValidate(t *testing.T) {
validScope := auth.Scope{
OptionalDomainID: "domain1",
EntityType: auth.GroupsType,
EntityID: "entity1",
Operation: auth.CreateOp,
DomainID: "domain1",
EntityType: auth.GroupsType,
EntityID: "entity1",
Operation: "create",
}
invalidScope := auth.Scope{
OptionalDomainID: "",
EntityType: auth.GroupsType,
EntityID: "",
Operation: auth.CreateOp,
DomainID: "",
EntityType: auth.GroupsType,
EntityID: "",
Operation: "view",
}
cases := []struct {
+6 -6
View File
@@ -28,7 +28,7 @@ func NewPatsCache(client *redis.Client, duration time.Duration) auth.Cache {
func (pc *patCache) Save(ctx context.Context, userID string, scopes []auth.Scope) error {
for _, sc := range scopes {
key := generateKey(userID, sc.PatID, sc.OptionalDomainID, sc.EntityType, sc.Operation, sc.EntityID)
key := generateKey(userID, sc.PatID, sc.DomainID, sc.EntityType, sc.Operation, sc.EntityID)
if err := pc.client.Set(ctx, key, sc.ID, pc.duration).Err(); err != nil {
return errors.Wrap(repoerr.ErrCreateEntity, err)
}
@@ -37,9 +37,9 @@ func (pc *patCache) Save(ctx context.Context, userID string, scopes []auth.Scope
return nil
}
func (pc *patCache) CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool {
exactKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainID, operation, entityID)
wildcardKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:*", userID, patID, entityType, operation, operation)
func (pc *patCache) CheckScope(ctx context.Context, userID, patID, domainID string, entityType auth.EntityType, operation string, entityID string) bool {
exactKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, domainID, operation, entityID)
wildcardKey := fmt.Sprintf("pat:%s:%s:%s:%s:%s:*", userID, patID, entityType, domainID, operation)
res, err := pc.client.Exists(ctx, exactKey, wildcardKey).Result()
if err != nil {
@@ -115,6 +115,6 @@ func (pc *patCache) RemoveAllScope(ctx context.Context, userID, patID string) er
return nil
}
func generateKey(userID, patID, optionalDomainId string, entityType auth.EntityType, operation auth.Operation, entityID string) string {
return fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, optionalDomainId, operation, entityID)
func generateKey(userID, patID, domainId string, entityType auth.EntityType, operation string, entityID string) string {
return fmt.Sprintf("pat:%s:%s:%s:%s:%s:%s", userID, patID, entityType, domainId, operation, entityID)
}
+6 -6
View File
@@ -307,8 +307,8 @@ func (lm *loggingMiddleware) AddScope(ctx context.Context, token, patID string,
var groupArgs []any
for _, s := range scopes {
groupArgs = append(groupArgs, slog.String("entity_type", s.EntityType.String()))
groupArgs = append(groupArgs, slog.String("optional_domain_id", s.OptionalDomainID))
groupArgs = append(groupArgs, slog.String("operation", s.Operation.String()))
groupArgs = append(groupArgs, slog.String("domain_id", s.DomainID))
groupArgs = append(groupArgs, slog.String("operation", s.Operation))
groupArgs = append(groupArgs, slog.String("entity_id", s.EntityID))
}
@@ -379,13 +379,13 @@ func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (p
return lm.svc.IdentifyPAT(ctx, paToken)
}
func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) (err error) {
func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) (err error) {
defer func(begin time.Time) {
args := []any{
slog.String("duration", time.Since(begin).String()),
slog.String("entity_type", entityType.String()),
slog.String("optional_domain_id", optionalDomainID),
slog.String("operation", operation.String()),
slog.String("domain_id", domainID),
slog.String("operation", operation),
slog.String("entities", entityID),
}
if err != nil {
@@ -395,5 +395,5 @@ func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, userID, patID str
}
lm.logger.Info("Authorize PAT completed successfully", args...)
}(time.Now())
return lm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
return lm.svc.AuthorizePAT(ctx, userID, patID, entityType, domainID, operation, entityID)
}
+2 -2
View File
@@ -195,10 +195,10 @@ func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a
return ms.svc.IdentifyPAT(ctx, paToken)
}
func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
defer func(begin time.Time) {
ms.counter.With("method", "authorize_pat").Add(1)
ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds())
}(time.Now())
return ms.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
return ms.svc.AuthorizePAT(ctx, userID, patID, entityType, domainID, operation, entityID)
}
+6 -6
View File
@@ -169,8 +169,8 @@ func (tm *tracingMiddleware) AddScope(ctx context.Context, token, patID string,
var attributes []attribute.KeyValue
for _, s := range scopes {
attributes = append(attributes, attribute.String("entity_type", s.EntityType.String()))
attributes = append(attributes, attribute.String("optional_domain_id", s.OptionalDomainID))
attributes = append(attributes, attribute.String("operation", s.Operation.String()))
attributes = append(attributes, attribute.String("domain_id", s.DomainID))
attributes = append(attributes, attribute.String("operation", s.Operation))
attributes = append(attributes, attribute.String("entity_id", s.EntityID))
}
@@ -208,14 +208,14 @@ func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (a
return tm.svc.IdentifyPAT(ctx, paToken)
}
func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, userID, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes(
attribute.String("pat_id", patID),
attribute.String("entity_type", entityType.String()),
attribute.String("optional_domain_id", optionalDomainID),
attribute.String("operation", operation.String()),
attribute.String("domain_id", domainID),
attribute.String("operation", operation),
attribute.String("entities", entityID),
))
defer span.End()
return tm.svc.AuthorizePAT(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
return tm.svc.AuthorizePAT(ctx, userID, patID, entityType, domainID, operation, entityID)
}
+7 -7
View File
@@ -43,7 +43,7 @@ func (_m *Cache) EXPECT() *Cache_Expecter {
}
// CheckScope provides a mock function for the type Cache
func (_mock *Cache) CheckScope(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool {
func (_mock *Cache) CheckScope(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation string, entityID string) bool {
ret := _mock.Called(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
if len(ret) == 0 {
@@ -51,7 +51,7 @@ func (_mock *Cache) CheckScope(ctx context.Context, userID string, patID string,
}
var r0 bool
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, auth.EntityType, auth.Operation, string) bool); ok {
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, string, auth.EntityType, string, string) bool); ok {
r0 = returnFunc(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
} else {
r0 = ret.Get(0).(bool)
@@ -70,13 +70,13 @@ type Cache_CheckScope_Call struct {
// - patID string
// - optionalDomainID string
// - entityType auth.EntityType
// - operation auth.Operation
// - operation string
// - entityID string
func (_e *Cache_Expecter) CheckScope(ctx interface{}, userID interface{}, patID interface{}, optionalDomainID interface{}, entityType interface{}, operation interface{}, entityID interface{}) *Cache_CheckScope_Call {
return &Cache_CheckScope_Call{Call: _e.mock.On("CheckScope", ctx, userID, patID, optionalDomainID, entityType, operation, entityID)}
}
func (_c *Cache_CheckScope_Call) Run(run func(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string)) *Cache_CheckScope_Call {
func (_c *Cache_CheckScope_Call) Run(run func(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation string, entityID string)) *Cache_CheckScope_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -98,9 +98,9 @@ func (_c *Cache_CheckScope_Call) Run(run func(ctx context.Context, userID string
if args[4] != nil {
arg4 = args[4].(auth.EntityType)
}
var arg5 auth.Operation
var arg5 string
if args[5] != nil {
arg5 = args[5].(auth.Operation)
arg5 = args[5].(string)
}
var arg6 string
if args[6] != nil {
@@ -124,7 +124,7 @@ func (_c *Cache_CheckScope_Call) Return(b bool) *Cache_CheckScope_Call {
return _c
}
func (_c *Cache_CheckScope_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation auth.Operation, entityID string) bool) *Cache_CheckScope_Call {
func (_c *Cache_CheckScope_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, optionalDomainID string, entityType auth.EntityType, operation string, entityID string) bool) *Cache_CheckScope_Call {
_c.Call.Return(run)
return _c
}
+12 -12
View File
@@ -113,16 +113,16 @@ func (_c *PATS_AddScope_Call) RunAndReturn(run func(ctx context.Context, token s
}
// AuthorizePAT provides a mock function for the type PATS
func (_mock *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
func (_mock *PATS) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, domainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, string, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, domainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -139,14 +139,14 @@ type PATS_AuthorizePAT_Call struct {
// - userID string
// - patID string
// - entityType auth.EntityType
// - optionalDomainID string
// - operation auth.Operation
// - domainID string
// - operation string
// - entityID string
func (_e *PATS_Expecter) AuthorizePAT(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, optionalDomainID interface{}, operation interface{}, entityID interface{}) *PATS_AuthorizePAT_Call {
return &PATS_AuthorizePAT_Call{Call: _e.mock.On("AuthorizePAT", ctx, userID, patID, entityType, optionalDomainID, operation, entityID)}
func (_e *PATS_Expecter) AuthorizePAT(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, domainID interface{}, operation interface{}, entityID interface{}) *PATS_AuthorizePAT_Call {
return &PATS_AuthorizePAT_Call{Call: _e.mock.On("AuthorizePAT", ctx, userID, patID, entityType, domainID, operation, entityID)}
}
func (_c *PATS_AuthorizePAT_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string)) *PATS_AuthorizePAT_Call {
func (_c *PATS_AuthorizePAT_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string)) *PATS_AuthorizePAT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -168,9 +168,9 @@ func (_c *PATS_AuthorizePAT_Call) Run(run func(ctx context.Context, userID strin
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 auth.Operation
var arg5 string
if args[5] != nil {
arg5 = args[5].(auth.Operation)
arg5 = args[5].(string)
}
var arg6 string
if args[6] != nil {
@@ -194,7 +194,7 @@ func (_c *PATS_AuthorizePAT_Call) Return(err error) *PATS_AuthorizePAT_Call {
return _c
}
func (_c *PATS_AuthorizePAT_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error) *PATS_AuthorizePAT_Call {
func (_c *PATS_AuthorizePAT_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error) *PATS_AuthorizePAT_Call {
_c.Call.Return(run)
return _c
}
+12 -12
View File
@@ -107,16 +107,16 @@ func (_c *PATSRepository_AddScope_Call) RunAndReturn(run func(ctx context.Contex
}
// CheckScope provides a mock function for the type PATSRepository
func (_mock *PATSRepository) CheckScope(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
func (_mock *PATSRepository) CheckScope(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, domainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for CheckScope")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, string, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, domainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -133,14 +133,14 @@ type PATSRepository_CheckScope_Call struct {
// - userID string
// - patID string
// - entityType auth.EntityType
// - optionalDomainID string
// - operation auth.Operation
// - domainID string
// - operation string
// - entityID string
func (_e *PATSRepository_Expecter) CheckScope(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, optionalDomainID interface{}, operation interface{}, entityID interface{}) *PATSRepository_CheckScope_Call {
return &PATSRepository_CheckScope_Call{Call: _e.mock.On("CheckScope", ctx, userID, patID, entityType, optionalDomainID, operation, entityID)}
func (_e *PATSRepository_Expecter) CheckScope(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, domainID interface{}, operation interface{}, entityID interface{}) *PATSRepository_CheckScope_Call {
return &PATSRepository_CheckScope_Call{Call: _e.mock.On("CheckScope", ctx, userID, patID, entityType, domainID, operation, entityID)}
}
func (_c *PATSRepository_CheckScope_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string)) *PATSRepository_CheckScope_Call {
func (_c *PATSRepository_CheckScope_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string)) *PATSRepository_CheckScope_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -162,9 +162,9 @@ func (_c *PATSRepository_CheckScope_Call) Run(run func(ctx context.Context, user
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 auth.Operation
var arg5 string
if args[5] != nil {
arg5 = args[5].(auth.Operation)
arg5 = args[5].(string)
}
var arg6 string
if args[6] != nil {
@@ -188,7 +188,7 @@ func (_c *PATSRepository_CheckScope_Call) Return(err error) *PATSRepository_Chec
return _c
}
func (_c *PATSRepository_CheckScope_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error) *PATSRepository_CheckScope_Call {
func (_c *PATSRepository_CheckScope_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error) *PATSRepository_CheckScope_Call {
_c.Call.Return(run)
return _c
}
+12 -12
View File
@@ -171,16 +171,16 @@ func (_c *Service_Authorize_Call) RunAndReturn(run func(ctx context.Context, pr
}
// AuthorizePAT provides a mock function for the type Service
func (_mock *Service) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
func (_mock *Service) AuthorizePAT(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
ret := _mock.Called(ctx, userID, patID, entityType, domainID, operation, entityID)
if len(ret) == 0 {
panic("no return value specified for AuthorizePAT")
}
var r0 error
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, auth.Operation, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, optionalDomainID, operation, entityID)
if returnFunc, ok := ret.Get(0).(func(context.Context, string, string, auth.EntityType, string, string, string) error); ok {
r0 = returnFunc(ctx, userID, patID, entityType, domainID, operation, entityID)
} else {
r0 = ret.Error(0)
}
@@ -197,14 +197,14 @@ type Service_AuthorizePAT_Call struct {
// - userID string
// - patID string
// - entityType auth.EntityType
// - optionalDomainID string
// - operation auth.Operation
// - domainID string
// - operation string
// - entityID string
func (_e *Service_Expecter) AuthorizePAT(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, optionalDomainID interface{}, operation interface{}, entityID interface{}) *Service_AuthorizePAT_Call {
return &Service_AuthorizePAT_Call{Call: _e.mock.On("AuthorizePAT", ctx, userID, patID, entityType, optionalDomainID, operation, entityID)}
func (_e *Service_Expecter) AuthorizePAT(ctx interface{}, userID interface{}, patID interface{}, entityType interface{}, domainID interface{}, operation interface{}, entityID interface{}) *Service_AuthorizePAT_Call {
return &Service_AuthorizePAT_Call{Call: _e.mock.On("AuthorizePAT", ctx, userID, patID, entityType, domainID, operation, entityID)}
}
func (_c *Service_AuthorizePAT_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string)) *Service_AuthorizePAT_Call {
func (_c *Service_AuthorizePAT_Call) Run(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string)) *Service_AuthorizePAT_Call {
_c.Call.Run(func(args mock.Arguments) {
var arg0 context.Context
if args[0] != nil {
@@ -226,9 +226,9 @@ func (_c *Service_AuthorizePAT_Call) Run(run func(ctx context.Context, userID st
if args[4] != nil {
arg4 = args[4].(string)
}
var arg5 auth.Operation
var arg5 string
if args[5] != nil {
arg5 = args[5].(auth.Operation)
arg5 = args[5].(string)
}
var arg6 string
if args[6] != nil {
@@ -252,7 +252,7 @@ func (_c *Service_AuthorizePAT_Call) Return(err error) *Service_AuthorizePAT_Cal
return _c
}
func (_c *Service_AuthorizePAT_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error) *Service_AuthorizePAT_Call {
func (_c *Service_AuthorizePAT_Call) RunAndReturn(run func(ctx context.Context, userID string, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error) *Service_AuthorizePAT_Call {
_c.Call.Return(run)
return _c
}
+127 -146
View File
@@ -12,114 +12,53 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/pkg/errors"
)
const AnyIDs = "*"
type Operation uint32
const (
CreateOp Operation = iota
ReadOp
ListOp
UpdateOp
DeleteOp
ShareOp
UnshareOp
PublishOp
SubscribeOp
"github.com/absmach/supermq/pkg/permissions"
)
const (
createOpStr = "create"
readOpStr = "read"
listOpStr = "list"
updateOpStr = "update"
deleteOpStr = "delete"
shareOpStr = "share"
UnshareOpStr = "unshare"
PublishOpStr = "publish"
SubscribeOpStr = "subscribe"
AnyIDs = "*"
RoleOperationPrefix = "role_"
)
func (op Operation) String() string {
switch op {
case CreateOp:
return createOpStr
case ReadOp:
return readOpStr
case ListOp:
return listOpStr
case UpdateOp:
return updateOpStr
case DeleteOp:
return deleteOpStr
case ShareOp:
return shareOpStr
case UnshareOp:
return UnshareOpStr
case PublishOp:
return PublishOpStr
case SubscribeOp:
return SubscribeOpStr
default:
return fmt.Sprintf("unknown operation type %d", op)
}
}
const (
OpCreate = "create"
OpList = "list"
func (op Operation) ValidString() (string, error) {
str := op.String()
if str == fmt.Sprintf("unknown operation type %d", op) {
return "", errors.New(str)
}
return str, nil
}
OpCreateClients = "create_clients"
OpListClients = "list_clients"
OpCreateChannels = "create_channels"
OpListChannels = "list_channels"
OpCreateGroups = "create_groups"
OpListGroups = "list_groups"
func ParseOperation(op string) (Operation, error) {
switch op {
case createOpStr:
return CreateOp, nil
case readOpStr:
return ReadOp, nil
case listOpStr:
return ListOp, nil
case updateOpStr:
return UpdateOp, nil
case deleteOpStr:
return DeleteOp, nil
case shareOpStr:
return ShareOp, nil
case UnshareOpStr:
return UnshareOp, nil
case PublishOpStr:
return PublishOp, nil
case SubscribeOpStr:
return SubscribeOp, nil
default:
return 0, fmt.Errorf("unknown operation type %s", op)
}
}
OpShare = "share"
OpUnshare = "unshare"
func (op Operation) MarshalJSON() ([]byte, error) {
return json.Marshal(op.String())
}
OpDashboardShare = "dashboard_share"
OpDashboardUnshare = "dashboard_unshare"
func (op *Operation) UnmarshalJSON(data []byte) error {
str := strings.Trim(string(data), "\"")
val, err := ParseOperation(str)
*op = val
return err
}
OpPublish = "publish"
OpSubscribe = "subscribe"
func (op Operation) MarshalText() (text []byte, err error) {
return []byte(op.String()), nil
}
OpMessagePublish = "message_publish"
OpMessageSubscribe = "message_subscribe"
)
func (op *Operation) UnmarshalText(data []byte) (err error) {
str := strings.Trim(string(data), "\"")
*op, err = ParseOperation(str)
return err
}
var errInvalidEntityOp = errors.NewRequestError("operation not valid for entity type")
type Operation = permissions.Operation
// Dashboard operations.
const (
DashboardShareOp Operation = iota + 400
DashboardUnshareOp
)
// Messages operations.
const (
MessagePublishOp Operation = iota + 500
MessageSubscribeOp
)
type EntityType uint32
@@ -127,20 +66,18 @@ const (
GroupsType EntityType = iota
ChannelsType
ClientsType
DomainsType
UsersType
DashboardType
MessagesType
DomainsType
)
const (
GroupsScopeStr = "groups"
ChannelsScopeStr = "channels"
ClientsScopeStr = "clients"
DomainsStr = "domains"
UsersStr = "users"
DashboardsStr = "dashboards"
MessagesStr = "messages"
DomainsStr = "domains"
)
func (et EntityType) String() string {
@@ -151,27 +88,17 @@ func (et EntityType) String() string {
return ChannelsScopeStr
case ClientsType:
return ClientsScopeStr
case DomainsType:
return DomainsStr
case UsersType:
return UsersStr
case DashboardType:
return DashboardsStr
case MessagesType:
return MessagesStr
case DomainsType:
return DomainsStr
default:
return fmt.Sprintf("unknown domain entity type %d", et)
}
}
func (et EntityType) ValidString() (string, error) {
str := et.String()
if str == fmt.Sprintf("unknown operation type %d", et) {
return "", errors.New(str)
}
return str, nil
}
func ParseEntityType(et string) (EntityType, error) {
switch et {
case GroupsScopeStr:
@@ -180,14 +107,12 @@ func ParseEntityType(et string) (EntityType, error) {
return ChannelsType, nil
case ClientsScopeStr:
return ClientsType, nil
case DomainsStr:
return DomainsType, nil
case UsersStr:
return UsersType, nil
case DashboardsStr:
return DashboardType, nil
case MessagesStr:
return MessagesType, nil
case DomainsStr:
return DomainsType, nil
default:
return 0, fmt.Errorf("unknown domain entity type %s", et)
}
@@ -214,39 +139,101 @@ func (et *EntityType) UnmarshalText(data []byte) (err error) {
return err
}
func IsValidOperationForEntity(entityType EntityType, operation string) bool {
switch entityType {
case ClientsType, ChannelsType, GroupsType, DomainsType:
return true
case DashboardType:
return operation == OpDashboardShare || operation == OpDashboardUnshare
case MessagesType:
return operation == OpMessagePublish || operation == OpMessageSubscribe
default:
return false
}
}
// Example Scope as JSON
//
// [
// {
// "optional_domain_id": "domain_1",
// "domain_id": "domain_1",
// "entity_type": "groups",
// "operation": "create",
// "operation": "view",
// "entity_id": "*"
// },
// {
// "optional_domain_id": "domain_1",
// "domain_id": "domain_1",
// "entity_type": "channels",
// "operation": "delete",
// "entity_id": "channel1"
// },
// {
// "optional_domain_id": "domain_1",
// "entity_type": "things",
// "domain_id": "domain_1",
// "entity_type": "clients",
// "operation": "update",
// "entity_id": "*"
// }
// ]
type Scope struct {
ID string `json:"id"`
PatID string `json:"pat_id"`
OptionalDomainID string `json:"optional_domain_id"`
EntityType EntityType `json:"entity_type"`
EntityID string `json:"entity_id"`
Operation Operation `json:"operation"`
ID string `json:"id"`
PatID string `json:"pat_id"`
DomainID string `json:"domain_id"`
EntityType EntityType `json:"entity_type"`
EntityID string `json:"entity_id"`
Operation string `json:"operation"`
}
func (s *Scope) Authorized(entityType EntityType, optionalDomainID string, operation Operation, entityID string) bool {
func (s *Scope) UnmarshalJSON(data []byte) error {
type Alias Scope
aux := (*Alias)(s)
if err := json.Unmarshal(data, aux); err != nil {
return err
}
switch s.EntityType {
case ClientsType:
switch s.Operation {
case OpCreate:
s.Operation = OpCreateClients
case OpList:
s.Operation = OpListClients
}
case ChannelsType:
switch s.Operation {
case OpCreate:
s.Operation = OpCreateChannels
case OpList:
s.Operation = OpListChannels
}
case GroupsType:
switch s.Operation {
case OpCreate:
s.Operation = OpCreateGroups
case OpList:
s.Operation = OpListGroups
}
case DashboardType:
switch s.Operation {
case OpShare:
s.Operation = OpDashboardShare
case OpUnshare:
s.Operation = OpDashboardUnshare
}
case MessagesType:
switch s.Operation {
case OpPublish:
s.Operation = OpMessagePublish
case OpSubscribe:
s.Operation = OpMessageSubscribe
}
}
return nil
}
func (s *Scope) Authorized(entityType EntityType, domainID string, operation string, entityID string) bool {
if s == nil {
return false
}
@@ -255,7 +242,7 @@ func (s *Scope) Authorized(entityType EntityType, optionalDomainID string, opera
return false
}
if optionalDomainID != "" && s.OptionalDomainID != optionalDomainID {
if s.DomainID != "" && s.DomainID != domainID {
return false
}
@@ -270,6 +257,7 @@ func (s *Scope) Authorized(entityType EntityType, optionalDomainID string, opera
if s.EntityID == entityID {
return true
}
return false
}
@@ -281,11 +269,12 @@ func (s *Scope) Validate() error {
return apiutil.ErrMissingEntityID
}
switch s.EntityType {
case ChannelsType, GroupsType, ClientsType:
if s.OptionalDomainID == "" {
return apiutil.ErrMissingDomainID
}
if s.DomainID == "" {
return apiutil.ErrMissingDomainID
}
if !IsValidOperationForEntity(s.EntityType, s.Operation) {
return errors.Wrap(apiutil.ErrInvalidQueryParams, errInvalidEntityOp)
}
return nil
@@ -358,14 +347,6 @@ func (pat *PAT) Validate() error {
return nil
}
func (pat *PAT) String() string {
str, err := json.MarshalIndent(pat, "", " ")
if err != nil {
return fmt.Sprintf("failed to convert PAT to string: json marshal error :%s", err.Error())
}
return string(str)
}
// PATS specifies function which are required for Personal access Token implementation.
type PATS interface {
// Create function creates new PAT for given valid inputs.
@@ -411,7 +392,7 @@ type PATS interface {
IdentifyPAT(ctx context.Context, paToken string) (PAT, error)
// AuthorizePAT function will valid the secret and check the given scope exists.
AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error
AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, domainID string, operation string, entityID string) error
}
// PATSRepository specifies PATS persistence API.
@@ -456,7 +437,7 @@ type PATSRepository interface {
RemoveScope(ctx context.Context, userID string, scopesIDs ...string) error
CheckScope(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error
CheckScope(ctx context.Context, userID, patID string, entityType EntityType, domainID string, operation string, entityID string) error
RemoveAllScope(ctx context.Context, patID string) error
}
@@ -464,7 +445,7 @@ type PATSRepository interface {
type Cache interface {
Save(ctx context.Context, userID string, scopes []Scope) error
CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType EntityType, operation Operation, entityID string) bool
CheckScope(ctx context.Context, userID, patID, optionalDomainID string, entityType EntityType, operation string, entityID string) bool
Remove(ctx context.Context, userID string, scopesID []string) error
-358
View File
@@ -10,342 +10,6 @@ import (
"github.com/stretchr/testify/assert"
)
func TestOperationString(t *testing.T) {
cases := []struct {
desc string
op auth.Operation
expected string
}{
{
desc: "Create operation",
op: auth.CreateOp,
expected: "create",
},
{
desc: "Read operation",
op: auth.ReadOp,
expected: "read",
},
{
desc: "List operation",
op: auth.ListOp,
expected: "list",
},
{
desc: "Update operation",
op: auth.UpdateOp,
expected: "update",
},
{
desc: "Delete operation",
op: auth.DeleteOp,
expected: "delete",
},
{
desc: "Share operation",
op: auth.ShareOp,
expected: "share",
},
{
desc: "Unshare operation",
op: auth.UnshareOp,
expected: "unshare",
},
{
desc: "Publish operation",
op: auth.PublishOp,
expected: "publish",
},
{
desc: "Subscribe operation",
op: auth.SubscribeOp,
expected: "subscribe",
},
{
desc: "Unknown operation",
op: auth.Operation(100),
expected: "unknown operation type 100",
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got := tc.op.String()
assert.Equal(t, tc.expected, got, "String() = %v, expected %v", got, tc.expected)
})
}
}
func TestOperationValidString(t *testing.T) {
cases := []struct {
desc string
op auth.Operation
expected string
err bool
}{
{
desc: "Valid create operation",
op: auth.CreateOp,
expected: "create",
err: false,
},
{
desc: "Valid read operation",
op: auth.ReadOp,
expected: "read",
err: false,
},
{
desc: "Invalid operation",
op: auth.Operation(100),
expected: "",
err: true,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := tc.op.ValidString()
if tc.err {
assert.Error(t, err, "ValidString() should return error")
} else {
assert.NoError(t, err, "ValidString() should not return error")
assert.Equal(t, tc.expected, got, "ValidString() = %v, expected %v", got, tc.expected)
}
})
}
}
func TestParseOperation(t *testing.T) {
cases := []struct {
desc string
op string
expected auth.Operation
err bool
}{
{
desc: "Parse create",
op: "create",
expected: auth.CreateOp,
err: false,
},
{
desc: "Parse read",
op: "read",
expected: auth.ReadOp,
err: false,
},
{
desc: "Parse list",
op: "list",
expected: auth.ListOp,
err: false,
},
{
desc: "Parse update",
op: "update",
expected: auth.UpdateOp,
err: false,
},
{
desc: "Parse delete",
op: "delete",
expected: auth.DeleteOp,
err: false,
},
{
desc: "Parse share",
op: "share",
expected: auth.ShareOp,
err: false,
},
{
desc: "Parse unshare",
op: "unshare",
expected: auth.UnshareOp,
err: false,
},
{
desc: "Parse publish",
op: "publish",
expected: auth.PublishOp,
err: false,
},
{
desc: "Parse subscribe",
op: "subscribe",
expected: auth.SubscribeOp,
err: false,
},
{
desc: "Parse unknown operation",
op: "unknown",
expected: auth.Operation(0),
err: true,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := auth.ParseOperation(tc.op)
if tc.err {
assert.Error(t, err, "ParseOperation() should return error")
} else {
assert.NoError(t, err, "ParseOperation() should not return error")
assert.Equal(t, tc.expected, got, "ParseOperation() = %v, expected %v", got, tc.expected)
}
})
}
}
func TestOperationMarshalJSON(t *testing.T) {
cases := []struct {
desc string
op auth.Operation
expected []byte
err error
}{
{
desc: "Marshal create",
op: auth.CreateOp,
expected: []byte(`"create"`),
err: nil,
},
{
desc: "Marshal read",
op: auth.ReadOp,
expected: []byte(`"read"`),
err: nil,
},
{
desc: "Marshal delete",
op: auth.DeleteOp,
expected: []byte(`"delete"`),
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := tc.op.MarshalJSON()
assert.Equal(t, tc.err, err, "MarshalJSON() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expected, got, "MarshalJSON() = %v, expected %v", got, tc.expected)
})
}
}
func TestOperationUnmarshalJSON(t *testing.T) {
cases := []struct {
desc string
data []byte
expected auth.Operation
err bool
}{
{
desc: "Unmarshal create",
data: []byte(`"create"`),
expected: auth.CreateOp,
err: false,
},
{
desc: "Unmarshal read",
data: []byte(`"read"`),
expected: auth.ReadOp,
err: false,
},
{
desc: "Unmarshal unknown",
data: []byte(`"unknown"`),
expected: auth.Operation(0),
err: true,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var op auth.Operation
err := op.UnmarshalJSON(tc.data)
if tc.err {
assert.Error(t, err, "UnmarshalJSON() should return error")
} else {
assert.NoError(t, err, "UnmarshalJSON() should not return error")
assert.Equal(t, tc.expected, op, "UnmarshalJSON() = %v, expected %v", op, tc.expected)
}
})
}
}
func TestOperationMarshalText(t *testing.T) {
cases := []struct {
desc string
op auth.Operation
expected []byte
err error
}{
{
desc: "Marshal create as text",
op: auth.CreateOp,
expected: []byte("create"),
err: nil,
},
{
desc: "Marshal read as text",
op: auth.ReadOp,
expected: []byte("read"),
err: nil,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
got, err := tc.op.MarshalText()
assert.Equal(t, tc.err, err, "MarshalText() error = %v, expected %v", err, tc.err)
assert.Equal(t, tc.expected, got, "MarshalText() = %v, expected %v", got, tc.expected)
})
}
}
func TestOperationUnmarshalText(t *testing.T) {
cases := []struct {
desc string
data []byte
expected auth.Operation
err bool
}{
{
desc: "Unmarshal create from text",
data: []byte("create"),
expected: auth.CreateOp,
err: false,
},
{
desc: "Unmarshal read from text",
data: []byte("read"),
expected: auth.ReadOp,
err: false,
},
{
desc: "Unmarshal unknown from text",
data: []byte("unknown"),
expected: auth.Operation(0),
err: true,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
var op auth.Operation
err := op.UnmarshalText(tc.data)
if tc.err {
assert.Error(t, err, "UnmarshalText() should return error")
} else {
assert.NoError(t, err, "UnmarshalText() should not return error")
assert.Equal(t, tc.expected, op, "UnmarshalText() = %v, expected %v", op, tc.expected)
}
})
}
}
func TestEntityTypeString(t *testing.T) {
cases := []struct {
desc string
@@ -367,16 +31,6 @@ func TestEntityTypeString(t *testing.T) {
et: auth.ClientsType,
expected: "clients",
},
{
desc: "Domains entity type",
et: auth.DomainsType,
expected: "domains",
},
{
desc: "Users entity type",
et: auth.UsersType,
expected: "users",
},
{
desc: "Dashboard entity type",
et: auth.DashboardType,
@@ -427,18 +81,6 @@ func TestParseEntityType(t *testing.T) {
expected: auth.ClientsType,
err: false,
},
{
desc: "Parse domains",
et: "domains",
expected: auth.DomainsType,
err: false,
},
{
desc: "Parse users",
et: "users",
expected: auth.UsersType,
err: false,
},
{
desc: "Parse dashboards",
et: "dashboards",
+9
View File
@@ -125,6 +125,15 @@ func Migration() *migrate.MemoryMigrationSource {
`ALTER TABLE pats ALTER COLUMN last_used_at TYPE TIMESTAMP;`,
},
},
{
Id: "auth_7",
Up: []string{
`ALTER TABLE pat_scopes RENAME COLUMN optional_domain_id TO domain_id;`,
},
Down: []string{
`ALTER TABLE pat_scopes RENAME COLUMN domain_id TO optional_domain_id;`,
},
},
},
}
}
+18 -22
View File
@@ -26,12 +26,12 @@ type dbPat struct {
}
type dbScope struct {
ID string `db:"id,omitempty"`
PatID string `db:"pat_id,omitempty"`
OptionalDomainID string `db:"optional_domain_id,omitempty"`
EntityType string `db:"entity_type,omitempty"`
EntityID string `db:"entity_id,omitempty"`
Operation string `db:"operation,omitempty"`
ID string `db:"id,omitempty"`
PatID string `db:"pat_id,omitempty"`
DomainID string `db:"domain_id,omitempty"`
EntityType string `db:"entity_type,omitempty"`
EntityID string `db:"entity_id,omitempty"`
Operation string `db:"operation,omitempty"`
}
type dbPagemeta struct {
@@ -92,17 +92,13 @@ func toAuthScope(dsc []dbScope) ([]auth.Scope, error) {
if err != nil {
return []auth.Scope{}, err
}
operation, err := auth.ParseOperation(s.Operation)
if err != nil {
return []auth.Scope{}, err
}
scope = append(scope, auth.Scope{
ID: s.ID,
PatID: s.PatID,
OptionalDomainID: s.OptionalDomainID,
EntityType: entityType,
EntityID: s.EntityID,
Operation: operation,
ID: s.ID,
PatID: s.PatID,
DomainID: s.DomainID,
EntityType: entityType,
EntityID: s.EntityID,
Operation: s.Operation,
})
}
@@ -152,12 +148,12 @@ func toDBScope(sc []auth.Scope) []dbScope {
var scopes []dbScope
for _, s := range sc {
scopes = append(scopes, dbScope{
ID: s.ID,
PatID: s.PatID,
OptionalDomainID: s.OptionalDomainID,
EntityType: s.EntityType.String(),
EntityID: s.EntityID,
Operation: s.Operation.String(),
ID: s.ID,
PatID: s.PatID,
DomainID: s.DomainID,
EntityType: s.EntityType.String(),
EntityID: s.EntityID,
Operation: s.Operation,
})
}
return scopes
+34 -36
View File
@@ -18,6 +18,8 @@ import (
var _ auth.PATSRepository = (*patRepo)(nil)
var errInsufficientPATScope = errors.NewRequestError("PAT does not have the required scope permissions")
type patRepo struct {
db postgres.Database
cache auth.Cache
@@ -374,8 +376,8 @@ func (pr *patRepo) RemoveAllPAT(ctx context.Context, userID string) error {
func (pr *patRepo) AddScope(ctx context.Context, userID string, scopes []auth.Scope) error {
q := `
INSERT INTO pat_scopes (id, pat_id, entity_type, optional_domain_id, operation, entity_id)
VALUES (:id, :pat_id, :entity_type, :optional_domain_id, :operation, :entity_id)`
INSERT INTO pat_scopes (id, pat_id, entity_type, domain_id, operation, entity_id)
VALUES (:id, :pat_id, :entity_type, :domain_id, :operation, :entity_id)`
var newScopes []auth.Scope
@@ -410,17 +412,17 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND domain_id = :domain_id
AND operation = :operation
AND entity_id = :entity_id
LIMIT 1`
params := dbScope{
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation.String(),
EntityID: auth.AnyIDs,
PatID: sc.PatID,
DomainID: sc.DomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation,
EntityID: auth.AnyIDs,
}
rows, err := pr.db.NamedQueryContext(ctx, q, params)
@@ -442,10 +444,10 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
if sc.EntityID == auth.AnyIDs {
newParams := dbScope{
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation.String(),
PatID: sc.PatID,
DomainID: sc.DomainID,
EntityType: sc.EntityType.String(),
Operation: sc.Operation,
}
checkEntityQuery := `
@@ -453,7 +455,7 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND domain_id = :domain_id
AND operation = :operation
LIMIT 1`
@@ -476,7 +478,7 @@ func (pr *patRepo) processScope(ctx context.Context, sc auth.Scope) (auth.Scope,
SET entity_id = :entity_id
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND domain_id = :domain_id
AND operation = :operation`
rows, err = pr.db.NamedQueryContext(ctx, updateWithWildcardQuery, params)
@@ -511,28 +513,28 @@ func (pr *patRepo) RemoveScope(ctx context.Context, userID string, scopesIDs ...
return nil
}
func (pr *patRepo) CheckScope(ctx context.Context, userID, patID string, entityType auth.EntityType, optionalDomainID string, operation auth.Operation, entityID string) error {
func (pr *patRepo) CheckScope(ctx context.Context, userID, patID string, entityType auth.EntityType, domainID string, operation string, entityID string) error {
q := `
SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id
SELECT id, pat_id, entity_type, domain_id, operation, entity_id
FROM pat_scopes
WHERE pat_id = :pat_id
AND entity_type = :entity_type
AND optional_domain_id = :optional_domain_id
AND domain_id = :domain_id
AND operation = :operation
AND (entity_id = :entity_id OR entity_id = '*')
LIMIT 1`
authorized := pr.cache.CheckScope(ctx, userID, patID, optionalDomainID, entityType, operation, entityID)
authorized := pr.cache.CheckScope(ctx, userID, patID, domainID, entityType, operation, entityID)
if authorized {
return nil
}
scope := dbScope{
PatID: patID,
EntityType: entityType.String(),
OptionalDomainID: optionalDomainID,
Operation: operation.String(),
EntityID: entityID,
PatID: patID,
EntityType: entityType.String(),
DomainID: domainID,
Operation: operation,
EntityID: entityID,
}
rows, err := pr.db.NamedQueryContext(ctx, q, scope)
@@ -551,29 +553,25 @@ func (pr *patRepo) CheckScope(ctx context.Context, userID, patID string, entityT
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
operation, err := auth.ParseOperation(sc.Operation)
if err != nil {
return errors.Wrap(repoerr.ErrViewEntity, err)
}
authScope := auth.Scope{
ID: sc.ID,
PatID: sc.PatID,
OptionalDomainID: sc.OptionalDomainID,
EntityType: entityType,
EntityID: sc.EntityID,
Operation: operation,
ID: sc.ID,
PatID: sc.PatID,
DomainID: sc.DomainID,
EntityType: entityType,
EntityID: sc.EntityID,
Operation: sc.Operation,
}
if err := pr.cache.Save(ctx, userID, []auth.Scope{authScope}); err != nil {
return err
}
if authScope.Authorized(entityType, optionalDomainID, operation, entityID) {
if authScope.Authorized(entityType, domainID, operation, entityID) {
return nil
}
}
return repoerr.ErrNotFound
return errInsufficientPATScope
}
func (pr *patRepo) RemoveAllScope(ctx context.Context, patID string) error {
@@ -625,7 +623,7 @@ func (pr *patRepo) RetrieveScope(ctx context.Context, pm auth.ScopesPageMeta) (a
func (pr *patRepo) retrieveScopeFromDB(ctx context.Context, pm dbPagemeta) ([]auth.Scope, error) {
q := `
SELECT id, pat_id, entity_type, optional_domain_id, operation, entity_id
SELECT id, pat_id, entity_type, domain_id, operation, entity_id
FROM pat_scopes WHERE pat_id = :pat_id OFFSET :offset LIMIT :limit`
scopeRows, err := pr.db.NamedQueryContext(ctx, q, pm)
if err != nil {
+143 -136
View File
@@ -9,131 +9,134 @@ import (
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
channelsOps "github.com/absmach/supermq/channels/operations"
clientsOps "github.com/absmach/supermq/clients/operations"
groupsOps "github.com/absmach/supermq/groups/operations"
"github.com/stretchr/testify/assert"
)
func TestScopeAuthorized(t *testing.T) {
cases := []struct {
desc string
scope *auth.Scope
entityType auth.EntityType
optionalDomainID string
operation auth.Operation
entityID string
expected bool
desc string
scope *auth.Scope
entityType auth.EntityType
domainID string
operation string
entityID string
expected bool
}{
{
desc: "Authorized with matching entity type, domain, operation and entity ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
entityType: auth.GroupsType,
optionalDomainID: "domain1",
operation: auth.CreateOp,
entityID: "entity1",
expected: true,
entityType: auth.GroupsType,
domainID: "domain1",
operation: "view",
entityID: "entity1",
expected: true,
},
{
desc: "Authorized with wildcard entity ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "*",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "*",
},
entityType: auth.GroupsType,
optionalDomainID: "domain1",
operation: auth.CreateOp,
entityID: "any-entity",
expected: true,
entityType: auth.GroupsType,
domainID: "domain1",
operation: "view",
entityID: "any-entity",
expected: true,
},
{
desc: "Authorized without domain ID",
scope: &auth.Scope{
EntityType: auth.UsersType,
OptionalDomainID: "",
Operation: auth.ReadOp,
EntityID: "user1",
EntityType: auth.ClientsType,
DomainID: "",
Operation: "view",
EntityID: "client1",
},
entityType: auth.UsersType,
optionalDomainID: "",
operation: auth.ReadOp,
entityID: "user1",
expected: true,
entityType: auth.ClientsType,
domainID: "domain1",
operation: "view",
entityID: "client1",
expected: true,
},
{
desc: "Not authorized with different entity type",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
entityType: auth.ChannelsType,
optionalDomainID: "domain1",
operation: auth.CreateOp,
entityID: "entity1",
expected: false,
entityType: auth.ChannelsType,
domainID: "domain1",
operation: "view",
entityID: "entity1",
expected: false,
},
{
desc: "Not authorized with different domain ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
entityType: auth.GroupsType,
optionalDomainID: "domain2",
operation: auth.CreateOp,
entityID: "entity1",
expected: false,
entityType: auth.GroupsType,
domainID: "domain2",
operation: "view",
entityID: "entity1",
expected: false,
},
{
desc: "Not authorized with different operation",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
entityType: auth.GroupsType,
optionalDomainID: "domain1",
operation: auth.DeleteOp,
entityID: "entity1",
expected: false,
entityType: auth.GroupsType,
domainID: "domain1",
operation: "delete",
entityID: "entity1",
expected: false,
},
{
desc: "Not authorized with different entity ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
entityType: auth.GroupsType,
optionalDomainID: "domain1",
operation: auth.CreateOp,
entityID: "entity2",
expected: false,
entityType: auth.GroupsType,
domainID: "domain1",
operation: "view",
entityID: "entity2",
expected: false,
},
{
desc: "Not authorized with nil scope",
scope: nil,
entityType: auth.GroupsType,
optionalDomainID: "domain1",
operation: auth.CreateOp,
entityID: "entity1",
expected: false,
desc: "Not authorized with nil scope",
scope: nil,
entityType: auth.GroupsType,
domainID: "domain1",
operation: "view",
entityID: "entity1",
expected: false,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
result := tc.scope.Authorized(tc.entityType, tc.optionalDomainID, tc.operation, tc.entityID)
result := tc.scope.Authorized(tc.entityType, tc.domainID, tc.operation, tc.entityID)
assert.Equal(t, tc.expected, result, "Authorized() = %v, expected %v", result, tc.expected)
})
}
@@ -148,60 +151,60 @@ func TestScopeValidate(t *testing.T) {
{
desc: "Valid scope for groups with domain ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "entity1",
},
err: nil,
},
{
desc: "Valid scope for channels with domain ID",
scope: &auth.Scope{
EntityType: auth.ChannelsType,
OptionalDomainID: "domain1",
Operation: auth.ReadOp,
EntityID: "channel1",
EntityType: auth.ChannelsType,
DomainID: "domain1",
Operation: "view",
EntityID: "channel1",
},
err: nil,
},
{
desc: "Valid scope for clients with domain ID",
scope: &auth.Scope{
EntityType: auth.ClientsType,
OptionalDomainID: "domain1",
Operation: auth.UpdateOp,
EntityID: "client1",
EntityType: auth.ClientsType,
DomainID: "domain1",
Operation: "update",
EntityID: "client1",
},
err: nil,
},
{
desc: "Valid scope for users without domain ID",
desc: "Valid scope for messages with domain ID",
scope: &auth.Scope{
EntityType: auth.UsersType,
OptionalDomainID: "",
Operation: auth.DeleteOp,
EntityID: "user1",
EntityType: auth.MessagesType,
DomainID: "domain1",
Operation: "message_publish",
EntityID: "message1",
},
err: nil,
},
{
desc: "Valid scope for domains without domain ID",
desc: "Valid scope for dashboard with domain ID",
scope: &auth.Scope{
EntityType: auth.DomainsType,
OptionalDomainID: "",
Operation: auth.ListOp,
EntityID: "domain1",
EntityType: auth.DashboardType,
DomainID: "domain1",
Operation: "dashboard_share",
EntityID: "dashboard1",
},
err: nil,
},
{
desc: "Valid scope with wildcard entity ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "*",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: "view",
EntityID: "*",
},
err: nil,
},
@@ -213,40 +216,60 @@ func TestScopeValidate(t *testing.T) {
{
desc: "Invalid scope without entity ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "domain1",
Operation: auth.CreateOp,
EntityID: "",
EntityType: auth.GroupsType,
DomainID: "domain1",
Operation: groupsOps.OperationDetails()[groupsOps.OpViewGroup].Name,
EntityID: "",
},
err: apiutil.ErrMissingEntityID,
},
{
desc: "Invalid scope for groups without domain ID",
scope: &auth.Scope{
EntityType: auth.GroupsType,
OptionalDomainID: "",
Operation: auth.CreateOp,
EntityID: "entity1",
EntityType: auth.GroupsType,
DomainID: "",
Operation: groupsOps.OperationDetails()[groupsOps.OpViewGroup].Name,
EntityID: "entity1",
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "Invalid scope for channels without domain ID",
scope: &auth.Scope{
EntityType: auth.ChannelsType,
OptionalDomainID: "",
Operation: auth.CreateOp,
EntityID: "channel1",
EntityType: auth.ChannelsType,
DomainID: "",
Operation: channelsOps.OperationDetails()[channelsOps.OpViewChannel].Name,
EntityID: "channel1",
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "Invalid scope for clients without domain ID",
scope: &auth.Scope{
EntityType: auth.ClientsType,
OptionalDomainID: "",
Operation: auth.CreateOp,
EntityID: "client1",
EntityType: auth.ClientsType,
DomainID: "",
Operation: clientsOps.OperationDetails()[clientsOps.OpViewClient].Name,
EntityID: "client1",
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "Invalid scope for dashboard without domain ID",
scope: &auth.Scope{
EntityType: auth.DashboardType,
DomainID: "",
Operation: auth.OpShare,
EntityID: "dashboard1",
},
err: apiutil.ErrMissingDomainID,
},
{
desc: "Invalid scope for messages without domain ID",
scope: &auth.Scope{
EntityType: auth.MessagesType,
DomainID: "",
Operation: auth.OpPublish,
EntityID: "message1",
},
err: apiutil.ErrMissingDomainID,
},
@@ -351,19 +374,3 @@ func TestPATMarshalUnmarshalBinary(t *testing.T) {
assert.Equal(t, pat.Secret, newPAT.Secret, "Secret mismatch")
assert.Equal(t, pat.Status, newPAT.Status, "Status mismatch")
}
func TestPATString(t *testing.T) {
pat := &auth.PAT{
ID: "pat-id",
User: "user-id",
Name: "test-pat",
Description: "test description",
Status: auth.ActiveStatus,
}
str := pat.String()
assert.NotEmpty(t, str, "String() should return non-empty string")
assert.Contains(t, str, "pat-id", "String() should contain ID")
assert.Contains(t, str, "user-id", "String() should contain User")
assert.Contains(t, str, "test-pat", "String() should contain Name")
}
+8 -10
View File
@@ -25,11 +25,6 @@ const (
patSecretSeparator = "_"
)
const (
AccessTokenType uint32 = iota
PersonalAccessTokenType
)
var (
// ErrExpiry indicates that the token is expired.
ErrExpiry = errors.New("token is expired")
@@ -211,11 +206,14 @@ func (svc service) RetrieveJWKS() []PublicKeyInfo {
}
func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
if pr.PatID != "" && pr.TokenType == PersonalAccessTokenType {
if err := svc.AuthorizePAT(ctx, pr.UserID, pr.PatID, EntityType(pr.EntityType), pr.OptionalDomainID, Operation(pr.Operation), pr.EntityID); err != nil {
if pr.PatID != "" {
entityType, err := ParseEntityType(pr.EntityType)
if err != nil {
return err
}
if err := svc.AuthorizePAT(ctx, pr.UserID, pr.PatID, entityType, pr.Domain, pr.Operation, pr.EntityID); err != nil {
return err
}
return nil
}
if err := svc.PolicyValidation(pr); err != nil {
@@ -735,8 +733,8 @@ func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error)
return pat, nil
}
func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, optionalDomainID string, operation Operation, entityID string) error {
if err := svc.pats.CheckScope(ctx, userID, patID, entityType, optionalDomainID, operation, entityID); err != nil {
func (svc service) AuthorizePAT(ctx context.Context, userID, patID string, entityType EntityType, domainID string, operation string, entityID string) error {
if err := svc.pats.CheckScope(ctx, userID, patID, entityType, domainID, operation, entityID); err != nil {
return errors.Wrap(svcerr.ErrAuthorization, err)
}
+44 -79
View File
@@ -9,9 +9,10 @@ import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/channels/operations"
cOperations "github.com/absmach/supermq/clients/operations"
dOperations "github.com/absmach/supermq/domains/operations"
gOperations "github.com/absmach/supermq/groups/operations"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/connections"
@@ -79,26 +80,19 @@ func NewAuthorization(
}
func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session authn.Session, chs ...channels.Channel) ([]channels.Channel, []roles.RoleProvision, error) {
req := smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.DomainType,
Object: session.DomainID,
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.CreateOp,
EntityID: auth.AnyIDs,
}
if err := am.authorize(ctx, session, policies.DomainType, domains.OpCreateDomainChannels, req); err != nil {
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.DomainType,
Object: session.DomainID,
}); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, errors.Wrap(err, errDomainCreateChannels)
}
for _, ch := range chs {
if ch.ParentGroup != "" {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpGroupSetChildChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -114,18 +108,12 @@ func (am *authorizationMiddleware) CreateChannels(ctx context.Context, session a
}
func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session authn.Session, id string, withRoles bool) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpViewChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
UserID: session.UserID,
PatID: session.PatID,
EntityType: auth.ChannelsType,
OptionalDomainID: session.DomainID,
Operation: auth.ReadOp,
EntityID: id,
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpViewChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.ChannelType,
Object: id,
}); err != nil {
return channels.Channel{}, errors.Wrap(err, errView)
}
@@ -134,7 +122,13 @@ func (am *authorizationMiddleware) ViewChannel(ctx context.Context, session auth
}
func (am *authorizationMiddleware) ListChannels(ctx context.Context, session authn.Session, pm channels.Page) (channels.ChannelsPage, error) {
if err := am.checkSuperAdmin(ctx, session); err == nil {
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpListDomainChannels, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
ObjectType: policies.DomainType,
Object: session.DomainID,
}); err == nil {
session.SuperAdmin = true
}
@@ -150,7 +144,7 @@ func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session
}
func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpUpdateChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -164,7 +158,7 @@ func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session au
}
func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpUpdateChannelTags, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -178,7 +172,7 @@ func (am *authorizationMiddleware) UpdateChannelTags(ctx context.Context, sessio
}
func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpEnableChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpEnableChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -192,7 +186,7 @@ func (am *authorizationMiddleware) EnableChannel(ctx context.Context, session au
}
func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session authn.Session, id string) (channels.Channel, error) {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpDisableChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisableChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -206,7 +200,7 @@ func (am *authorizationMiddleware) DisableChannel(ctx context.Context, session a
}
func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpDeleteChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDeleteChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -221,7 +215,7 @@ func (am *authorizationMiddleware) RemoveChannel(ctx context.Context, session au
func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
for _, chID := range chIDs {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpConnectClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpConnectClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -233,7 +227,7 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
}
for _, thID := range thIDs {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpConnectToChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpConnectToChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -249,7 +243,7 @@ func (am *authorizationMiddleware) Connect(ctx context.Context, session authn.Se
func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn.Session, chIDs, thIDs []string, connTypes []connections.ConnType) error {
for _, chID := range chIDs {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpDisconnectClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpDisconnectClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -261,7 +255,7 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
}
for _, thID := range thIDs {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpDisconnectFromChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, cOperations.OpDisconnectFromChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -276,7 +270,7 @@ func (am *authorizationMiddleware) Disconnect(ctx context.Context, session authn
}
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpSetParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -286,7 +280,7 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
return errors.Wrap(err, errSetParentGroup)
}
if err := am.authorize(ctx, session, policies.GroupType, groups.OpGroupSetChildChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -300,7 +294,7 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ChannelType, channels.OpSetParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ChannelType, operations.OpSetParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -316,7 +310,7 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
if ch.ParentGroup != "" {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpGroupRemoveChildChannel, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupRemoveChildChannel, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -332,11 +326,9 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, req smqauthz.PolicyReq) error {
req.TokenType = session.Type
req.UserID = session.UserID
req.PatID = session.PatID
req.OptionalDomainID = session.DomainID
req.Domain = session.DomainID
perm, err := am.entitiesOps.GetPermission(entityType, op)
if err != nil {
return err
@@ -344,38 +336,11 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.
req.Permission = perm.String()
if req.PatID != "" && req.TokenType == authn.PersonalAccessToken {
req.EntityType = auth.ChannelsType
req.EntityID = req.Object
switch op {
case channels.OpViewChannel:
req.Operation = auth.ReadOp
case channels.OpListUserChannels:
req.Operation = auth.ListOp
req.EntityID = auth.AnyIDs
case channels.OpUpdateChannel,
channels.OpUpdateChannelTags,
channels.OpEnableChannel,
channels.OpDisableChannel,
channels.OpSetParentGroup:
req.Operation = auth.UpdateOp
case channels.OpDeleteChannel,
channels.OpRemoveParentGroup:
req.Operation = auth.DeleteOp
case channels.OpConnectClient:
req.Operation = auth.CreateOp
case channels.OpDisconnectClient:
req.Operation = auth.DeleteOp
case domains.OpCreateDomainChannels,
domains.OpListDomainChannels:
if op == domains.OpCreateDomainChannels {
req.Operation = auth.CreateOp
} else {
req.Operation = auth.ListOp
}
req.EntityID = auth.AnyIDs
}
req.EntityID = req.Object
req.EntityType = auth.ChannelsType.String()
req.Operation = am.entitiesOps.OperationName(entityType, op)
if op == operations.OpListUserChannels || op == dOperations.OpCreateDomainChannels || op == dOperations.OpListDomainChannels {
req.EntityID = auth.AnyIDs
}
if err := am.authz.Authorize(ctx, req); err != nil {
+15 -14
View File
@@ -8,7 +8,8 @@ import (
"time"
"github.com/absmach/supermq/channels"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/channels/operations"
dOperations "github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/connections"
@@ -55,7 +56,7 @@ func (cm *calloutMiddleware) CreateChannels(ctx context.Context, session authn.S
"count": len(chs),
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpCreateDomainChannels, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpCreateDomainChannels, params); err != nil {
return []channels.Channel{}, []roles.RoleProvision{}, err
}
@@ -67,7 +68,7 @@ func (cm *calloutMiddleware) ViewChannel(ctx context.Context, session authn.Sess
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpViewChannel, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpViewChannel, params); err != nil {
return channels.Channel{}, err
}
@@ -79,7 +80,7 @@ func (cm *calloutMiddleware) ListChannels(ctx context.Context, session authn.Ses
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListDomainChannels, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpListDomainChannels, params); err != nil {
return channels.ChannelsPage{}, err
}
@@ -92,7 +93,7 @@ func (cm *calloutMiddleware) ListUserChannels(ctx context.Context, session authn
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpListUserChannels, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpListUserChannels, params); err != nil {
return channels.ChannelsPage{}, err
}
@@ -104,7 +105,7 @@ func (cm *calloutMiddleware) UpdateChannel(ctx context.Context, session authn.Se
"entity_id": channel.ID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpUpdateChannel, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannel, params); err != nil {
return channels.Channel{}, err
}
@@ -116,7 +117,7 @@ func (cm *calloutMiddleware) UpdateChannelTags(ctx context.Context, session auth
"entity_id": channel.ID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpUpdateChannelTags, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpUpdateChannelTags, params); err != nil {
return channels.Channel{}, err
}
@@ -128,7 +129,7 @@ func (cm *calloutMiddleware) EnableChannel(ctx context.Context, session authn.Se
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpEnableChannel, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpEnableChannel, params); err != nil {
return channels.Channel{}, err
}
@@ -140,7 +141,7 @@ func (cm *calloutMiddleware) DisableChannel(ctx context.Context, session authn.S
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpDisableChannel, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisableChannel, params); err != nil {
return channels.Channel{}, err
}
@@ -152,7 +153,7 @@ func (cm *calloutMiddleware) RemoveChannel(ctx context.Context, session authn.Se
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpDeleteChannel, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDeleteChannel, params); err != nil {
return err
}
@@ -166,7 +167,7 @@ func (cm *calloutMiddleware) Connect(ctx context.Context, session authn.Session,
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpConnectClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpConnectClient, params); err != nil {
return err
}
@@ -180,7 +181,7 @@ func (cm *calloutMiddleware) Disconnect(ctx context.Context, session authn.Sessi
"connection_types": connTypes,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpDisconnectClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpDisconnectClient, params); err != nil {
return err
}
@@ -193,7 +194,7 @@ func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.S
"parent_group_id": parentGroupID,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpSetParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpSetParentGroup, params); err != nil {
return err
}
@@ -211,7 +212,7 @@ func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session auth
"parent_group_id": ch.ParentGroup,
}
if err := cm.callOut(ctx, session, policies.ChannelType, channels.OpRemoveParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.ChannelType, operations.OpRemoveParentGroup, params); err != nil {
return err
}
}
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package channels
package operations
import (
"github.com/absmach/supermq/pkg/permissions"
+21 -58
View File
@@ -8,8 +8,9 @@ import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/clients/operations"
dOperations "github.com/absmach/supermq/domains/operations"
gOperations "github.com/absmach/supermq/groups/operations"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
@@ -72,7 +73,7 @@ func NewAuthorization(
}
func (am *authorizationMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) {
if err := am.authorize(ctx, session, policies.DomainType, domains.OpCreateDomainClients, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpCreateDomainClients, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -86,7 +87,7 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au
}
func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string, withRoles bool) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpViewClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpViewClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -116,7 +117,7 @@ func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session
}
func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpUpdateClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpUpdateClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -130,7 +131,7 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses
}
func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpUpdateClientTags, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpUpdateClientTags, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -144,7 +145,7 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn
}
func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpUpdateClientSecret, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpUpdateClientSecret, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -158,7 +159,7 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut
}
func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpEnableClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpEnableClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -172,7 +173,7 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses
}
func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpDisableClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpDisableClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -186,7 +187,7 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se
}
func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpDeleteClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpDeleteClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -200,7 +201,7 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses
}
func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpSetParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpSetParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -210,7 +211,7 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
return errors.Wrap(err, errSetParentGroup)
}
if err := am.authorize(ctx, session, policies.GroupType, groups.OpGroupSetChildClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupSetChildClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -224,7 +225,7 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a
}
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.ClientType, clients.OpRemoveParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.ClientType, operations.OpRemoveParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -240,7 +241,7 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
if th.ParentGroup != "" {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpGroupRemoveChildClient, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, gOperations.OpGroupRemoveChildClient, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -256,10 +257,9 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, req smqauthz.PolicyReq) error {
req.TokenType = session.Type
req.UserID = session.UserID
req.PatID = session.PatID
req.OptionalDomainID = session.DomainID
req.Domain = session.DomainID
perm, err := am.entitiesOps.GetPermission(entityType, op)
if err != nil {
@@ -268,48 +268,11 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.
req.Permission = perm.String()
if req.PatID != "" && req.TokenType == authn.PersonalAccessToken {
req.EntityID = req.Object
switch entityType {
case policies.ClientType:
req.EntityType = auth.ClientsType
switch op {
case clients.OpViewClient:
req.Operation = auth.ReadOp
case clients.OpListUserClients:
req.Operation = auth.ListOp
req.EntityID = auth.AnyIDs
case clients.OpUpdateClient,
clients.OpUpdateClientTags,
clients.OpUpdateClientSecret,
clients.OpEnableClient,
clients.OpDisableClient,
clients.OpSetParentGroup:
req.Operation = auth.UpdateOp
case clients.OpDeleteClient,
clients.OpRemoveParentGroup:
req.Operation = auth.DeleteOp
}
case policies.DomainType:
req.EntityType = auth.ClientsType
switch op {
case domains.OpCreateDomainClients:
req.Operation = auth.CreateOp
req.EntityID = auth.AnyIDs
case domains.OpListDomainClients:
req.Operation = auth.ListOp
req.EntityID = auth.AnyIDs
}
case policies.GroupType:
req.EntityType = auth.ClientsType
switch op {
case groups.OpGroupSetChildClient:
req.Operation = auth.UpdateOp
case groups.OpGroupRemoveChildClient:
req.Operation = auth.DeleteOp
}
}
req.EntityID = req.Object
req.EntityType = auth.ClientsType.String()
req.Operation = am.entitiesOps.OperationName(entityType, op)
if op == operations.OpListUserClients || op == dOperations.OpCreateDomainClients || op == dOperations.OpListDomainClients {
req.EntityID = auth.AnyIDs
}
if err := am.authz.Authorize(ctx, req); err != nil {
+14 -13
View File
@@ -8,7 +8,8 @@ import (
"time"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/clients/operations"
dOperations "github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/permissions"
@@ -52,7 +53,7 @@ func (cm *calloutMiddleware) CreateClients(ctx context.Context, session authn.Se
"count": len(client),
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpCreateDomainClients, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpCreateDomainClients, params); err != nil {
return []clients.Client{}, []roles.RoleProvision{}, err
}
@@ -63,7 +64,7 @@ func (cm *calloutMiddleware) View(ctx context.Context, session authn.Session, id
params := map[string]any{
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpViewClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpViewClient, params); err != nil {
return clients.Client{}, err
}
@@ -75,7 +76,7 @@ func (cm *calloutMiddleware) ListClients(ctx context.Context, session authn.Sess
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListDomainClients, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpListDomainClients, params); err != nil {
return clients.ClientsPage{}, err
}
@@ -88,7 +89,7 @@ func (cm *calloutMiddleware) ListUserClients(ctx context.Context, session authn.
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpListUserClients, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpListUserClients, params); err != nil {
return clients.ClientsPage{}, err
}
@@ -100,7 +101,7 @@ func (cm *calloutMiddleware) Update(ctx context.Context, session authn.Session,
"entity_id": client.ID,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpUpdateClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpUpdateClient, params); err != nil {
return clients.Client{}, err
}
@@ -112,7 +113,7 @@ func (cm *calloutMiddleware) UpdateTags(ctx context.Context, session authn.Sessi
"entity_id": client.ID,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpUpdateClientTags, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpUpdateClientTags, params); err != nil {
return clients.Client{}, err
}
@@ -124,7 +125,7 @@ func (cm *calloutMiddleware) UpdateSecret(ctx context.Context, session authn.Ses
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpUpdateClientSecret, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpUpdateClientSecret, params); err != nil {
return clients.Client{}, err
}
@@ -136,7 +137,7 @@ func (cm *calloutMiddleware) Enable(ctx context.Context, session authn.Session,
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpEnableClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpEnableClient, params); err != nil {
return clients.Client{}, err
}
@@ -148,7 +149,7 @@ func (cm *calloutMiddleware) Disable(ctx context.Context, session authn.Session,
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpDisableClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpDisableClient, params); err != nil {
return clients.Client{}, err
}
@@ -160,7 +161,7 @@ func (cm *calloutMiddleware) Delete(ctx context.Context, session authn.Session,
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpDeleteClient, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpDeleteClient, params); err != nil {
return err
}
@@ -173,7 +174,7 @@ func (cm *calloutMiddleware) SetParentGroup(ctx context.Context, session authn.S
"parent_id": parentGroupID,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpSetParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpSetParentGroup, params); err != nil {
return err
}
@@ -192,7 +193,7 @@ func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session auth
"parent_id": th.ParentGroup,
}
if err := cm.callOut(ctx, session, policies.ClientType, clients.OpRemoveParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.ClientType, operations.OpRemoveParentGroup, params); err != nil {
return err
}
}
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package clients
package operations
import (
"github.com/absmach/supermq/pkg/permissions"
+8 -7
View File
@@ -25,12 +25,13 @@ import (
"github.com/absmach/supermq/channels/cache"
"github.com/absmach/supermq/channels/events"
"github.com/absmach/supermq/channels/middleware"
channelsOps "github.com/absmach/supermq/channels/operations"
"github.com/absmach/supermq/channels/postgres"
pChannels "github.com/absmach/supermq/channels/private"
"github.com/absmach/supermq/clients"
"github.com/absmach/supermq/domains"
clientsOps "github.com/absmach/supermq/clients/operations"
domainsOps "github.com/absmach/supermq/domains/operations"
dpostgres "github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/groups"
groupsOps "github.com/absmach/supermq/groups/operations"
gpostgres "github.com/absmach/supermq/groups/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
@@ -428,10 +429,10 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, cach
policies.ClientType: clientOps,
},
permissions.EntitiesOperationDetails[permissions.Operation]{
policies.ChannelType: channels.OperationDetails(),
policies.DomainType: domains.OperationDetails(),
policies.GroupType: groups.OperationDetails(),
policies.ClientType: clients.OperationDetails(),
policies.ChannelType: channelsOps.OperationDetails(),
policies.DomainType: domainsOps.OperationDetails(),
policies.GroupType: groupsOps.OperationDetails(),
policies.ClientType: clientsOps.OperationDetails(),
},
)
if err != nil {
+6 -5
View File
@@ -25,11 +25,12 @@ import (
"github.com/absmach/supermq/clients/cache"
"github.com/absmach/supermq/clients/events"
"github.com/absmach/supermq/clients/middleware"
clientsOps "github.com/absmach/supermq/clients/operations"
"github.com/absmach/supermq/clients/postgres"
pClients "github.com/absmach/supermq/clients/private"
"github.com/absmach/supermq/domains"
doperations "github.com/absmach/supermq/domains/operations"
dpostgres "github.com/absmach/supermq/domains/postgres"
svcgroups "github.com/absmach/supermq/groups"
goperations "github.com/absmach/supermq/groups/operations"
gpostgres "github.com/absmach/supermq/groups/postgres"
redisclient "github.com/absmach/supermq/internal/clients/redis"
smqlog "github.com/absmach/supermq/logger"
@@ -422,9 +423,9 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, auth
policies.GroupType: groupOps,
},
permissions.EntitiesOperationDetails[permissions.Operation]{
policies.ClientType: clients.OperationDetails(),
policies.DomainType: domains.OperationDetails(),
policies.GroupType: svcgroups.OperationDetails(),
policies.ClientType: clientsOps.OperationDetails(),
policies.DomainType: doperations.OperationDetails(),
policies.GroupType: goperations.OperationDetails(),
},
)
if err != nil {
+2 -1
View File
@@ -23,6 +23,7 @@ import (
cache "github.com/absmach/supermq/domains/cache"
"github.com/absmach/supermq/domains/events"
dmw "github.com/absmach/supermq/domains/middleware"
doperations "github.com/absmach/supermq/domains/operations"
dpostgres "github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/domains/private"
redisclient "github.com/absmach/supermq/internal/clients/redis"
@@ -319,7 +320,7 @@ func newDomainService(ctx context.Context, domainsRepo domainsSvc.Repository, ca
entitiesOps, err := permissions.NewEntitiesOperations(
permissions.EntitiesPermission{policies.DomainType: domainOps},
permissions.EntitiesOperationDetails[permissions.Operation]{policies.DomainType: domains.OperationDetails()},
permissions.EntitiesOperationDetails[permissions.Operation]{policies.DomainType: doperations.OperationDetails()},
)
if err != nil {
return nil, fmt.Errorf("failed to create entities operations: %w", err)
+4 -3
View File
@@ -18,7 +18,7 @@ import (
grpcClientsV1 "github.com/absmach/supermq/api/grpc/clients/v1"
grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
doperations "github.com/absmach/supermq/domains/operations"
dpostgres "github.com/absmach/supermq/domains/postgres"
"github.com/absmach/supermq/groups"
gpsvc "github.com/absmach/supermq/groups"
@@ -26,6 +26,7 @@ import (
httpapi "github.com/absmach/supermq/groups/api/http"
"github.com/absmach/supermq/groups/events"
"github.com/absmach/supermq/groups/middleware"
goperations "github.com/absmach/supermq/groups/operations"
"github.com/absmach/supermq/groups/postgres"
pgroups "github.com/absmach/supermq/groups/private"
smqlog "github.com/absmach/supermq/logger"
@@ -378,8 +379,8 @@ func newService(ctx context.Context, authz smqauthz.Authorization, policy polici
policies.DomainType: domainOps,
},
permissions.EntitiesOperationDetails[permissions.Operation]{
policies.GroupType: groups.OperationDetails(),
policies.DomainType: domains.OperationDetails(),
policies.GroupType: goperations.OperationDetails(),
policies.DomainType: doperations.OperationDetails(),
},
)
if err != nil {
-2
View File
@@ -130,5 +130,3 @@ domains:
- check_members_exists: view_role_users_permission
- remove_members: remove_role_users_permission
- remove_all_members: remove_role_users_permission
+6 -15
View File
@@ -8,6 +8,7 @@ import (
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
smqauthz "github.com/absmach/supermq/pkg/authz"
@@ -61,8 +62,7 @@ func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session a
return am.svc.RetrieveDomain(ctx, session, id, withRoles)
}
if err := am.authorize(ctx, policies.DomainType, domains.OpRetrieveDomain, authz.PolicyReq{
TokenType: session.Type,
if err := am.authorize(ctx, policies.DomainType, operations.OpRetrieveDomain, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -76,8 +76,7 @@ func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session a
}
func (am *authorizationMiddleware) UpdateDomain(ctx context.Context, session authn.Session, id string, d domains.DomainReq) (domains.Domain, error) {
if err := am.authorize(ctx, policies.DomainType, domains.OpUpdateDomain, authz.PolicyReq{
TokenType: session.Type,
if err := am.authorize(ctx, policies.DomainType, operations.OpUpdateDomain, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -91,8 +90,7 @@ func (am *authorizationMiddleware) UpdateDomain(ctx context.Context, session aut
}
func (am *authorizationMiddleware) EnableDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := am.authorize(ctx, policies.DomainType, domains.OpEnableDomain, authz.PolicyReq{
TokenType: session.Type,
if err := am.authorize(ctx, policies.DomainType, operations.OpEnableDomain, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -106,8 +104,7 @@ func (am *authorizationMiddleware) EnableDomain(ctx context.Context, session aut
}
func (am *authorizationMiddleware) DisableDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
if err := am.authorize(ctx, policies.DomainType, domains.OpDisableDomain, authz.PolicyReq{
TokenType: session.Type,
if err := am.authorize(ctx, policies.DomainType, operations.OpDisableDomain, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -123,7 +120,6 @@ func (am *authorizationMiddleware) DisableDomain(ctx context.Context, session au
func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) {
// Only SuperAdmin can freeze the domain
if err := am.authz.Authorize(ctx, authz.PolicyReq{
TokenType: session.Type,
Subject: session.UserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -148,7 +144,6 @@ func (am *authorizationMiddleware) ListDomains(ctx context.Context, session auth
func (am *authorizationMiddleware) SendInvitation(ctx context.Context, session authn.Session, invitation domains.Invitation) (domains.Invitation, error) {
domainUserId := auth.EncodeDomainUserID(invitation.DomainID, invitation.InviteeUserID)
req := authz.PolicyReq{
TokenType: session.Type,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: domainUserId,
@@ -173,8 +168,7 @@ func (am *authorizationMiddleware) ListInvitations(ctx context.Context, session
}
func (am *authorizationMiddleware) ListDomainInvitations(ctx context.Context, session authn.Session, page domains.InvitationPageMeta) (invs domains.InvitationPage, err error) {
if err := am.authorize(ctx, policies.DomainType, domains.OpListDomainInvitations, authz.PolicyReq{
TokenType: session.Type,
if err := am.authorize(ctx, policies.DomainType, operations.OpListDomainInvitations, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -222,7 +216,6 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, entityType str
// checkAdmin checks if the given user is a domain or platform administrator.
func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn.Session) error {
req := smqauthz.PolicyReq{
TokenType: session.Type,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.DomainUserID,
@@ -235,7 +228,6 @@ func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn
}
req = smqauthz.PolicyReq{
TokenType: session.Type,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
Subject: session.UserID,
@@ -256,7 +248,6 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
return svcerr.ErrSuperAdminAction
}
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
TokenType: session.Type,
SubjectType: policies.UserType,
Subject: session.UserID,
Permission: policies.AdminPermission,
+14 -13
View File
@@ -8,6 +8,7 @@ import (
"time"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/permissions"
@@ -48,7 +49,7 @@ func (cm *calloutMiddleware) CreateDomain(ctx context.Context, session authn.Ses
"entity_id": d.ID,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpCreateDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpCreateDomain, params); err != nil {
return domains.Domain{}, nil, err
}
@@ -61,7 +62,7 @@ func (cm *calloutMiddleware) RetrieveDomain(ctx context.Context, session authn.S
"with_roles": withRoles,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpRetrieveDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpRetrieveDomain, params); err != nil {
return domains.Domain{}, err
}
@@ -74,7 +75,7 @@ func (cm *calloutMiddleware) UpdateDomain(ctx context.Context, session authn.Ses
"domain_req": d,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpUpdateDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpUpdateDomain, params); err != nil {
return domains.Domain{}, err
}
@@ -86,7 +87,7 @@ func (cm *calloutMiddleware) EnableDomain(ctx context.Context, session authn.Ses
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpEnableDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpEnableDomain, params); err != nil {
return domains.Domain{}, err
}
@@ -98,7 +99,7 @@ func (cm *calloutMiddleware) DisableDomain(ctx context.Context, session authn.Se
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpDisableDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpDisableDomain, params); err != nil {
return domains.Domain{}, err
}
@@ -110,7 +111,7 @@ func (cm *calloutMiddleware) FreezeDomain(ctx context.Context, session authn.Ses
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpFreezeDomain, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpFreezeDomain, params); err != nil {
return domains.Domain{}, err
}
@@ -122,7 +123,7 @@ func (cm *calloutMiddleware) ListDomains(ctx context.Context, session authn.Sess
"page": page,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListDomains, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpListDomains, params); err != nil {
return domains.DomainsPage{}, err
}
@@ -137,7 +138,7 @@ func (cm *calloutMiddleware) SendInvitation(ctx context.Context, session authn.S
// While entity here is technically an invitation, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpSendDomainInvitation, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpSendDomainInvitation, params); err != nil {
return domains.Invitation{}, err
}
@@ -149,7 +150,7 @@ func (cm *calloutMiddleware) ListInvitations(ctx context.Context, session authn.
"page": page,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListInvitations, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpListInvitations, params); err != nil {
return domains.InvitationPage{}, err
}
@@ -162,7 +163,7 @@ func (cm *calloutMiddleware) ListDomainInvitations(ctx context.Context, session
"page": page,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListDomainInvitations, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpListDomainInvitations, params); err != nil {
return domains.InvitationPage{}, err
}
@@ -176,7 +177,7 @@ func (cm *calloutMiddleware) AcceptInvitation(ctx context.Context, session authn
// Similar to sending an invitation, Domain is used as the
// entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpAcceptInvitation, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpAcceptInvitation, params); err != nil {
return domains.Invitation{}, err
}
@@ -190,7 +191,7 @@ func (cm *calloutMiddleware) RejectInvitation(ctx context.Context, session authn
// Similar to sending and accepting, Domain is used as
// the entity in callout since the invitation refers to the domain.
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpRejectInvitation, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpRejectInvitation, params); err != nil {
return domains.Invitation{}, err
}
@@ -203,7 +204,7 @@ func (cm *calloutMiddleware) DeleteInvitation(ctx context.Context, session authn
"invitee_user_id": inviteeUserID,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpDeleteDomainInvitation, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, operations.OpDeleteDomainInvitation, params); err != nil {
return err
}
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package domains
package operations
import (
"github.com/absmach/supermq/pkg/permissions"
+29 -55
View File
@@ -8,8 +8,9 @@ import (
"fmt"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/domains"
dOperations "github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/groups/operations"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
@@ -62,7 +63,7 @@ func NewAuthorization(
if err := entitiesOps.Validate(); err != nil {
return nil, err
}
ram, err := rolemgr.NewAuthorization(entityType, svc, authz, roleOps)
ram, err := rolemgr.NewAuthorization(policies.GroupType, svc, authz, roleOps)
if err != nil {
return nil, err
}
@@ -77,7 +78,7 @@ func NewAuthorization(
}
func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) {
if err := am.authorize(ctx, session, policies.DomainType, domains.OpCreateDomainGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpCreateDomainGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -89,7 +90,7 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth
}
if g.Parent != "" {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpAddChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpAddChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -105,7 +106,7 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth
}
func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpUpdateGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpUpdateGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -120,7 +121,7 @@ func (am *authorizationMiddleware) UpdateGroup(ctx context.Context, session auth
}
func (am *authorizationMiddleware) UpdateGroupTags(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpUpdateGroupTags, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpUpdateGroupTags, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -134,7 +135,7 @@ func (am *authorizationMiddleware) UpdateGroupTags(ctx context.Context, session
}
func (am *authorizationMiddleware) ViewGroup(ctx context.Context, session authn.Session, id string, withRoles bool) (groups.Group, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpViewGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpViewGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -153,7 +154,7 @@ func (am *authorizationMiddleware) ListGroups(ctx context.Context, session authn
session.SuperAdmin = true
return am.svc.ListGroups(ctx, session, gm)
}
if err := am.authorize(ctx, session, policies.DomainType, domains.OpListDomainGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpListDomainGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -172,7 +173,7 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a
session.SuperAdmin = true
return am.svc.ListGroups(ctx, session, pm)
}
if err := am.authorize(ctx, session, policies.DomainType, domains.OpListDomainGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.DomainType, dOperations.OpListDomainGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
SubjectKind: policies.UsersKind,
@@ -187,7 +188,7 @@ func (am *authorizationMiddleware) ListUserGroups(ctx context.Context, session a
}
func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpEnableGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpEnableGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -201,7 +202,7 @@ func (am *authorizationMiddleware) EnableGroup(ctx context.Context, session auth
}
func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session authn.Session, id string) (groups.Group, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpDisableGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpDisableGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -215,7 +216,7 @@ func (am *authorizationMiddleware) DisableGroup(ctx context.Context, session aut
}
func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpDeleteGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpDeleteGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -229,7 +230,7 @@ func (am *authorizationMiddleware) DeleteGroup(ctx context.Context, session auth
}
func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, session authn.Session, id string, hm groups.HierarchyPageMeta) (groups.HierarchyPage, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpRetrieveGroupHierarchy, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpRetrieveGroupHierarchy, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -243,7 +244,7 @@ func (am *authorizationMiddleware) RetrieveGroupHierarchy(ctx context.Context, s
}
func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session authn.Session, id, parentID string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpAddParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpAddParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -253,7 +254,7 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a
return errors.Wrap(errSetParentGroup, err)
}
if err := am.authorize(ctx, session, policies.GroupType, groups.OpAddChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpAddChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -267,7 +268,7 @@ func (am *authorizationMiddleware) AddParentGroup(ctx context.Context, session a
}
func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpRemoveParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpRemoveParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -283,7 +284,7 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
if group.Parent != "" {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpRemoveParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpRemoveParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -298,7 +299,7 @@ func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, sessio
}
func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpAddChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpAddChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -309,7 +310,7 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio
}
for _, childID := range childrenGroupIDs {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpAddParentGroup, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpAddParentGroup, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -324,7 +325,7 @@ func (am *authorizationMiddleware) AddChildrenGroups(ctx context.Context, sessio
}
func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, session authn.Session, id string, childrenGroupIDs []string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpRemoveChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpRemoveChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -338,7 +339,7 @@ func (am *authorizationMiddleware) RemoveChildrenGroups(ctx context.Context, ses
}
func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context, session authn.Session, id string) error {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpRemoveAllChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpRemoveAllChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -352,7 +353,7 @@ func (am *authorizationMiddleware) RemoveAllChildrenGroups(ctx context.Context,
}
func (am *authorizationMiddleware) ListChildrenGroups(ctx context.Context, session authn.Session, id string, startLevel, endLevel int64, pm groups.PageMeta) (groups.Page, error) {
if err := am.authorize(ctx, session, policies.GroupType, groups.OpListChildrenGroups, smqauthz.PolicyReq{
if err := am.authorize(ctx, session, policies.GroupType, operations.OpListChildrenGroups, smqauthz.PolicyReq{
Domain: session.DomainID,
SubjectType: policies.UserType,
Subject: session.DomainUserID,
@@ -382,10 +383,9 @@ func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, session
}
func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.Session, entityType string, op permissions.Operation, pr smqauthz.PolicyReq) error {
pr.TokenType = session.Type
pr.UserID = session.UserID
pr.PatID = session.PatID
pr.OptionalDomainID = session.DomainID
pr.Domain = session.DomainID
perm, err := am.entitiesOps.GetPermission(entityType, op)
if err != nil {
@@ -393,37 +393,11 @@ func (am *authorizationMiddleware) authorize(ctx context.Context, session authn.
}
pr.Permission = perm.String()
if pr.PatID != "" && pr.TokenType == authn.PersonalAccessToken {
pr.EntityType = auth.GroupsType
pr.EntityID = pr.Object
switch op {
case groups.OpViewGroup:
pr.Operation = auth.ReadOp
case groups.OpListUserGroups,
groups.OpRetrieveGroupHierarchy,
groups.OpListChildrenGroups,
domains.OpListDomainGroups:
pr.Operation = auth.ListOp
if op == domains.OpListDomainGroups {
pr.EntityID = auth.AnyIDs
}
case groups.OpUpdateGroup,
groups.OpUpdateGroupTags,
groups.OpEnableGroup,
groups.OpDisableGroup,
groups.OpAddParentGroup,
groups.OpAddChildrenGroups:
pr.Operation = auth.UpdateOp
case groups.OpDeleteGroup,
groups.OpRemoveParentGroup,
groups.OpRemoveChildrenGroups,
groups.OpRemoveAllChildrenGroups:
pr.Operation = auth.DeleteOp
case domains.OpCreateDomainGroups:
pr.Operation = auth.CreateOp
pr.EntityID = auth.AnyIDs
}
pr.EntityID = pr.Object
pr.EntityType = auth.GroupsType.String()
pr.Operation = am.entitiesOps.OperationName(entityType, op)
if op == dOperations.OpListDomainGroups || op == dOperations.OpCreateDomainGroups {
pr.EntityID = auth.AnyIDs
}
if err := am.authz.Authorize(ctx, pr); err != nil {
+18 -17
View File
@@ -7,8 +7,9 @@ import (
"context"
"time"
"github.com/absmach/supermq/domains"
dOperations "github.com/absmach/supermq/domains/operations"
"github.com/absmach/supermq/groups"
"github.com/absmach/supermq/groups/operations"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/callout"
"github.com/absmach/supermq/pkg/errors"
@@ -54,7 +55,7 @@ func (cm *calloutMiddleware) CreateGroup(ctx context.Context, session authn.Sess
"count": 1,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpCreateDomainGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpCreateDomainGroups, params); err != nil {
return groups.Group{}, nil, err
}
@@ -67,7 +68,7 @@ func (cm *calloutMiddleware) UpdateGroup(ctx context.Context, session authn.Sess
"group": group,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpUpdateGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpUpdateGroup, params); err != nil {
return groups.Group{}, err
}
@@ -80,7 +81,7 @@ func (cm *calloutMiddleware) UpdateGroupTags(ctx context.Context, session authn.
"tags": group.Tags,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpUpdateGroupTags, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpUpdateGroupTags, params); err != nil {
return groups.Group{}, err
}
@@ -92,7 +93,7 @@ func (cm *calloutMiddleware) ViewGroup(ctx context.Context, session authn.Sessio
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpViewGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpViewGroup, params); err != nil {
return groups.Group{}, err
}
@@ -104,7 +105,7 @@ func (cm *calloutMiddleware) ListGroups(ctx context.Context, session authn.Sessi
"pagemeta": gm,
}
if err := cm.callOut(ctx, session, policies.DomainType, domains.OpListDomainGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.DomainType, dOperations.OpListDomainGroups, params); err != nil {
return groups.Page{}, err
}
@@ -117,7 +118,7 @@ func (cm *calloutMiddleware) ListUserGroups(ctx context.Context, session authn.S
"pagemeta": gm,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpListUserGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpListUserGroups, params); err != nil {
return groups.Page{}, err
}
@@ -129,7 +130,7 @@ func (cm *calloutMiddleware) EnableGroup(ctx context.Context, session authn.Sess
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpEnableGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpEnableGroup, params); err != nil {
return groups.Group{}, err
}
@@ -141,7 +142,7 @@ func (cm *calloutMiddleware) DisableGroup(ctx context.Context, session authn.Ses
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpDisableGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpDisableGroup, params); err != nil {
return groups.Group{}, err
}
@@ -153,7 +154,7 @@ func (cm *calloutMiddleware) DeleteGroup(ctx context.Context, session authn.Sess
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpDeleteGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpDeleteGroup, params); err != nil {
return err
}
@@ -166,7 +167,7 @@ func (cm *calloutMiddleware) RetrieveGroupHierarchy(ctx context.Context, session
"hierarchy_pagemeta": hm,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpRetrieveGroupHierarchy, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpRetrieveGroupHierarchy, params); err != nil {
return groups.HierarchyPage{}, err
}
@@ -179,7 +180,7 @@ func (cm *calloutMiddleware) AddParentGroup(ctx context.Context, session authn.S
"parent_id": parentID,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpAddParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpAddParentGroup, params); err != nil {
return err
}
@@ -197,7 +198,7 @@ func (cm *calloutMiddleware) RemoveParentGroup(ctx context.Context, session auth
"parent_id": group.Parent,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpRemoveParentGroup, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpRemoveParentGroup, params); err != nil {
return err
}
@@ -210,7 +211,7 @@ func (cm *calloutMiddleware) AddChildrenGroups(ctx context.Context, session auth
"children_group_ids": childrenGroupIDs,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpAddChildrenGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpAddChildrenGroups, params); err != nil {
return err
}
@@ -223,7 +224,7 @@ func (cm *calloutMiddleware) RemoveChildrenGroups(ctx context.Context, session a
"children_group_ids": childrenGroupIDs,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpRemoveChildrenGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpRemoveChildrenGroups, params); err != nil {
return err
}
@@ -235,7 +236,7 @@ func (cm *calloutMiddleware) RemoveAllChildrenGroups(ctx context.Context, sessio
"entity_id": id,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpRemoveAllChildrenGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpRemoveAllChildrenGroups, params); err != nil {
return err
}
@@ -250,7 +251,7 @@ func (cm *calloutMiddleware) ListChildrenGroups(ctx context.Context, session aut
"pagemeta": pm,
}
if err := cm.callOut(ctx, session, policies.GroupType, groups.OpListChildrenGroups, params); err != nil {
if err := cm.callOut(ctx, session, policies.GroupType, operations.OpListChildrenGroups, params); err != nil {
return groups.Page{}, err
}
@@ -1,7 +1,7 @@
// Copyright (c) Abstract Machines
// SPDX-License-Identifier: Apache-2.0
package groups
package operations
import "github.com/absmach/supermq/pkg/permissions"
-1
View File
@@ -110,7 +110,6 @@ func (h *handler) AuthPublish(ctx context.Context, topic *string, payload *[]byt
clientID, err := h.authAccess(ctx, s.Username, string(s.Password), domainID, channelID, connections.Publish, topicType)
if err != nil {
fmt.Println("AuthPublish authAccess error:", err)
return err
}
+19 -32
View File
@@ -9,7 +9,7 @@ option go_package = "github.com/absmach/supermq/api/grpc/auth/v1";
// AuthService is a service that provides authentication
// and authorization functionalities for SuperMQ services.
service AuthService {
rpc Authorize(AuthZReq) returns (AuthZRes) {}
rpc Authorize(PolicyReq) returns (AuthZRes) {}
rpc Authenticate(AuthNReq) returns (AuthNRes) {}
}
@@ -19,40 +19,27 @@ message AuthNReq {
}
message AuthNRes {
string id = 1; // token id
string user_id = 2; // user id
uint32 user_role = 3; // user role
bool verified = 4; // verified user
uint32 token_type = 5; // token type
string id = 1;
string user_id = 2;
uint32 user_role = 3;
bool verified = 4;
}
message PolicyReq {
uint32 token_type = 1; // Token type
string domain = 2; // Domain
string subject_type = 3; // Client or User
string subject_kind = 4; // ID or Token
string subject_relation = 5; // Subject relation
string subject = 6; // Subject value
string relation = 7; // Relation to filter
string permission = 8; // Action
string object = 9; // Object ID
string object_type = 10; // Client, User, Group
}
message PATReq {
string user_id = 1; // User id (PAT)
string pat_id = 2; // Pat id
uint32 entity_type = 3; // Entity type (PAT)
string optional_domain_id = 4; // Optional domain id (PAT)
uint32 operation = 5; // Operation (PAT)
string entity_id = 6; // EntityID (PAT)
}
message AuthZReq {
oneof auth_type {
PolicyReq policy = 1; // Policy-based authorization
PATReq pat = 2; // PAT authorization
}
string domain = 1;
string subject_type = 2;
string subject_kind = 3;
string subject_relation = 4;
string subject = 5;
string relation = 6;
string permission = 7;
string object = 8;
string object_type = 9;
string pat_id = 10;
string operation = 11;
string user_id = 12;
string entity_id = 13;
string entity_type = 14;
}
message AuthZRes {
+1 -2
View File
@@ -5,7 +5,6 @@ package authsvc
import (
"context"
"strings"
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
"github.com/absmach/supermq/auth/api/grpc/auth"
@@ -44,7 +43,7 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
if strings.HasPrefix(token, authn.PatPrefix) {
if res.GetId() != "" {
return authn.Session{Type: authn.PersonalAccessToken, PatID: res.GetId(), UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
}
+2 -2
View File
@@ -11,8 +11,8 @@ import (
"strconv"
apiutil "github.com/absmach/supermq/api/http/util"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
"github.com/absmach/supermq/pkg/policies"
"github.com/go-chi/chi/v5"
)
@@ -154,7 +154,7 @@ func (a *authnMiddleware) Middleware() func(http.Handler) http.Handler {
case AdminRole:
resp.DomainUserID = resp.UserID
case UserRole:
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
resp.DomainUserID = policies.EncodeDomainUserID(domain, resp.UserID)
}
}
+34 -67
View File
@@ -9,7 +9,6 @@ import (
grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1"
"github.com/absmach/supermq/auth/api/grpc/auth"
"github.com/absmach/supermq/domains"
"github.com/absmach/supermq/pkg/authn"
"github.com/absmach/supermq/pkg/authz"
pkgDomians "github.com/absmach/supermq/pkg/domains"
"github.com/absmach/supermq/pkg/errors"
@@ -48,28 +47,6 @@ func NewAuthorization(ctx context.Context, cfg grpcclient.Config, domainsAuthz p
}
func (a authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error {
if pr.PatID != "" && pr.TokenType == authn.PersonalAccessToken {
req := grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Pat{
Pat: &grpcAuthV1.PATReq{
UserId: pr.UserID,
PatId: pr.PatID,
EntityType: uint32(pr.EntityType),
OptionalDomainId: pr.OptionalDomainID,
Operation: uint32(pr.Operation),
EntityId: pr.EntityID,
},
},
}
res, err := a.authSvcClient.Authorize(ctx, &req)
if err != nil {
return errors.Wrap(errors.ErrAuthorization, err)
}
if !res.GetAuthorized() {
return errors.ErrAuthorization
}
}
if pr.SubjectType == policies.UserType && (pr.ObjectType == policies.GroupType || pr.ObjectType == policies.ClientType || pr.ObjectType == policies.DomainType) {
domainID := pr.Domain
if domainID == "" {
@@ -83,21 +60,23 @@ func (a authorization) Authorize(ctx context.Context, pr authz.PolicyReq) error
}
}
req := grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Domain: pr.Domain,
SubjectType: pr.SubjectType,
SubjectKind: pr.SubjectKind,
SubjectRelation: pr.SubjectRelation,
Subject: pr.Subject,
Relation: pr.Relation,
Permission: pr.Permission,
Object: pr.Object,
ObjectType: pr.ObjectType,
},
},
req := grpcAuthV1.PolicyReq{
Domain: pr.Domain,
SubjectType: pr.SubjectType,
SubjectKind: pr.SubjectKind,
SubjectRelation: pr.SubjectRelation,
Subject: pr.Subject,
Relation: pr.Relation,
Permission: pr.Permission,
Object: pr.Object,
ObjectType: pr.ObjectType,
PatId: pr.PatID,
Operation: pr.Operation,
UserId: pr.UserID,
EntityId: pr.EntityID,
EntityType: pr.EntityType,
}
res, err := a.authSvcClient.Authorize(ctx, &req)
if err != nil {
return errors.Wrap(errors.ErrAuthorization, err)
@@ -116,44 +95,32 @@ func (a authorization) checkDomain(ctx context.Context, subjectType, subject, do
switch status {
case domains.FreezeStatus:
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.AdminPermission,
Object: policies.SuperMQObject,
ObjectType: policies.PlatformType,
},
},
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.AdminPermission,
Object: policies.SuperMQObject,
ObjectType: policies.PlatformType,
})
return err
case domains.DisabledStatus:
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.AdminPermission,
Object: domainID,
ObjectType: policies.DomainType,
},
},
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.AdminPermission,
Object: domainID,
ObjectType: policies.DomainType,
})
return err
case domains.EnabledStatus:
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.AuthZReq{
AuthType: &grpcAuthV1.AuthZReq_Policy{
Policy: &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.MembershipPermission,
Object: domainID,
ObjectType: policies.DomainType,
},
},
_, err := a.authSvcClient.Authorize(ctx, &grpcAuthV1.PolicyReq{
Subject: subject,
SubjectType: subjectType,
Permission: policies.MembershipPermission,
Object: domainID,
ObjectType: policies.DomainType,
})
return err
+6 -12
View File
@@ -5,15 +5,9 @@ package authz
import (
"context"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/authn"
)
type PolicyReq struct {
// TokenType contains the token type. Used to differentiate between PAT and regular user tokens.
TokenType authn.TokenType `json:"token_type,omitempty"`
// Domain contains the domain ID.
Domain string `json:"domain,omitempty"`
@@ -50,12 +44,12 @@ type PolicyReq struct {
Permission string `json:"permission,omitempty"`
// PAT authorization fields
UserID string `json:"user_id,omitempty"` // UserID who owns the PAT
PatID string `json:"pat_id,omitempty"` // PAT ID
EntityType auth.EntityType `json:"entity_type,omitempty"` // Entity type
OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain ID for PAT scope checking
Operation auth.Operation `json:"operation,omitempty"` // Operation type
EntityID string `json:"entityID,omitempty"` // Entity ID
UserID string `json:"user_id,omitempty"`
PatID string `json:"pat_id,omitempty"`
EntityType string `json:"entity_type,omitempty"`
DomainID string `json:"domain_id,omitempty"`
Operation string `json:"operation,omitempty"`
EntityID string `json:"entity_id,omitempty"`
}
// Authz is supermq authorization library.
+6 -29
View File
@@ -42,43 +42,20 @@ func (pc *PermissionConfig) GetEntityPermissions(entityType string) (map[string]
operations := make(map[string]Permission)
for _, op := range entityPerms.Operations {
for name, perm := range op {
operations[name] = Permission(perm)
if perm != "" {
operations[name] = Permission(perm)
}
}
}
rolesOperations := make(map[string]Permission)
for _, op := range entityPerms.RolesOperations {
for name, perm := range op {
rolesOperations[name] = Permission(perm)
if perm != "" {
rolesOperations[name] = Permission(perm)
}
}
}
return operations, rolesOperations, nil
}
func BuildEntitiesPermissions(
config *PermissionConfig,
entityTypes []string,
operationDetailsFuncs map[string]func() map[Operation]OperationDetails,
) (EntitiesPermission, EntitiesOperationDetails[Operation], error) {
entitiesPermission := make(EntitiesPermission)
entitiesOperationDetails := make(EntitiesOperationDetails[Operation])
for _, entityType := range entityTypes {
ops, _, err := config.GetEntityPermissions(entityType)
if err != nil {
return nil, nil, fmt.Errorf("failed to get permissions for %s: %w", entityType, err)
}
entitiesPermission[entityType] = ops
detailsFunc, ok := operationDetailsFuncs[entityType]
if !ok {
return nil, nil, fmt.Errorf("operation details function not found for entity type %s", entityType)
}
entitiesOperationDetails[entityType] = detailsFunc()
}
return entitiesPermission, entitiesOperationDetails, nil
}
+5 -7
View File
@@ -52,16 +52,14 @@ type Policy struct {
// PAT authorization fields
// UserID contains the user ID who owns the PAT.
UserID string `json:"user_id,omitempty"`
// PatID contains the personal access token ID.
PatID string `json:"pat_id,omitempty"`
// EntityType contains the entity type for PAT authorization.
EntityType uint32 `json:"entity_type,omitempty"`
// OptionalDomainID contains the optional domain ID for PAT scope checking.
OptionalDomainID string `json:"optional_domain_id,omitempty"`
// Operation contains the operation type for PAT authorization.
Operation uint32 `json:"operation,omitempty"`
Operation string `json:"operation,omitempty"`
// UserID contains the user ID for PAT authorization.
UserID string `json:"user_id,omitempty"`
// EntityType contains the entity type for PAT authorization.
EntityType string `json:"entity_type,omitempty"`
// EntityID contains the entity ID for PAT authorization.
EntityID string `json:"entity_id,omitempty"`
}
@@ -5,7 +5,9 @@ package middleware
import (
"context"
"fmt"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/authn"
smqauthz "github.com/absmach/supermq/pkg/authz"
"github.com/absmach/supermq/pkg/errors"
@@ -40,7 +42,7 @@ func NewAuthorization(entityType string, svc roles.RoleManager, authz smqauthz.A
}
func (ram RoleManagerAuthorizationMiddleware) AddRole(ctx context.Context, session authn.Session, entityID, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) {
if err := ram.authorize(ctx, roles.OpAddRole, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpAddRole, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -57,7 +59,7 @@ func (ram RoleManagerAuthorizationMiddleware) AddRole(ctx context.Context, sessi
}
func (ram RoleManagerAuthorizationMiddleware) RemoveRole(ctx context.Context, session authn.Session, entityID, roleID string) error {
if err := ram.authorize(ctx, roles.OpRemoveRole, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRemoveRole, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -71,7 +73,7 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveRole(ctx context.Context, se
}
func (ram RoleManagerAuthorizationMiddleware) UpdateRoleName(ctx context.Context, session authn.Session, entityID, roleID, newRoleName string) (roles.Role, error) {
if err := ram.authorize(ctx, roles.OpUpdateRoleName, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpUpdateRoleName, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -85,7 +87,7 @@ func (ram RoleManagerAuthorizationMiddleware) UpdateRoleName(ctx context.Context
}
func (ram RoleManagerAuthorizationMiddleware) RetrieveRole(ctx context.Context, session authn.Session, entityID, roleID string) (roles.Role, error) {
if err := ram.authorize(ctx, roles.OpRetrieveRole, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRetrieveRole, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -99,7 +101,7 @@ func (ram RoleManagerAuthorizationMiddleware) RetrieveRole(ctx context.Context,
}
func (ram RoleManagerAuthorizationMiddleware) RetrieveAllRoles(ctx context.Context, session authn.Session, entityID string, limit, offset uint64) (roles.RolePage, error) {
if err := ram.authorize(ctx, roles.OpRetrieveAllRoles, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRetrieveAllRoles, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -117,7 +119,7 @@ func (ram RoleManagerAuthorizationMiddleware) ListAvailableActions(ctx context.C
}
func (ram RoleManagerAuthorizationMiddleware) RoleAddActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) (ops []string, err error) {
if err := ram.authorize(ctx, roles.OpRoleAddActions, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleAddActions, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -132,7 +134,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddActions(ctx context.Context
}
func (ram RoleManagerAuthorizationMiddleware) RoleListActions(ctx context.Context, session authn.Session, entityID, roleID string) ([]string, error) {
if err := ram.authorize(ctx, roles.OpRoleListActions, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleListActions, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -147,7 +149,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListActions(ctx context.Contex
}
func (ram RoleManagerAuthorizationMiddleware) RoleCheckActionsExists(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) (bool, error) {
if err := ram.authorize(ctx, roles.OpRoleCheckActionsExists, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleCheckActionsExists, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -161,7 +163,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckActionsExists(ctx context
}
func (ram RoleManagerAuthorizationMiddleware) RoleRemoveActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) (err error) {
if err := ram.authorize(ctx, roles.OpRoleRemoveActions, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleRemoveActions, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -176,7 +178,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveActions(ctx context.Cont
}
func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllActions(ctx context.Context, session authn.Session, entityID, roleID string) error {
if err := ram.authorize(ctx, roles.OpRoleRemoveAllActions, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleRemoveAllActions, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -190,7 +192,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllActions(ctx context.C
}
func (ram RoleManagerAuthorizationMiddleware) RoleAddMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) ([]string, error) {
if err := ram.authorize(ctx, roles.OpRoleAddMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleAddMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -208,7 +210,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleAddMembers(ctx context.Context
}
func (ram RoleManagerAuthorizationMiddleware) RoleListMembers(ctx context.Context, session authn.Session, entityID, roleID string, limit, offset uint64) (roles.MembersPage, error) {
if err := ram.authorize(ctx, roles.OpRoleListMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleListMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -222,7 +224,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleListMembers(ctx context.Contex
}
func (ram RoleManagerAuthorizationMiddleware) RoleCheckMembersExists(ctx context.Context, session authn.Session, entityID, roleID string, members []string) (bool, error) {
if err := ram.authorize(ctx, roles.OpRoleCheckMembersExists, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleCheckMembersExists, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -236,7 +238,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleCheckMembersExists(ctx context
}
func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllMembers(ctx context.Context, session authn.Session, entityID, roleID string) (err error) {
if err := ram.authorize(ctx, roles.OpRoleRemoveAllMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleRemoveAllMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -250,7 +252,7 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveAllMembers(ctx context.C
}
func (ram RoleManagerAuthorizationMiddleware) ListEntityMembers(ctx context.Context, session authn.Session, entityID string, pageQuery roles.MembersRolePageQuery) (roles.MembersRolePage, error) {
if err := ram.authorize(ctx, roles.OpRoleListMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleListMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -264,7 +266,7 @@ func (ram RoleManagerAuthorizationMiddleware) ListEntityMembers(ctx context.Cont
}
func (ram RoleManagerAuthorizationMiddleware) RemoveEntityMembers(ctx context.Context, session authn.Session, entityID string, members []string) error {
if err := ram.authorize(ctx, roles.OpRoleRemoveAllMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleRemoveAllMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -278,7 +280,7 @@ func (ram RoleManagerAuthorizationMiddleware) RemoveEntityMembers(ctx context.Co
}
func (ram RoleManagerAuthorizationMiddleware) RoleRemoveMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) (err error) {
if err := ram.authorize(ctx, roles.OpRoleRemoveMembers, smqauthz.PolicyReq{
if err := ram.authorize(ctx, session, roles.OpRoleRemoveMembers, smqauthz.PolicyReq{
Domain: session.DomainID,
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -291,7 +293,11 @@ func (ram RoleManagerAuthorizationMiddleware) RoleRemoveMembers(ctx context.Cont
return ram.svc.RoleRemoveMembers(ctx, session, entityID, roleID, members)
}
func (ram RoleManagerAuthorizationMiddleware) authorize(ctx context.Context, op permissions.RoleOperation, pr smqauthz.PolicyReq) error {
func (ram RoleManagerAuthorizationMiddleware) authorize(ctx context.Context, session authn.Session, op permissions.RoleOperation, pr smqauthz.PolicyReq) error {
pr.UserID = session.UserID
pr.PatID = session.PatID
pr.Domain = session.DomainID
perm, err := ram.ops.GetPermission(op)
if err != nil {
return err
@@ -299,6 +305,22 @@ func (ram RoleManagerAuthorizationMiddleware) authorize(ctx context.Context, op
pr.Permission = perm.String()
pr.EntityID = pr.Object
opName := ram.ops.OperationName(op)
var patEntityType string
switch pr.ObjectType {
case policies.GroupType:
patEntityType = auth.GroupsType.String()
case policies.ClientType:
patEntityType = auth.ClientsType.String()
case policies.ChannelType:
patEntityType = auth.ChannelsType.String()
default:
return errors.Wrap(errors.ErrAuthorization, fmt.Errorf("unsupported entity type for PAT: %s", pr.ObjectType))
}
pr.EntityType = patEntityType
pr.Operation = auth.RoleOperationPrefix + opName
if err := ram.authz.Authorize(ctx, pr); err != nil {
return errors.Wrap(errors.ErrAuthorization, err)
}