mirror of
https://github.com/absmach/magistrala.git
synced 2026-06-23 04:10:28 +00:00
SMQ-2609 - Enable superadmin to perform actions over entities (#2688)
Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
@@ -70,9 +70,9 @@ func (x *AuthNReq) GetToken() string {
|
||||
|
||||
type AuthNRes struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // id
|
||||
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
|
||||
DomainId string `protobuf:"bytes,3,opt,name=domain_id,json=domainId,proto3" json:"domain_id,omitempty"` // domain id
|
||||
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id
|
||||
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
|
||||
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -121,11 +121,11 @@ func (x *AuthNRes) GetUserId() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *AuthNRes) GetDomainId() string {
|
||||
func (x *AuthNRes) GetUserRole() uint32 {
|
||||
if x != nil {
|
||||
return x.DomainId
|
||||
return x.UserRole
|
||||
}
|
||||
return ""
|
||||
return 0
|
||||
}
|
||||
|
||||
type AuthZReq struct {
|
||||
@@ -382,7 +382,7 @@ const file_auth_v1_auth_proto_rawDesc = "" +
|
||||
"\bAuthNRes\x12\x0e\n" +
|
||||
"\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" +
|
||||
"\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" +
|
||||
"\tdomain_id\x18\x03 \x01(\tR\bdomainId\"\xa2\x02\n" +
|
||||
"\tuser_role\x18\x03 \x01(\rR\buserRole\"\xa2\x02\n" +
|
||||
"\bAuthZReq\x12\x16\n" +
|
||||
"\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" +
|
||||
"\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" +
|
||||
|
||||
@@ -27,6 +27,7 @@ const (
|
||||
type IssueReq struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
|
||||
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
|
||||
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
@@ -69,6 +70,13 @@ func (x *IssueReq) GetUserId() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *IssueReq) GetUserRole() uint32 {
|
||||
if x != nil {
|
||||
return x.UserRole
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *IssueReq) GetType() uint32 {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
@@ -187,9 +195,10 @@ var File_token_v1_token_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_token_v1_token_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"7\n" +
|
||||
"\x14token/v1/token.proto\x12\btoken.v1\"T\n" +
|
||||
"\bIssueReq\x12\x17\n" +
|
||||
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x12\n" +
|
||||
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
|
||||
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
|
||||
"\x04type\x18\x03 \x01(\rR\x04type\"1\n" +
|
||||
"\n" +
|
||||
"RefreshReq\x12#\n" +
|
||||
|
||||
+6
-1
@@ -38,7 +38,12 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun
|
||||
return
|
||||
}
|
||||
resp.DomainID = domain
|
||||
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
|
||||
switch resp.Role {
|
||||
case smqauthn.AdminRole:
|
||||
resp.DomainUserID = resp.UserID
|
||||
case smqauthn.UserRole:
|
||||
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), SessionKey, resp)
|
||||
|
||||
@@ -75,7 +75,7 @@ func (client authGrpcClient) Authenticate(ctx context.Context, token *grpcAuthV1
|
||||
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
|
||||
}
|
||||
ir := res.(authenticateRes)
|
||||
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, DomainId: ir.domainID}, nil
|
||||
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole)}, nil
|
||||
}
|
||||
|
||||
func encodeIdentifyRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
@@ -85,7 +85,7 @@ func encodeIdentifyRequest(_ context.Context, grpcReq interface{}) (interface{},
|
||||
|
||||
func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
|
||||
res := grpcRes.(*grpcAuthV1.AuthNRes)
|
||||
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil
|
||||
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole)}, nil
|
||||
}
|
||||
|
||||
func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
|
||||
|
||||
@@ -23,7 +23,7 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
return authenticateRes{}, err
|
||||
}
|
||||
|
||||
return authenticateRes{id: key.Subject, userID: key.User, domainID: key.Domain}, nil
|
||||
return authenticateRes{userID: key.Subject, userRole: key.Role}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ func TestIdentify(t *testing.T) {
|
||||
{
|
||||
desc: "authenticate user with valid user token",
|
||||
token: validToken,
|
||||
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: email, DomainId: domainID},
|
||||
idt: &grpcAuthV1.AuthNRes{UserId: id, UserRole: uint32(auth.UserRole)},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
@@ -100,7 +100,7 @@ func TestIdentify(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
|
||||
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, Role: auth.UserRole}, tc.svcErr)
|
||||
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
|
||||
if idt != nil {
|
||||
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
|
||||
package auth
|
||||
|
||||
import smqauth "github.com/absmach/supermq/auth"
|
||||
|
||||
type authenticateRes struct {
|
||||
id string
|
||||
userID string
|
||||
domainID string
|
||||
userRole smqauth.Role
|
||||
}
|
||||
|
||||
type authorizeRes struct {
|
||||
|
||||
@@ -82,7 +82,7 @@ func decodeAuthenticateRequest(_ context.Context, grpcReq interface{}) (interfac
|
||||
|
||||
func encodeAuthenticateResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
|
||||
res := grpcRes.(authenticateRes)
|
||||
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, DomainId: res.domainID}, nil
|
||||
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole)}, nil
|
||||
}
|
||||
|
||||
func encodeAuthenticatePATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
|
||||
|
||||
@@ -53,8 +53,9 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
|
||||
defer cancel()
|
||||
|
||||
res, err := client.issue(ctx, issueReq{
|
||||
userID: req.GetUserId(),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
|
||||
@@ -65,8 +66,9 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
|
||||
func encodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(issueReq)
|
||||
return &grpcTokenV1.IssueReq{
|
||||
UserId: req.userID,
|
||||
Type: uint32(req.keyType),
|
||||
UserId: req.userID,
|
||||
UserRole: uint32(req.userRole),
|
||||
Type: uint32(req.keyType),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,8 +18,9 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
|
||||
}
|
||||
|
||||
key := auth.Key{
|
||||
Type: req.keyType,
|
||||
User: req.userID,
|
||||
Type: req.keyType,
|
||||
Subject: req.userID,
|
||||
Role: req.userRole,
|
||||
}
|
||||
tkn, err := svc.Issue(ctx, "", key)
|
||||
if err != nil {
|
||||
|
||||
@@ -9,8 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type issueReq struct {
|
||||
userID string
|
||||
keyType auth.KeyType
|
||||
userID string
|
||||
userRole auth.Role
|
||||
keyType auth.KeyType
|
||||
}
|
||||
|
||||
func (req issueReq) validate() error {
|
||||
|
||||
@@ -55,8 +55,9 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *grpcTokenV1.RefreshR
|
||||
func decodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
|
||||
req := grpcReq.(*grpcTokenV1.IssueReq)
|
||||
return issueReq{
|
||||
userID: req.GetUserId(),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
userID: req.GetUserId(),
|
||||
userRole: auth.Role(req.GetUserRole()),
|
||||
keyType: auth.KeyType(req.GetType()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -37,6 +37,12 @@ const (
|
||||
invalidDuration = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
krepo *mocks.KeyRepository
|
||||
pEvaluator *policymocks.Evaluator
|
||||
callback *mocks.CallBack
|
||||
)
|
||||
|
||||
type issueRequest struct {
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
Type uint32 `json:"type,omitempty"`
|
||||
@@ -67,18 +73,18 @@ func (tr testRequest) make() (*http.Response, error) {
|
||||
return tr.client.Do(req)
|
||||
}
|
||||
|
||||
func newService() (auth.Service, *mocks.KeyRepository) {
|
||||
krepo := new(mocks.KeyRepository)
|
||||
func newService() auth.Service {
|
||||
krepo = new(mocks.KeyRepository)
|
||||
pRepo := new(mocks.PATSRepository)
|
||||
cache := new(mocks.Cache)
|
||||
hash := new(mocks.Hasher)
|
||||
idProvider := uuid.NewMock()
|
||||
pService := new(policymocks.Service)
|
||||
pEvaluator := new(policymocks.Evaluator)
|
||||
pEvaluator = new(policymocks.Evaluator)
|
||||
t := jwt.New([]byte(secret))
|
||||
callback := new(mocks.CallBack)
|
||||
callback = new(mocks.CallBack)
|
||||
|
||||
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration, callback), krepo
|
||||
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration, callback)
|
||||
}
|
||||
|
||||
func newServer(svc auth.Service) *httptest.Server {
|
||||
@@ -95,9 +101,13 @@ func toJSON(data interface{}) string {
|
||||
}
|
||||
|
||||
func TestIssue(t *testing.T) {
|
||||
svc, krepo := newService()
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
policyCall.Unset()
|
||||
callBackCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
defer ts.Close()
|
||||
@@ -196,16 +206,22 @@ func TestIssue(t *testing.T) {
|
||||
body: strings.NewReader(tc.req),
|
||||
}
|
||||
repocall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil)
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
repocall.Unset()
|
||||
policyCall.Unset()
|
||||
callBackCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieve(t *testing.T) {
|
||||
svc, krepo := newService()
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
|
||||
|
||||
@@ -213,6 +229,8 @@ func TestRetrieve(t *testing.T) {
|
||||
k, err := svc.Issue(context.Background(), token.AccessToken, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
repocall.Unset()
|
||||
policyCall.Unset()
|
||||
callBackCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
defer ts.Close()
|
||||
@@ -269,17 +287,23 @@ func TestRetrieve(t *testing.T) {
|
||||
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
|
||||
token: tc.token,
|
||||
}
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
|
||||
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err)
|
||||
res, err := req.make()
|
||||
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
|
||||
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
|
||||
repocall.Unset()
|
||||
policyCall.Unset()
|
||||
callBackCall.Unset()
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevoke(t *testing.T) {
|
||||
svc, krepo := newService()
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
|
||||
svc := newService()
|
||||
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
|
||||
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
|
||||
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
|
||||
|
||||
@@ -287,6 +311,8 @@ func TestRevoke(t *testing.T) {
|
||||
k, err := svc.Issue(context.Background(), token.AccessToken, key)
|
||||
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
|
||||
repocall.Unset()
|
||||
policyCall.Unset()
|
||||
callBackCall.Unset()
|
||||
|
||||
ts := newServer(svc)
|
||||
defer ts.Close()
|
||||
|
||||
+6
-38
@@ -20,11 +20,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
tokenType = "type"
|
||||
userField = "user"
|
||||
domainField = "domain"
|
||||
issuerName = "supermq.auth"
|
||||
secret = "test"
|
||||
tokenType = "type"
|
||||
roleField = "role"
|
||||
issuerName = "supermq.auth"
|
||||
secret = "test"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -39,10 +38,7 @@ func newToken(issuerName string, key auth.Key) string {
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, "r").
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(userField, key.User)
|
||||
if key.Domain != "" {
|
||||
builder.Claim(domainField, key.Domain)
|
||||
}
|
||||
builder.Claim(roleField, key.Role)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
@@ -73,8 +69,6 @@ func TestIssue(t *testing.T) {
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
User: testsutil.GenerateUUID(t),
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
@@ -86,8 +80,6 @@ func TestIssue(t *testing.T) {
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
User: testsutil.GenerateUUID(t),
|
||||
Domain: "",
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
@@ -98,8 +90,6 @@ func TestIssue(t *testing.T) {
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: "",
|
||||
User: testsutil.GenerateUUID(t),
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
@@ -110,8 +100,6 @@ func TestIssue(t *testing.T) {
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.KeyType(auth.InvitationKey + 1),
|
||||
Subject: testsutil.GenerateUUID(t),
|
||||
User: testsutil.GenerateUUID(t),
|
||||
Domain: testsutil.GenerateUUID(t),
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
},
|
||||
err: nil,
|
||||
@@ -122,8 +110,6 @@ func TestIssue(t *testing.T) {
|
||||
ID: testsutil.GenerateUUID(t),
|
||||
Type: auth.AccessKey,
|
||||
Subject: "",
|
||||
User: testsutil.GenerateUUID(t),
|
||||
Domain: "",
|
||||
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
|
||||
},
|
||||
@@ -158,11 +144,6 @@ func TestParse(t *testing.T) {
|
||||
expToken, err := tokenizer.Issue(expKey)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
|
||||
|
||||
emptyDomainKey := key()
|
||||
emptyDomainKey.Domain = ""
|
||||
emptyDomainToken, err := tokenizer.Issue(emptyDomainKey)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptySubjectKey := key()
|
||||
emptySubjectKey.Subject = ""
|
||||
emptySubjectToken, err := tokenizer.Issue(emptySubjectKey)
|
||||
@@ -174,9 +155,7 @@ func TestParse(t *testing.T) {
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
emptyKey := key()
|
||||
emptyKey.Domain = ""
|
||||
emptyKey.Subject = ""
|
||||
emptyToken, err := tokenizer.Issue(emptyKey)
|
||||
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
|
||||
|
||||
inValidToken := newToken("invalid", key())
|
||||
@@ -223,12 +202,6 @@ func TestParse(t *testing.T) {
|
||||
token: newToken(issuerName, key()),
|
||||
err: authjwt.ErrJSONHandle,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty domain",
|
||||
key: emptyDomainKey,
|
||||
token: emptyDomainToken,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty subject",
|
||||
key: emptySubjectKey,
|
||||
@@ -241,12 +214,6 @@ func TestParse(t *testing.T) {
|
||||
token: emptyTypeToken,
|
||||
err: errors.ErrAuthentication,
|
||||
},
|
||||
{
|
||||
desc: "parse token with empty domain and subject",
|
||||
key: emptyKey,
|
||||
token: emptyToken,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
@@ -266,6 +233,7 @@ func key() auth.Key {
|
||||
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
|
||||
Type: auth.AccessKey,
|
||||
Issuer: "supermq.auth",
|
||||
Role: auth.UserRole,
|
||||
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
|
||||
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
|
||||
ExpiresAt: exp,
|
||||
|
||||
+23
-8
@@ -6,8 +6,6 @@ package jwt
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/absmach/supermq/auth"
|
||||
"github.com/absmach/supermq/pkg/errors"
|
||||
@@ -21,6 +19,8 @@ var (
|
||||
errInvalidIssuer = errors.New("invalid token issuer value")
|
||||
// errInvalidType is returned when there is no type field.
|
||||
errInvalidType = errors.New("invalid token type")
|
||||
// errInvalidRole is returned when the role is invalid.
|
||||
errInvalidRole = errors.New("invalid role")
|
||||
// errJWTExpiryKey is used to check if the token is expired.
|
||||
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
|
||||
// ErrSignJWT indicates an error in signing jwt token.
|
||||
@@ -35,6 +35,7 @@ const (
|
||||
issuerName = "supermq.auth"
|
||||
tokenType = "type"
|
||||
userField = "user"
|
||||
RoleField = "role"
|
||||
oauthProviderField = "oauth_provider"
|
||||
oauthAccessTokenField = "access_token"
|
||||
oauthRefreshTokenField = "refresh_token"
|
||||
@@ -60,7 +61,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
|
||||
IssuedAt(key.IssuedAt).
|
||||
Claim(tokenType, key.Type).
|
||||
Expiration(key.ExpiresAt)
|
||||
builder.Claim(userField, key.User)
|
||||
builder.Claim(RoleField, key.Role)
|
||||
if key.Subject != "" {
|
||||
builder.Subject(key.Subject)
|
||||
}
|
||||
@@ -132,17 +133,31 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
ktype, err := strconv.ParseInt(fmt.Sprintf("%v", tType), 10, 64)
|
||||
if err != nil {
|
||||
return auth.Key{}, err
|
||||
kType, ok := tType.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
kt := auth.KeyType(ktype)
|
||||
kt := auth.KeyType(kType)
|
||||
if !kt.Validate() {
|
||||
return auth.Key{}, errInvalidType
|
||||
}
|
||||
|
||||
tRole, ok := tkn.Get(RoleField)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
kRole, ok := tRole.(float64)
|
||||
if !ok {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
kr := auth.Role(kRole)
|
||||
if !kr.Validate() {
|
||||
return auth.Key{}, errInvalidRole
|
||||
}
|
||||
|
||||
key.ID = tkn.JwtID()
|
||||
key.Type = auth.KeyType(ktype)
|
||||
key.Type = auth.KeyType(kType)
|
||||
key.Role = auth.Role(kRole)
|
||||
key.Issuer = tkn.Issuer()
|
||||
key.Subject = tkn.Subject()
|
||||
key.IssuedAt = tkn.IssuedAt()
|
||||
|
||||
+25
-5
@@ -57,14 +57,35 @@ func (kt KeyType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type Role uint32
|
||||
|
||||
const (
|
||||
UserRole Role = iota + 1
|
||||
AdminRole
|
||||
)
|
||||
|
||||
func (r Role) String() string {
|
||||
switch r {
|
||||
case UserRole:
|
||||
return "user"
|
||||
case AdminRole:
|
||||
return "admin"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (r Role) Validate() bool {
|
||||
return UserRole <= r && r <= AdminRole
|
||||
}
|
||||
|
||||
// Key represents API key.
|
||||
type Key struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type KeyType `json:"type,omitempty"`
|
||||
Issuer string `json:"issuer,omitempty"`
|
||||
Subject string `json:"subject,omitempty"` // user ID
|
||||
User string `json:"user,omitempty"`
|
||||
Domain string `json:"domain,omitempty"` // domain user ID
|
||||
Role Role `json:"role,omitempty"`
|
||||
IssuedAt time.Time `json:"issued_at,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
}
|
||||
@@ -75,11 +96,10 @@ func (key Key) String() string {
|
||||
type: %s,
|
||||
issuer_id: %s,
|
||||
subject: %s,
|
||||
user: %s,
|
||||
domain: %s,
|
||||
role: %s,
|
||||
iat: %v,
|
||||
eat: %v
|
||||
}`, key.ID, key.Type, key.Issuer, key.Subject, key.User, key.Domain, key.IssuedAt, key.ExpiresAt)
|
||||
}`, key.ID, key.Type, key.Issuer, key.Subject, key.Role, key.IssuedAt, key.ExpiresAt)
|
||||
}
|
||||
|
||||
// Expired verifies if the key is expired.
|
||||
|
||||
+47
-45
@@ -35,6 +35,7 @@ var (
|
||||
errRetrieve = errors.New("failed to retrieve key data")
|
||||
errIdentify = errors.New("failed to validate token")
|
||||
errPlatform = errors.New("invalid platform id")
|
||||
errRoleAuth = errors.New("failed to authorize user role")
|
||||
|
||||
errMalformedPAT = errors.New("malformed personal access token")
|
||||
errFailedToParseUUID = errors.New("failed to parse string to UUID")
|
||||
@@ -133,7 +134,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
|
||||
case RefreshKey:
|
||||
return svc.refreshKey(ctx, token, key)
|
||||
case RecoveryKey:
|
||||
return svc.tmpKey(recoveryDuration, key)
|
||||
return svc.tmpKey(ctx, recoveryDuration, key)
|
||||
case InvitationKey:
|
||||
return svc.invitationKey(ctx, key)
|
||||
default:
|
||||
@@ -205,7 +206,6 @@ func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
|
||||
return svcerr.ErrAuthentication
|
||||
}
|
||||
pr.Subject = key.Subject
|
||||
pr.Domain = key.Domain
|
||||
}
|
||||
if err := svc.checkPolicy(ctx, pr); err != nil {
|
||||
return err
|
||||
@@ -259,8 +259,11 @@ func (svc service) PolicyValidation(pr policies.Policy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (svc service) tmpKey(duration time.Duration, key Key) (Token, error) {
|
||||
func (svc service) tmpKey(ctx context.Context, duration time.Duration, key Key) (Token, error) {
|
||||
key.ExpiresAt = time.Now().Add(duration)
|
||||
if err := svc.checkUserRole(ctx, key); err != nil {
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
}
|
||||
value, err := svc.tokenizer.Issue(key)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
@@ -274,9 +277,8 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
|
||||
key.Type = AccessKey
|
||||
key.ExpiresAt = time.Now().Add(svc.loginDuration)
|
||||
|
||||
key.Subject, err = svc.checkUserDomain(ctx, key)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
if err := svc.checkUserRole(ctx, key); err != nil {
|
||||
return Token{}, errors.Wrap(errIssueUser, err)
|
||||
}
|
||||
|
||||
access, err := svc.tokenizer.Issue(key)
|
||||
@@ -299,9 +301,8 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) {
|
||||
key.Type = InvitationKey
|
||||
key.ExpiresAt = time.Now().Add(svc.invitationDuration)
|
||||
|
||||
key.Subject, err = svc.checkUserDomain(ctx, key)
|
||||
if err != nil {
|
||||
return Token{}, err
|
||||
if err := svc.checkUserRole(ctx, key); err != nil {
|
||||
return Token{}, errors.Wrap(errIssueTmp, err)
|
||||
}
|
||||
|
||||
access, err := svc.tokenizer.Issue(key)
|
||||
@@ -321,16 +322,13 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
|
||||
return Token{}, errIssueUser
|
||||
}
|
||||
key.ID = k.ID
|
||||
if key.Domain == "" {
|
||||
key.Domain = k.Domain
|
||||
}
|
||||
key.User = k.User
|
||||
key.Type = AccessKey
|
||||
key.Subject = k.Subject
|
||||
|
||||
key.Subject, err = svc.checkUserDomain(ctx, key)
|
||||
if err != nil {
|
||||
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
if err := svc.checkUserRole(ctx, key); err != nil {
|
||||
return Token{}, errors.Wrap(errIssueUser, err)
|
||||
}
|
||||
key.Role = k.Role
|
||||
|
||||
key.ExpiresAt = time.Now().Add(svc.loginDuration)
|
||||
access, err := svc.tokenizer.Issue(key)
|
||||
@@ -348,32 +346,33 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
|
||||
return Token{AccessToken: access, RefreshToken: refresh}, nil
|
||||
}
|
||||
|
||||
func (svc service) checkUserDomain(ctx context.Context, key Key) (subject string, err error) {
|
||||
if key.Domain != "" {
|
||||
// Check user is platform admin.
|
||||
func (svc service) checkUserRole(ctx context.Context, key Key) (err error) {
|
||||
switch key.Role {
|
||||
case AdminRole:
|
||||
if err = svc.Authorize(ctx, policies.Policy{
|
||||
Subject: key.User,
|
||||
Subject: key.Subject,
|
||||
SubjectType: policies.UserType,
|
||||
Permission: policies.AdminPermission,
|
||||
Object: policies.SuperMQObject,
|
||||
ObjectType: policies.PlatformType,
|
||||
}); err == nil {
|
||||
return key.User, nil
|
||||
}); err != nil {
|
||||
return errRoleAuth
|
||||
}
|
||||
// Check user is domain member.
|
||||
domainUserSubject := EncodeDomainUserID(key.Domain, key.User)
|
||||
return nil
|
||||
case UserRole:
|
||||
if err = svc.Authorize(ctx, policies.Policy{
|
||||
Subject: domainUserSubject,
|
||||
Subject: key.Subject,
|
||||
SubjectType: policies.UserType,
|
||||
Permission: policies.MembershipPermission,
|
||||
Object: key.Domain,
|
||||
ObjectType: policies.DomainType,
|
||||
Object: policies.SuperMQObject,
|
||||
ObjectType: policies.PlatformType,
|
||||
}); err != nil {
|
||||
return "", err
|
||||
return errRoleAuth
|
||||
}
|
||||
return domainUserSubject, nil
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) {
|
||||
@@ -386,6 +385,9 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e
|
||||
if key.Subject == "" {
|
||||
key.Subject = sub
|
||||
}
|
||||
if err := svc.checkUserRole(ctx, key); err != nil {
|
||||
return Token{}, errors.Wrap(errIssueUser, err)
|
||||
}
|
||||
|
||||
keyID, err := svc.idProvider.ID()
|
||||
if err != nil {
|
||||
@@ -471,7 +473,7 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
secret, hash, err := svc.generateSecretAndHash(key.User, id)
|
||||
secret, hash, err := svc.generateSecretAndHash(key.Subject, id)
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
@@ -479,7 +481,7 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin
|
||||
now := time.Now()
|
||||
pat := PAT{
|
||||
ID: id,
|
||||
User: key.User,
|
||||
User: key.Subject,
|
||||
Name: name,
|
||||
Description: description,
|
||||
Secret: hash,
|
||||
@@ -502,7 +504,7 @@ func (svc service) UpdatePATName(ctx context.Context, token, patID, name string)
|
||||
if err != nil {
|
||||
return PAT{}, err
|
||||
}
|
||||
pat, err := svc.pats.UpdateName(ctx, key.User, patID, name)
|
||||
pat, err := svc.pats.UpdateName(ctx, key.Subject, patID, name)
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(errUpdatePAT, err)
|
||||
}
|
||||
@@ -514,7 +516,7 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr
|
||||
if err != nil {
|
||||
return PAT{}, err
|
||||
}
|
||||
pat, err := svc.pats.UpdateDescription(ctx, key.User, patID, description)
|
||||
pat, err := svc.pats.UpdateDescription(ctx, key.Subject, patID, description)
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(errUpdatePAT, err)
|
||||
}
|
||||
@@ -526,7 +528,7 @@ func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, e
|
||||
if err != nil {
|
||||
return PAT{}, err
|
||||
}
|
||||
pat, err := svc.pats.Retrieve(ctx, key.User, patID)
|
||||
pat, err := svc.pats.Retrieve(ctx, key.Subject, patID)
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(errRetrievePAT, err)
|
||||
}
|
||||
@@ -538,7 +540,7 @@ func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta)
|
||||
if err != nil {
|
||||
return PATSPage{}, err
|
||||
}
|
||||
patsPage, err := svc.pats.RetrieveAll(ctx, key.User, pm)
|
||||
patsPage, err := svc.pats.RetrieveAll(ctx, key.Subject, pm)
|
||||
if err != nil {
|
||||
return PATSPage{}, errors.Wrap(errRetrievePAT, err)
|
||||
}
|
||||
@@ -550,7 +552,7 @@ func (svc service) DeletePAT(ctx context.Context, token, patID string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := svc.pats.Remove(ctx, key.User, patID); err != nil {
|
||||
if err := svc.pats.Remove(ctx, key.Subject, patID); err != nil {
|
||||
return errors.Wrap(errDeletePAT, err)
|
||||
}
|
||||
return nil
|
||||
@@ -563,17 +565,17 @@ func (svc service) ResetPATSecret(ctx context.Context, token, patID string, dura
|
||||
}
|
||||
|
||||
// Generate new HashToken take place here
|
||||
secret, hash, err := svc.generateSecretAndHash(key.User, patID)
|
||||
secret, hash, err := svc.generateSecretAndHash(key.Subject, patID)
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
pat, err := svc.pats.UpdateTokenHash(ctx, key.User, patID, hash, time.Now().Add(duration))
|
||||
pat, err := svc.pats.UpdateTokenHash(ctx, key.Subject, patID, hash, time.Now().Add(duration))
|
||||
if err != nil {
|
||||
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
|
||||
if err := svc.pats.Reactivate(ctx, key.User, patID); err != nil {
|
||||
if err := svc.pats.Reactivate(ctx, key.Subject, patID); err != nil {
|
||||
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
pat.Secret = secret
|
||||
@@ -588,7 +590,7 @@ func (svc service) RevokePATSecret(ctx context.Context, token, patID string) err
|
||||
return err
|
||||
}
|
||||
|
||||
if err := svc.pats.Revoke(ctx, key.User, patID); err != nil {
|
||||
if err := svc.pats.Revoke(ctx, key.Subject, patID); err != nil {
|
||||
return errors.Wrap(svcerr.ErrUpdateEntity, err)
|
||||
}
|
||||
return nil
|
||||
@@ -599,7 +601,7 @@ func (svc service) RemoveAllPAT(ctx context.Context, token string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := svc.pats.RemoveAllPAT(ctx, key.User); err != nil {
|
||||
if err := svc.pats.RemoveAllPAT(ctx, key.Subject); err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
return nil
|
||||
@@ -620,7 +622,7 @@ func (svc service) AddScope(ctx context.Context, token, patID string, scopes []S
|
||||
scopes[i].PatID = patID
|
||||
}
|
||||
|
||||
err = svc.pats.AddScope(ctx, key.User, scopes)
|
||||
err = svc.pats.AddScope(ctx, key.Subject, scopes)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrCreateEntity, err)
|
||||
}
|
||||
@@ -633,7 +635,7 @@ func (svc service) RemoveScope(ctx context.Context, token, patID string, scopesI
|
||||
return err
|
||||
}
|
||||
|
||||
err = svc.pats.RemoveScope(ctx, key.User, scopesIDs...)
|
||||
err = svc.pats.RemoveScope(ctx, key.Subject, scopesIDs...)
|
||||
if err != nil {
|
||||
return errors.Wrap(svcerr.ErrRemoveEntity, err)
|
||||
}
|
||||
@@ -750,7 +752,7 @@ func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (
|
||||
return Key{}, err
|
||||
}
|
||||
|
||||
_, err = svc.pats.Retrieve(ctx, key.User, patID)
|
||||
_, err = svc.pats.Retrieve(ctx, key.Subject, patID)
|
||||
if err != nil {
|
||||
return Key{}, errors.Wrap(svcerr.ErrAuthorization, err)
|
||||
}
|
||||
|
||||
+195
-516
File diff suppressed because it is too large
Load Diff
@@ -318,9 +318,9 @@ definition domain {
|
||||
channel_update + channel_read + channel_delete + channel_set_parent_group + channel_connect_to_client + channel_publish + channel_subscribe +
|
||||
channel_manage_role + channel_add_role_users + channel_remove_role_users + channel_view_role_users +
|
||||
group_update + group_membership + group_read + group_delete + group_set_child + group_set_parent +
|
||||
group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users
|
||||
group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users + organization->admin
|
||||
|
||||
permission admin = read & update & enable & disable & delete & manage_role & add_role_users & remove_role_users & view_role_users
|
||||
permission admin = (read & update & enable & disable & delete & manage_role & add_role_users & remove_role_users & view_role_users) + organization->admin
|
||||
|
||||
permission client_create_permission = client_create + team->client_create + organization->admin
|
||||
permission channel_create_permission = channel_create + team->channel_create + organization->admin
|
||||
|
||||
@@ -58,6 +58,11 @@ func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session aut
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session authn.Session, id string, withRoles bool) (domains.Domain, error) {
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
|
||||
session.SuperAdmin = true
|
||||
return am.svc.RetrieveDomain(ctx, session, id, withRoles)
|
||||
}
|
||||
|
||||
if err := am.authorize(ctx, domains.OpRetrieveDomain, authz.PolicyReq{
|
||||
Subject: session.DomainUserID,
|
||||
SubjectType: policies.UserType,
|
||||
@@ -127,13 +132,7 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) ListDomains(ctx context.Context, session authn.Session, page domains.Page) (domains.DomainsPage, error) {
|
||||
if err := am.authz.Authorize(ctx, authz.PolicyReq{
|
||||
Subject: session.UserID,
|
||||
SubjectType: policies.UserType,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.SuperMQObject,
|
||||
}); err == nil {
|
||||
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
|
||||
session.SuperAdmin = true
|
||||
}
|
||||
|
||||
@@ -247,6 +246,19 @@ func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn
|
||||
return svcerr.ErrAuthorization
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, userID string) error {
|
||||
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
Subject: userID,
|
||||
Permission: policies.AdminPermission,
|
||||
ObjectType: policies.PlatformType,
|
||||
Object: policies.SuperMQObject,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm, objType, obj string) error {
|
||||
req := authz.PolicyReq{
|
||||
SubjectType: policies.UserType,
|
||||
|
||||
@@ -21,9 +21,9 @@ message AuthNReq {
|
||||
}
|
||||
|
||||
message AuthNRes {
|
||||
string id = 1; // id
|
||||
string id = 1; // token id
|
||||
string user_id = 2; // user id
|
||||
string domain_id = 3; // domain id
|
||||
uint32 user_role = 3; // user role
|
||||
}
|
||||
|
||||
message AuthZReq {
|
||||
|
||||
@@ -13,6 +13,7 @@ service TokenService {
|
||||
|
||||
message IssueReq {
|
||||
string user_id = 1;
|
||||
uint32 user_role = 2;
|
||||
uint32 type = 3;
|
||||
}
|
||||
|
||||
|
||||
+9
-1
@@ -27,13 +27,21 @@ func (t TokenType) String() string {
|
||||
}
|
||||
}
|
||||
|
||||
type Role uint32
|
||||
|
||||
const (
|
||||
UserRole Role = iota + 1
|
||||
AdminRole
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Type TokenType
|
||||
PatID string
|
||||
DomainUserID string
|
||||
UserID string
|
||||
DomainID string
|
||||
DomainUserID string
|
||||
SuperAdmin bool
|
||||
Role Role
|
||||
}
|
||||
|
||||
// Authn is supermq authentication library.
|
||||
|
||||
@@ -54,5 +54,5 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
|
||||
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
|
||||
}
|
||||
|
||||
return authn.Session{Type: authn.AccessToken, DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
|
||||
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
|
||||
}
|
||||
|
||||
+1
-1
@@ -110,7 +110,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
|
||||
}
|
||||
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, Type: uint32(smqauth.AccessKey)})
|
||||
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey)})
|
||||
if err != nil {
|
||||
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
|
||||
}
|
||||
|
||||
@@ -1395,7 +1395,7 @@ func TestIssueToken(t *testing.T) {
|
||||
|
||||
for _, tc := range cases {
|
||||
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
|
||||
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
|
||||
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
|
||||
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
|
||||
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
|
||||
if err == nil {
|
||||
@@ -1403,7 +1403,7 @@ func TestIssueToken(t *testing.T) {
|
||||
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
|
||||
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
|
||||
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
|
||||
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, Type: uint32(smqauth.AccessKey)})
|
||||
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
|
||||
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
|
||||
}
|
||||
authCall.Unset()
|
||||
|
||||
Reference in New Issue
Block a user