SMQ-2609 - Enable superadmin to perform actions over entities (#2688)

Signed-off-by: Felix Gateru <felix.gateru@gmail.com>
This commit is contained in:
Felix Gateru
2025-04-10 17:55:05 +03:00
committed by GitHub
parent a01d2571de
commit 299cee7771
26 changed files with 418 additions and 666 deletions
+7 -7
View File
@@ -70,9 +70,9 @@ func (x *AuthNReq) GetToken() string {
type AuthNRes struct {
state protoimpl.MessageState `protogen:"open.v1"`
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // id
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
DomainId string `protobuf:"bytes,3,opt,name=domain_id,json=domainId,proto3" json:"domain_id,omitempty"` // domain id
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` // token id
UserId string `protobuf:"bytes,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` // user id
UserRole uint32 `protobuf:"varint,3,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"` // user role
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -121,11 +121,11 @@ func (x *AuthNRes) GetUserId() string {
return ""
}
func (x *AuthNRes) GetDomainId() string {
func (x *AuthNRes) GetUserRole() uint32 {
if x != nil {
return x.DomainId
return x.UserRole
}
return ""
return 0
}
type AuthZReq struct {
@@ -382,7 +382,7 @@ const file_auth_v1_auth_proto_rawDesc = "" +
"\bAuthNRes\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\x12\x17\n" +
"\auser_id\x18\x02 \x01(\tR\x06userId\x12\x1b\n" +
"\tdomain_id\x18\x03 \x01(\tR\bdomainId\"\xa2\x02\n" +
"\tuser_role\x18\x03 \x01(\rR\buserRole\"\xa2\x02\n" +
"\bAuthZReq\x12\x16\n" +
"\x06domain\x18\x01 \x01(\tR\x06domain\x12!\n" +
"\fsubject_type\x18\x02 \x01(\tR\vsubjectType\x12!\n" +
+11 -2
View File
@@ -27,6 +27,7 @@ const (
type IssueReq struct {
state protoimpl.MessageState `protogen:"open.v1"`
UserId string `protobuf:"bytes,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"`
UserRole uint32 `protobuf:"varint,2,opt,name=user_role,json=userRole,proto3" json:"user_role,omitempty"`
Type uint32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
@@ -69,6 +70,13 @@ func (x *IssueReq) GetUserId() string {
return ""
}
func (x *IssueReq) GetUserRole() uint32 {
if x != nil {
return x.UserRole
}
return 0
}
func (x *IssueReq) GetType() uint32 {
if x != nil {
return x.Type
@@ -187,9 +195,10 @@ var File_token_v1_token_proto protoreflect.FileDescriptor
const file_token_v1_token_proto_rawDesc = "" +
"\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"7\n" +
"\x14token/v1/token.proto\x12\btoken.v1\"T\n" +
"\bIssueReq\x12\x17\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x12\n" +
"\auser_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n" +
"\tuser_role\x18\x02 \x01(\rR\buserRole\x12\x12\n" +
"\x04type\x18\x03 \x01(\rR\x04type\"1\n" +
"\n" +
"RefreshReq\x12#\n" +
+6 -1
View File
@@ -38,7 +38,12 @@ func AuthenticateMiddleware(authn smqauthn.Authentication, domainCheck bool) fun
return
}
resp.DomainID = domain
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
switch resp.Role {
case smqauthn.AdminRole:
resp.DomainUserID = resp.UserID
case smqauthn.UserRole:
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
}
}
ctx := context.WithValue(r.Context(), SessionKey, resp)
+2 -2
View File
@@ -75,7 +75,7 @@ func (client authGrpcClient) Authenticate(ctx context.Context, token *grpcAuthV1
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
}
ir := res.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, DomainId: ir.domainID}, nil
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID, UserRole: uint32(ir.userRole)}, nil
}
func encodeIdentifyRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
@@ -85,7 +85,7 @@ func encodeIdentifyRequest(_ context.Context, grpcReq interface{}) (interface{},
func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(*grpcAuthV1.AuthNRes)
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), userRole: auth.Role(res.UserRole)}, nil
}
func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
+1 -1
View File
@@ -23,7 +23,7 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
return authenticateRes{}, err
}
return authenticateRes{id: key.Subject, userID: key.User, domainID: key.Domain}, nil
return authenticateRes{userID: key.Subject, userRole: key.Role}, nil
}
}
+2 -2
View File
@@ -80,7 +80,7 @@ func TestIdentify(t *testing.T) {
{
desc: "authenticate user with valid user token",
token: validToken,
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: email, DomainId: domainID},
idt: &grpcAuthV1.AuthNRes{UserId: id, UserRole: uint32(auth.UserRole)},
err: nil,
},
{
@@ -100,7 +100,7 @@ func TestIdentify(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, Role: auth.UserRole}, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
+3 -1
View File
@@ -3,10 +3,12 @@
package auth
import smqauth "github.com/absmach/supermq/auth"
type authenticateRes struct {
id string
userID string
domainID string
userRole smqauth.Role
}
type authorizeRes struct {
+1 -1
View File
@@ -82,7 +82,7 @@ func decodeAuthenticateRequest(_ context.Context, grpcReq interface{}) (interfac
func encodeAuthenticateResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, DomainId: res.domainID}, nil
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, UserRole: uint32(res.userRole)}, nil
}
func encodeAuthenticatePATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
+6 -4
View File
@@ -53,8 +53,9 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
defer cancel()
res, err := client.issue(ctx, issueReq{
userID: req.GetUserId(),
keyType: auth.KeyType(req.GetType()),
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
})
if err != nil {
return &grpcTokenV1.Token{}, grpcapi.DecodeError(err)
@@ -65,8 +66,9 @@ func (client tokenGrpcClient) Issue(ctx context.Context, req *grpcTokenV1.IssueR
func encodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(issueReq)
return &grpcTokenV1.IssueReq{
UserId: req.userID,
Type: uint32(req.keyType),
UserId: req.userID,
UserRole: uint32(req.userRole),
Type: uint32(req.keyType),
}, nil
}
+3 -2
View File
@@ -18,8 +18,9 @@ func issueEndpoint(svc auth.Service) endpoint.Endpoint {
}
key := auth.Key{
Type: req.keyType,
User: req.userID,
Type: req.keyType,
Subject: req.userID,
Role: req.userRole,
}
tkn, err := svc.Issue(ctx, "", key)
if err != nil {
+3 -2
View File
@@ -9,8 +9,9 @@ import (
)
type issueReq struct {
userID string
keyType auth.KeyType
userID string
userRole auth.Role
keyType auth.KeyType
}
func (req issueReq) validate() error {
+3 -2
View File
@@ -55,8 +55,9 @@ func (s *tokenGrpcServer) Refresh(ctx context.Context, req *grpcTokenV1.RefreshR
func decodeIssueRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcTokenV1.IssueReq)
return issueReq{
userID: req.GetUserId(),
keyType: auth.KeyType(req.GetType()),
userID: req.GetUserId(),
userRole: auth.Role(req.GetUserRole()),
keyType: auth.KeyType(req.GetType()),
}, nil
}
+37 -11
View File
@@ -37,6 +37,12 @@ const (
invalidDuration = 7 * 24 * time.Hour
)
var (
krepo *mocks.KeyRepository
pEvaluator *policymocks.Evaluator
callback *mocks.CallBack
)
type issueRequest struct {
Duration time.Duration `json:"duration,omitempty"`
Type uint32 `json:"type,omitempty"`
@@ -67,18 +73,18 @@ func (tr testRequest) make() (*http.Response, error) {
return tr.client.Do(req)
}
func newService() (auth.Service, *mocks.KeyRepository) {
krepo := new(mocks.KeyRepository)
func newService() auth.Service {
krepo = new(mocks.KeyRepository)
pRepo := new(mocks.PATSRepository)
cache := new(mocks.Cache)
hash := new(mocks.Hasher)
idProvider := uuid.NewMock()
pService := new(policymocks.Service)
pEvaluator := new(policymocks.Evaluator)
pEvaluator = new(policymocks.Evaluator)
t := jwt.New([]byte(secret))
callback := new(mocks.CallBack)
callback = new(mocks.CallBack)
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration, callback), krepo
return auth.New(krepo, pRepo, cache, hash, idProvider, t, pEvaluator, pService, loginDuration, refreshDuration, invalidDuration, callback)
}
func newServer(svc auth.Service) *httptest.Server {
@@ -95,9 +101,13 @@ func toJSON(data interface{}) string {
}
func TestIssue(t *testing.T) {
svc, krepo := newService()
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
policyCall.Unset()
callBackCall.Unset()
ts := newServer(svc)
defer ts.Close()
@@ -196,16 +206,22 @@ func TestIssue(t *testing.T) {
body: strings.NewReader(tc.req),
}
repocall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil)
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
repocall.Unset()
policyCall.Unset()
callBackCall.Unset()
}
}
func TestRetrieve(t *testing.T) {
svc, krepo := newService()
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
@@ -213,6 +229,8 @@ func TestRetrieve(t *testing.T) {
k, err := svc.Issue(context.Background(), token.AccessToken, key)
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
repocall.Unset()
policyCall.Unset()
callBackCall.Unset()
ts := newServer(svc)
defer ts.Close()
@@ -269,17 +287,23 @@ func TestRetrieve(t *testing.T) {
url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id),
token: tc.token,
}
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err)
res, err := req.make()
assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err))
assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode))
repocall.Unset()
policyCall.Unset()
callBackCall.Unset()
}
}
func TestRevoke(t *testing.T) {
svc, krepo := newService()
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id})
svc := newService()
policyCall := pEvaluator.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil)
callBackCall := callback.On("Authorize", mock.Anything, mock.Anything).Return(nil)
token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, Role: auth.UserRole, IssuedAt: time.Now(), Subject: id})
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id}
@@ -287,6 +311,8 @@ func TestRevoke(t *testing.T) {
k, err := svc.Issue(context.Background(), token.AccessToken, key)
assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err))
repocall.Unset()
policyCall.Unset()
callBackCall.Unset()
ts := newServer(svc)
defer ts.Close()
+6 -38
View File
@@ -20,11 +20,10 @@ import (
)
const (
tokenType = "type"
userField = "user"
domainField = "domain"
issuerName = "supermq.auth"
secret = "test"
tokenType = "type"
roleField = "role"
issuerName = "supermq.auth"
secret = "test"
)
var (
@@ -39,10 +38,7 @@ func newToken(issuerName string, key auth.Key) string {
IssuedAt(key.IssuedAt).
Claim(tokenType, "r").
Expiration(key.ExpiresAt)
builder.Claim(userField, key.User)
if key.Domain != "" {
builder.Claim(domainField, key.Domain)
}
builder.Claim(roleField, key.Role)
if key.Subject != "" {
builder.Subject(key.Subject)
}
@@ -73,8 +69,6 @@ func TestIssue(t *testing.T) {
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: testsutil.GenerateUUID(t),
User: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
@@ -86,8 +80,6 @@ func TestIssue(t *testing.T) {
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: testsutil.GenerateUUID(t),
User: testsutil.GenerateUUID(t),
Domain: "",
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
@@ -98,8 +90,6 @@ func TestIssue(t *testing.T) {
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: "",
User: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
@@ -110,8 +100,6 @@ func TestIssue(t *testing.T) {
ID: testsutil.GenerateUUID(t),
Type: auth.KeyType(auth.InvitationKey + 1),
Subject: testsutil.GenerateUUID(t),
User: testsutil.GenerateUUID(t),
Domain: testsutil.GenerateUUID(t),
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
},
err: nil,
@@ -122,8 +110,6 @@ func TestIssue(t *testing.T) {
ID: testsutil.GenerateUUID(t),
Type: auth.AccessKey,
Subject: "",
User: testsutil.GenerateUUID(t),
Domain: "",
IssuedAt: time.Now().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: time.Now().Add(10 * time.Minute).Round(time.Second),
},
@@ -158,11 +144,6 @@ func TestParse(t *testing.T) {
expToken, err := tokenizer.Issue(expKey)
require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err))
emptyDomainKey := key()
emptyDomainKey.Domain = ""
emptyDomainToken, err := tokenizer.Issue(emptyDomainKey)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptySubjectKey := key()
emptySubjectKey.Subject = ""
emptySubjectToken, err := tokenizer.Issue(emptySubjectKey)
@@ -174,9 +155,7 @@ func TestParse(t *testing.T) {
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
emptyKey := key()
emptyKey.Domain = ""
emptyKey.Subject = ""
emptyToken, err := tokenizer.Issue(emptyKey)
require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err))
inValidToken := newToken("invalid", key())
@@ -223,12 +202,6 @@ func TestParse(t *testing.T) {
token: newToken(issuerName, key()),
err: authjwt.ErrJSONHandle,
},
{
desc: "parse token with empty domain",
key: emptyDomainKey,
token: emptyDomainToken,
err: nil,
},
{
desc: "parse token with empty subject",
key: emptySubjectKey,
@@ -241,12 +214,6 @@ func TestParse(t *testing.T) {
token: emptyTypeToken,
err: errors.ErrAuthentication,
},
{
desc: "parse token with empty domain and subject",
key: emptyKey,
token: emptyToken,
err: nil,
},
}
for _, tc := range cases {
@@ -266,6 +233,7 @@ func key() auth.Key {
ID: "66af4a67-3823-438a-abd7-efdb613eaef6",
Type: auth.AccessKey,
Issuer: "supermq.auth",
Role: auth.UserRole,
Subject: "66af4a67-3823-438a-abd7-efdb613eaef6",
IssuedAt: time.Now().UTC().Add(-10 * time.Second).Round(time.Second),
ExpiresAt: exp,
+23 -8
View File
@@ -6,8 +6,6 @@ package jwt
import (
"context"
"encoding/json"
"fmt"
"strconv"
"github.com/absmach/supermq/auth"
"github.com/absmach/supermq/pkg/errors"
@@ -21,6 +19,8 @@ var (
errInvalidIssuer = errors.New("invalid token issuer value")
// errInvalidType is returned when there is no type field.
errInvalidType = errors.New("invalid token type")
// errInvalidRole is returned when the role is invalid.
errInvalidRole = errors.New("invalid role")
// errJWTExpiryKey is used to check if the token is expired.
errJWTExpiryKey = errors.New(`"exp" not satisfied`)
// ErrSignJWT indicates an error in signing jwt token.
@@ -35,6 +35,7 @@ const (
issuerName = "supermq.auth"
tokenType = "type"
userField = "user"
RoleField = "role"
oauthProviderField = "oauth_provider"
oauthAccessTokenField = "access_token"
oauthRefreshTokenField = "refresh_token"
@@ -60,7 +61,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) {
IssuedAt(key.IssuedAt).
Claim(tokenType, key.Type).
Expiration(key.ExpiresAt)
builder.Claim(userField, key.User)
builder.Claim(RoleField, key.Role)
if key.Subject != "" {
builder.Subject(key.Subject)
}
@@ -132,17 +133,31 @@ func toKey(tkn jwt.Token) (auth.Key, error) {
if !ok {
return auth.Key{}, errInvalidType
}
ktype, err := strconv.ParseInt(fmt.Sprintf("%v", tType), 10, 64)
if err != nil {
return auth.Key{}, err
kType, ok := tType.(float64)
if !ok {
return auth.Key{}, errInvalidType
}
kt := auth.KeyType(ktype)
kt := auth.KeyType(kType)
if !kt.Validate() {
return auth.Key{}, errInvalidType
}
tRole, ok := tkn.Get(RoleField)
if !ok {
return auth.Key{}, errInvalidRole
}
kRole, ok := tRole.(float64)
if !ok {
return auth.Key{}, errInvalidRole
}
kr := auth.Role(kRole)
if !kr.Validate() {
return auth.Key{}, errInvalidRole
}
key.ID = tkn.JwtID()
key.Type = auth.KeyType(ktype)
key.Type = auth.KeyType(kType)
key.Role = auth.Role(kRole)
key.Issuer = tkn.Issuer()
key.Subject = tkn.Subject()
key.IssuedAt = tkn.IssuedAt()
+25 -5
View File
@@ -57,14 +57,35 @@ func (kt KeyType) String() string {
}
}
type Role uint32
const (
UserRole Role = iota + 1
AdminRole
)
func (r Role) String() string {
switch r {
case UserRole:
return "user"
case AdminRole:
return "admin"
default:
return "unknown"
}
}
func (r Role) Validate() bool {
return UserRole <= r && r <= AdminRole
}
// Key represents API key.
type Key struct {
ID string `json:"id,omitempty"`
Type KeyType `json:"type,omitempty"`
Issuer string `json:"issuer,omitempty"`
Subject string `json:"subject,omitempty"` // user ID
User string `json:"user,omitempty"`
Domain string `json:"domain,omitempty"` // domain user ID
Role Role `json:"role,omitempty"`
IssuedAt time.Time `json:"issued_at,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
}
@@ -75,11 +96,10 @@ func (key Key) String() string {
type: %s,
issuer_id: %s,
subject: %s,
user: %s,
domain: %s,
role: %s,
iat: %v,
eat: %v
}`, key.ID, key.Type, key.Issuer, key.Subject, key.User, key.Domain, key.IssuedAt, key.ExpiresAt)
}`, key.ID, key.Type, key.Issuer, key.Subject, key.Role, key.IssuedAt, key.ExpiresAt)
}
// Expired verifies if the key is expired.
+47 -45
View File
@@ -35,6 +35,7 @@ var (
errRetrieve = errors.New("failed to retrieve key data")
errIdentify = errors.New("failed to validate token")
errPlatform = errors.New("invalid platform id")
errRoleAuth = errors.New("failed to authorize user role")
errMalformedPAT = errors.New("malformed personal access token")
errFailedToParseUUID = errors.New("failed to parse string to UUID")
@@ -133,7 +134,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err
case RefreshKey:
return svc.refreshKey(ctx, token, key)
case RecoveryKey:
return svc.tmpKey(recoveryDuration, key)
return svc.tmpKey(ctx, recoveryDuration, key)
case InvitationKey:
return svc.invitationKey(ctx, key)
default:
@@ -205,7 +206,6 @@ func (svc service) Authorize(ctx context.Context, pr policies.Policy) error {
return svcerr.ErrAuthentication
}
pr.Subject = key.Subject
pr.Domain = key.Domain
}
if err := svc.checkPolicy(ctx, pr); err != nil {
return err
@@ -259,8 +259,11 @@ func (svc service) PolicyValidation(pr policies.Policy) error {
return nil
}
func (svc service) tmpKey(duration time.Duration, key Key) (Token, error) {
func (svc service) tmpKey(ctx context.Context, duration time.Duration, key Key) (Token, error) {
key.ExpiresAt = time.Now().Add(duration)
if err := svc.checkUserRole(ctx, key); err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}
value, err := svc.tokenizer.Issue(key)
if err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
@@ -274,9 +277,8 @@ func (svc service) accessKey(ctx context.Context, key Key) (Token, error) {
key.Type = AccessKey
key.ExpiresAt = time.Now().Add(svc.loginDuration)
key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
if err := svc.checkUserRole(ctx, key); err != nil {
return Token{}, errors.Wrap(errIssueUser, err)
}
access, err := svc.tokenizer.Issue(key)
@@ -299,9 +301,8 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) {
key.Type = InvitationKey
key.ExpiresAt = time.Now().Add(svc.invitationDuration)
key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
return Token{}, err
if err := svc.checkUserRole(ctx, key); err != nil {
return Token{}, errors.Wrap(errIssueTmp, err)
}
access, err := svc.tokenizer.Issue(key)
@@ -321,16 +322,13 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
return Token{}, errIssueUser
}
key.ID = k.ID
if key.Domain == "" {
key.Domain = k.Domain
}
key.User = k.User
key.Type = AccessKey
key.Subject = k.Subject
key.Subject, err = svc.checkUserDomain(ctx, key)
if err != nil {
return Token{}, errors.Wrap(svcerr.ErrAuthorization, err)
if err := svc.checkUserRole(ctx, key); err != nil {
return Token{}, errors.Wrap(errIssueUser, err)
}
key.Role = k.Role
key.ExpiresAt = time.Now().Add(svc.loginDuration)
access, err := svc.tokenizer.Issue(key)
@@ -348,32 +346,33 @@ func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token
return Token{AccessToken: access, RefreshToken: refresh}, nil
}
func (svc service) checkUserDomain(ctx context.Context, key Key) (subject string, err error) {
if key.Domain != "" {
// Check user is platform admin.
func (svc service) checkUserRole(ctx context.Context, key Key) (err error) {
switch key.Role {
case AdminRole:
if err = svc.Authorize(ctx, policies.Policy{
Subject: key.User,
Subject: key.Subject,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
Object: policies.SuperMQObject,
ObjectType: policies.PlatformType,
}); err == nil {
return key.User, nil
}); err != nil {
return errRoleAuth
}
// Check user is domain member.
domainUserSubject := EncodeDomainUserID(key.Domain, key.User)
return nil
case UserRole:
if err = svc.Authorize(ctx, policies.Policy{
Subject: domainUserSubject,
Subject: key.Subject,
SubjectType: policies.UserType,
Permission: policies.MembershipPermission,
Object: key.Domain,
ObjectType: policies.DomainType,
Object: policies.SuperMQObject,
ObjectType: policies.PlatformType,
}); err != nil {
return "", err
return errRoleAuth
}
return domainUserSubject, nil
return nil
default:
return nil
}
return "", nil
}
func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) {
@@ -386,6 +385,9 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e
if key.Subject == "" {
key.Subject = sub
}
if err := svc.checkUserRole(ctx, key); err != nil {
return Token{}, errors.Wrap(errIssueUser, err)
}
keyID, err := svc.idProvider.ID()
if err != nil {
@@ -471,7 +473,7 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
secret, hash, err := svc.generateSecretAndHash(key.User, id)
secret, hash, err := svc.generateSecretAndHash(key.Subject, id)
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err)
}
@@ -479,7 +481,7 @@ func (svc service) CreatePAT(ctx context.Context, token, name, description strin
now := time.Now()
pat := PAT{
ID: id,
User: key.User,
User: key.Subject,
Name: name,
Description: description,
Secret: hash,
@@ -502,7 +504,7 @@ func (svc service) UpdatePATName(ctx context.Context, token, patID, name string)
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.UpdateName(ctx, key.User, patID, name)
pat, err := svc.pats.UpdateName(ctx, key.Subject, patID, name)
if err != nil {
return PAT{}, errors.Wrap(errUpdatePAT, err)
}
@@ -514,7 +516,7 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.UpdateDescription(ctx, key.User, patID, description)
pat, err := svc.pats.UpdateDescription(ctx, key.Subject, patID, description)
if err != nil {
return PAT{}, errors.Wrap(errUpdatePAT, err)
}
@@ -526,7 +528,7 @@ func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, e
if err != nil {
return PAT{}, err
}
pat, err := svc.pats.Retrieve(ctx, key.User, patID)
pat, err := svc.pats.Retrieve(ctx, key.Subject, patID)
if err != nil {
return PAT{}, errors.Wrap(errRetrievePAT, err)
}
@@ -538,7 +540,7 @@ func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta)
if err != nil {
return PATSPage{}, err
}
patsPage, err := svc.pats.RetrieveAll(ctx, key.User, pm)
patsPage, err := svc.pats.RetrieveAll(ctx, key.Subject, pm)
if err != nil {
return PATSPage{}, errors.Wrap(errRetrievePAT, err)
}
@@ -550,7 +552,7 @@ func (svc service) DeletePAT(ctx context.Context, token, patID string) error {
if err != nil {
return err
}
if err := svc.pats.Remove(ctx, key.User, patID); err != nil {
if err := svc.pats.Remove(ctx, key.Subject, patID); err != nil {
return errors.Wrap(errDeletePAT, err)
}
return nil
@@ -563,17 +565,17 @@ func (svc service) ResetPATSecret(ctx context.Context, token, patID string, dura
}
// Generate new HashToken take place here
secret, hash, err := svc.generateSecretAndHash(key.User, patID)
secret, hash, err := svc.generateSecretAndHash(key.Subject, patID)
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
pat, err := svc.pats.UpdateTokenHash(ctx, key.User, patID, hash, time.Now().Add(duration))
pat, err := svc.pats.UpdateTokenHash(ctx, key.Subject, patID, hash, time.Now().Add(duration))
if err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
if err := svc.pats.Reactivate(ctx, key.User, patID); err != nil {
if err := svc.pats.Reactivate(ctx, key.Subject, patID); err != nil {
return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err)
}
pat.Secret = secret
@@ -588,7 +590,7 @@ func (svc service) RevokePATSecret(ctx context.Context, token, patID string) err
return err
}
if err := svc.pats.Revoke(ctx, key.User, patID); err != nil {
if err := svc.pats.Revoke(ctx, key.Subject, patID); err != nil {
return errors.Wrap(svcerr.ErrUpdateEntity, err)
}
return nil
@@ -599,7 +601,7 @@ func (svc service) RemoveAllPAT(ctx context.Context, token string) error {
if err != nil {
return err
}
if err := svc.pats.RemoveAllPAT(ctx, key.User); err != nil {
if err := svc.pats.RemoveAllPAT(ctx, key.Subject); err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
return nil
@@ -620,7 +622,7 @@ func (svc service) AddScope(ctx context.Context, token, patID string, scopes []S
scopes[i].PatID = patID
}
err = svc.pats.AddScope(ctx, key.User, scopes)
err = svc.pats.AddScope(ctx, key.Subject, scopes)
if err != nil {
return errors.Wrap(svcerr.ErrCreateEntity, err)
}
@@ -633,7 +635,7 @@ func (svc service) RemoveScope(ctx context.Context, token, patID string, scopesI
return err
}
err = svc.pats.RemoveScope(ctx, key.User, scopesIDs...)
err = svc.pats.RemoveScope(ctx, key.Subject, scopesIDs...)
if err != nil {
return errors.Wrap(svcerr.ErrRemoveEntity, err)
}
@@ -750,7 +752,7 @@ func (svc service) authnAuthzUserPAT(ctx context.Context, token, patID string) (
return Key{}, err
}
_, err = svc.pats.Retrieve(ctx, key.User, patID)
_, err = svc.pats.Retrieve(ctx, key.Subject, patID)
if err != nil {
return Key{}, errors.Wrap(svcerr.ErrAuthorization, err)
}
+195 -516
View File
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -318,9 +318,9 @@ definition domain {
channel_update + channel_read + channel_delete + channel_set_parent_group + channel_connect_to_client + channel_publish + channel_subscribe +
channel_manage_role + channel_add_role_users + channel_remove_role_users + channel_view_role_users +
group_update + group_membership + group_read + group_delete + group_set_child + group_set_parent +
group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users
group_manage_role + group_add_role_users + group_remove_role_users + group_view_role_users + organization->admin
permission admin = read & update & enable & disable & delete & manage_role & add_role_users & remove_role_users & view_role_users
permission admin = (read & update & enable & disable & delete & manage_role & add_role_users & remove_role_users & view_role_users) + organization->admin
permission client_create_permission = client_create + team->client_create + organization->admin
permission channel_create_permission = channel_create + team->channel_create + organization->admin
+19 -7
View File
@@ -58,6 +58,11 @@ func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session aut
}
func (am *authorizationMiddleware) RetrieveDomain(ctx context.Context, session authn.Session, id string, withRoles bool) (domains.Domain, error) {
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
session.SuperAdmin = true
return am.svc.RetrieveDomain(ctx, session, id, withRoles)
}
if err := am.authorize(ctx, domains.OpRetrieveDomain, authz.PolicyReq{
Subject: session.DomainUserID,
SubjectType: policies.UserType,
@@ -127,13 +132,7 @@ func (am *authorizationMiddleware) FreezeDomain(ctx context.Context, session aut
}
func (am *authorizationMiddleware) ListDomains(ctx context.Context, session authn.Session, page domains.Page) (domains.DomainsPage, error) {
if err := am.authz.Authorize(ctx, authz.PolicyReq{
Subject: session.UserID,
SubjectType: policies.UserType,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}); err == nil {
if err := am.checkSuperAdmin(ctx, session.UserID); err == nil {
session.SuperAdmin = true
}
@@ -247,6 +246,19 @@ func (am *authorizationMiddleware) checkAdmin(ctx context.Context, session authn
return svcerr.ErrAuthorization
}
func (am *authorizationMiddleware) checkSuperAdmin(ctx context.Context, userID string) error {
if err := am.authz.Authorize(ctx, smqauthz.PolicyReq{
SubjectType: policies.UserType,
Subject: userID,
Permission: policies.AdminPermission,
ObjectType: policies.PlatformType,
Object: policies.SuperMQObject,
}); err != nil {
return err
}
return nil
}
func (am *authorizationMiddleware) extAuthorize(ctx context.Context, subj, perm, objType, obj string) error {
req := authz.PolicyReq{
SubjectType: policies.UserType,
+2 -2
View File
@@ -21,9 +21,9 @@ message AuthNReq {
}
message AuthNRes {
string id = 1; // id
string id = 1; // token id
string user_id = 2; // user id
string domain_id = 3; // domain id
uint32 user_role = 3; // user role
}
message AuthZReq {
+1
View File
@@ -13,6 +13,7 @@ service TokenService {
message IssueReq {
string user_id = 1;
uint32 user_role = 2;
uint32 type = 3;
}
+9 -1
View File
@@ -27,13 +27,21 @@ func (t TokenType) String() string {
}
}
type Role uint32
const (
UserRole Role = iota + 1
AdminRole
)
type Session struct {
Type TokenType
PatID string
DomainUserID string
UserID string
DomainID string
DomainUserID string
SuperAdmin bool
Role Role
}
// Authn is supermq authentication library.
+1 -1
View File
@@ -54,5 +54,5 @@ func (a authentication) Authenticate(ctx context.Context, token string) (authn.S
return authn.Session{}, errors.Wrap(errors.ErrAuthentication, err)
}
return authn.Session{Type: authn.AccessToken, DomainUserID: res.GetId(), UserID: res.GetUserId(), DomainID: res.GetDomainId()}, nil
return authn.Session{Type: authn.AccessToken, UserID: res.GetUserId(), Role: authn.Role(res.GetUserRole())}, nil
}
+1 -1
View File
@@ -110,7 +110,7 @@ func (svc service) IssueToken(ctx context.Context, identity, secret string) (*gr
return &grpcTokenV1.Token{}, errors.Wrap(svcerr.ErrLogin, err)
}
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, Type: uint32(smqauth.AccessKey)})
token, err := svc.token.Issue(ctx, &grpcTokenV1.IssueReq{UserId: dbUser.ID, UserRole: uint32(dbUser.Role + 1), Type: uint32(smqauth.AccessKey)})
if err != nil {
return &grpcTokenV1.Token{}, errors.Wrap(errIssueToken, err)
}
+2 -2
View File
@@ -1395,7 +1395,7 @@ func TestIssueToken(t *testing.T) {
for _, tc := range cases {
repoCall := cRepo.On("RetrieveByUsername", context.Background(), tc.user.Credentials.Username).Return(tc.retrieveByUsernameResponse, tc.retrieveByUsernameErr)
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
authCall := auth.On("Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)}).Return(tc.issueResponse, tc.issueErr)
token, err := svc.IssueToken(context.Background(), tc.user.Credentials.Username, tc.user.Credentials.Secret)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
if err == nil {
@@ -1403,7 +1403,7 @@ func TestIssueToken(t *testing.T) {
assert.NotEmpty(t, token.GetRefreshToken(), fmt.Sprintf("%s: expected %s not to be empty\n", tc.desc, token.GetRefreshToken()))
ok := repoCall.Parent.AssertCalled(t, "RetrieveByUsername", context.Background(), tc.user.Credentials.Username)
assert.True(t, ok, fmt.Sprintf("RetrieveByUsername was not called on %s", tc.desc))
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, Type: uint32(smqauth.AccessKey)})
ok = authCall.Parent.AssertCalled(t, "Issue", context.Background(), &grpcTokenV1.IssueReq{UserId: tc.user.ID, UserRole: uint32(tc.user.Role + 1), Type: uint32(smqauth.AccessKey)})
assert.True(t, ok, fmt.Sprintf("Issue was not called on %s", tc.desc))
}
authCall.Unset()